## **Libraries**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [3]:
cqt_dir="/content/drive/MyDrive/Automatic Guitar Transcription/Data Set/GuitarSet/cqt"
json_dir="/content/drive/MyDrive/Automatic Guitar Transcription/Data Set/GuitarSet/merge_annotation"

## **Dataset**

In [18]:
class FretNetDataset(Dataset):
    def __init__(self, file_list, cqt_dir, label_dir, normalize=True):
        self.file_list = file_list
        self.cqt_dir = cqt_dir
        self.label_dir = label_dir
        self.normalize = normalize

    def __len__(self):
        return len(self.file_list)

    def align_json_to_cqt(self, label_data, target_len):
        """
        JSON etiketlerini CQT uzunluğuna göre hizalar.
        Eksikse 'deviation' sıfırla doldurulur.
        'onsets' içinde 3 boyutlu yapı varsa düzleştirir.
        """
        aligned = {}

        for key in ['tablature', 'onsets', 'deviation']:
            arr = label_data.get(key)

            # Eksik 'deviation' için sıfırla doldur
            if arr is None and key == 'deviation':
                arr = [[0.0] * target_len for _ in range(6)]

            if arr is None:
                raise KeyError(f"Etiket dosyasında '{key}' bulunamadı.")

            # Eğer 3 boyutlu (nested list) varsa düzleştir
            if isinstance(arr[0][0], list):
                arr = [[x[0] if isinstance(x, list) else x for x in string_arr] for string_arr in arr]

            # Trim ve pad
            aligned_arr = []
            for string_arr in arr:
                string_arr = string_arr[:target_len]
                if len(string_arr) < target_len:
                    last_val = string_arr[-1] if string_arr else 0
                    pad = [last_val] * (target_len - len(string_arr))
                    string_arr += pad
                aligned_arr.append(string_arr)

            aligned[key] = aligned_arr

        return aligned

    def __getitem__(self, idx):
        filename = self.file_list[idx]

        # CQT yükle
        cqt_path = os.path.join(self.cqt_dir, f"{filename}_cqt.npy")
        cqt = np.load(cqt_path)  # [6, F, T]
        cqt = torch.tensor(cqt, dtype=torch.float32)  # [6, F, T]

        if self.normalize:
            cqt = torch.clamp(cqt, min=-80.0, max=0.0)
            cqt = (cqt + 80.0) / 80.0  # Normalize to [0, 1]

        time_dim = cqt.shape[2]  # T

        # Etiketleri yükle
        label_path = os.path.join(self.label_dir, f"{filename}_fretnet.json")
        with open(label_path, 'r') as f:
            label_data = json.load(f)

        # JSON'u CQT frame sayısına hizala
        aligned_data = self.align_json_to_cqt(label_data, target_len=time_dim)

        tablature = torch.tensor(aligned_data['tablature'], dtype=torch.long)       # [6, T]
        onsets    = torch.tensor(aligned_data['onsets'], dtype=torch.float32)       # [6, T]
        deviation = torch.tensor(aligned_data['deviation'], dtype=torch.float32)    # [6, T]

        return cqt, tablature, onsets, deviation

In [19]:
import os

cqt_dir = "/content/drive/MyDrive/Automatic Guitar Transcription/Data Set/GuitarSet/cqt"
label_dir = "/content/drive/MyDrive/Automatic Guitar Transcription/Data Set/GuitarSet/merge_annotation"

# Tüm _cqt.npy dosyalarının isimlerini al
file_list = sorted([
    f.replace("_cqt.npy", "") for f in os.listdir(cqt_dir)
    if f.endswith("_cqt.npy")
])

# Opsiyonel: %80 eğitim - %20 validation ayır
from sklearn.model_selection import train_test_split
train_files, val_files = train_test_split(file_list, test_size=0.2, random_state=42)

In [20]:
def collate_fn_fretnet(batch):
    cqt_list, tab_list, onset_list, dev_list = zip(*batch)

    min_F = min([x.shape[1] for x in cqt_list])  # F: frekans ekseni
    min_T = min([x.shape[2] for x in cqt_list])  # T: zaman ekseni

    def trim_cqt(x):      # [6, F, T] → [6, min_F, min_T]
        return x[:, :min_F, :min_T]

    def trim_label(x):    # [6, T] → [6, min_T]
        return x[:, :min_T]

    cqt_batch = torch.stack([trim_cqt(x) for x in cqt_list])         # [B, 6, F, T]
    tab_batch = torch.stack([trim_label(x) for x in tab_list])       # [B, 6, T]
    onset_batch = torch.stack([trim_label(x) for x in onset_list])   # [B, 6, T]
    dev_batch = torch.stack([trim_label(x) for x in dev_list])       # [B, 6, T]

    return cqt_batch, tab_batch, onset_batch, dev_batch

In [21]:
from torch.utils.data import DataLoader

train_dataset = FretNetDataset(
    file_list=train_files,
    cqt_dir=cqt_dir,
    label_dir=label_dir,
    normalize=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn_fretnet
)

val_dataset = FretNetDataset(
    file_list=val_files,
    cqt_dir=cqt_dir,
    label_dir=label_dir,
    normalize=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=collate_fn_fretnet
)

## **Model**

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleFretNet(nn.Module):
    def __init__(self, num_strings=6, num_frets=19):
        super().__init__()
        self.num_strings = num_strings
        self.num_frets = num_frets

        # CNN Backbone
        self.conv1 = nn.Conv2d(6, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.dropout2 = nn.Dropout2d(0.5)

        self.conv3 = nn.Conv2d(32, 48, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(48)
        self.dropout3 = nn.Dropout2d(0.25)

        self.pool = nn.MaxPool2d(kernel_size=(2, 1))  # sadece frekans ekseninde havuzlama

        # Prediction heads (tab, deviation, onset)
        self.tab_head = nn.Conv2d(48, num_strings * (num_frets + 1), kernel_size=1)
        self.dev_head = nn.Conv2d(48, num_strings, kernel_size=1)
        self.onset_head = nn.Conv2d(48, num_strings, kernel_size=1)

    def forward(self, x):  # x: [B, 6, F, T]
        x = F.relu(self.bn1(self.conv1(x)))      # [B, 16, F, T]

        x = F.relu(self.bn2(self.conv2(x)))      # [B, 32, F, T]
        x = self.pool(x)                         # [B, 32, F//2, T]
        x = self.dropout2(x)

        x = F.relu(self.bn3(self.conv3(x)))      # [B, 48, F//2, T]
        x = self.pool(x)                         # [B, 48, F//4, T]
        x = self.dropout3(x)

        # Head outputs
        tab = self.tab_head(x)       # [B, 6×(F+1), F', T]
        dev = self.dev_head(x)       # [B, 6, F', T]
        onset = self.onset_head(x)   # [B, 6, F', T]

        # Frekans boyutunu ortalama al → [B, *, T]
        tab = tab.mean(dim=2)        # [B, 6×(F+1), T]
        dev = dev.mean(dim=2)        # [B, 6, T]
        onset = onset.mean(dim=2)    # [B, 6, T]

        return tab, onset, dev

In [23]:
import numpy as np

example_cqt = np.load("/content/drive/MyDrive/Automatic Guitar Transcription/Data Set/GuitarSet/cqt/00_BN1-129-Eb_comp_mix_cqt.npy")
print(example_cqt.shape)

(6, 144, 962)


In [24]:
for fname in train_files[:5]:
    cqt = np.load(os.path.join(cqt_dir, f"{fname}_cqt.npy"))
    with open(os.path.join(label_dir, f"{fname}_fretnet.json")) as f:
        data = json.load(f)
    print(fname)
    print("CQT shape:", cqt.shape)  # (6, 144, T)
    print("tablature shape:", np.array(data["tablature"]).shape)  # (6, T')
    print("---")


04_SS1-100-C#_solo_mix
CQT shape: (6, 144, 1241)
tablature shape: (6, 1350)
---
05_Rock1-130-A_comp_mix
CQT shape: (6, 144, 954)
tablature shape: (6, 1035)
---
03_SS1-68-E_solo_mix
CQT shape: (6, 144, 1824)
tablature shape: (6, 1980)
---
01_Funk1-97-C_solo_mix
CQT shape: (6, 144, 1279)
tablature shape: (6, 1386)
---
03_Funk1-97-C_comp_mix
CQT shape: (6, 144, 1279)
tablature shape: (6, 1386)
---


## **Train**

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleFretNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# BCE Loss
bce_loss = nn.BCEWithLogitsLoss(reduction='none')  # manual masking için

# Pitch deviation için alternatif (Bernoulli yerine MSE)
mse_loss = nn.MSELoss(reduction='none')

# Loss ağırlıkları (makaledeki gibi)
λ_inh = 10.0
γ_total = 10.0

In [26]:
batch = next(iter(train_loader))
cqt, tablature, onsets, deviation = batch
tab_out, onset_out, dev_out = model(cqt.to(device))

print("tab_out shape:", tab_out.shape)         # ✅ [B, 6*(F+1), T]
print("onset_out shape:", onset_out.shape)     # ✅ [B, 6, T]
print("deviation_out shape:", dev_out.shape)   # ✅ [B, 6, T]
print("tablature label shape:", tablature.shape)  # ✅ [B, 6, T]

tab_out shape: torch.Size([16, 120, 954])
onset_out shape: torch.Size([16, 6, 954])
deviation_out shape: torch.Size([16, 6, 954])
tablature label shape: torch.Size([16, 6, 954])


In [36]:
import torch.nn as nn

def compute_fretnet_loss(tab_logits, onset_logits, dev_preds,
                         tab_labels, onset_labels, dev_labels, mask,
                         λ_inh=10.0, γ_total=10.0,
                         bce_loss=None, mse_loss=None):
    """
    Hesaplanan loss fonksiyonu:
    - tab_logits: [B, 6*(F+1), 1, T] veya [B, 6*(F+1), T]
    - tab_labels: [B, 6, T] (class index, -1 için ignore)
    - onset_logits, dev_preds: [B, 6, 1, T] veya [B, 6, T]
    - onset_labels, dev_labels, mask: [B, 6, T]

    Args:
        bce_loss: BCEWithLogitsLoss objesi (reduction='none' önerilir)
        mse_loss: MSELoss objesi (reduction='none' önerilir)
    """

    # Squeeze 3. boyutu varsa kaldır
    if tab_logits.dim() == 4:
        tab_logits = tab_logits.squeeze(2)  # [B, 6*(F+1), T]
    if onset_logits.dim() == 4:
        onset_logits = onset_logits.squeeze(2)  # [B, 6, T]
    if dev_preds.dim() == 4:
        dev_preds = dev_preds.squeeze(2)  # [B, 6, T]

    B, _, T = onset_logits.shape
    F_plus1 = tab_logits.shape[1] // 6

    # Tab loss (CrossEntropy)
    tab_logits_reshaped = tab_logits.view(B, 6, F_plus1, T).permute(0, 1, 3, 2)  # [B, 6, T, F+1]
    tab_gt = tab_labels  # [B, 6, T]
    ce_loss_fn = nn.CrossEntropyLoss(ignore_index=-1)

    L_tab = ce_loss_fn(
        tab_logits_reshaped.reshape(-1, F_plus1),  # [B*6*T, F+1]
        tab_gt.reshape(-1)                         # [B*6*T]
    )

    # Onset loss (BCE)
    if bce_loss is None:
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')
    L_onset = bce_loss(onset_logits, onset_labels)
    L_onset = (L_onset * mask).sum() / mask.sum()

    # Deviation loss (MSE)
    if mse_loss is None:
        mse_loss = nn.MSELoss(reduction='none')
    L_dev = mse_loss(dev_preds, dev_labels)
    L_dev = (L_dev * mask).sum() / mask.sum()

    # Inhibition loss (ceza)
    pred_classes = tab_logits_reshaped.argmax(-1).float()  # [B, 6, T]
    inhibition_penalty = (pred_classes > 0).float()
    L_inh = ((inhibition_penalty.sum(dim=1) > 1).float().sum()) / B

    # Toplam loss
    L_total = (1 / γ_total) * (L_tab + λ_inh * L_inh + L_onset) + L_dev

    return L_total


In [37]:
from tqdm import tqdm

def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    num_batches = 0

    # tqdm ile ilerleme çubuğu ekle
    loop = tqdm(dataloader, desc="Training", leave=False)

    for batch in loop:
        cqt, tablature, onsets, deviation = batch
        cqt = cqt.to(device)
        tablature = tablature.to(device)
        onsets = onsets.to(device)
        deviation = deviation.to(device)

        tab_logits, onset_logits, dev_preds = model(cqt)
        mask = (tablature != -1).float()

        loss = compute_fretnet_loss(
            tab_logits=tab_logits,
            onset_logits=onset_logits,
            dev_preds=dev_preds,
            tab_labels=tablature,
            onset_labels=onsets,
            dev_labels=deviation,
            mask=mask,
            λ_inh=10.0,
            γ_total=10.0,
            bce_loss=nn.BCEWithLogitsLoss(reduction='none'),
            mse_loss=nn.MSELoss(reduction='none')
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

In [38]:
for epoch in range(10):
    avg_loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1}/10 - Loss: {avg_loss:.4f}")



Epoch 1/10 - Loss: 769.1353




Epoch 2/10 - Loss: 793.3719




Epoch 3/10 - Loss: 749.6189




Epoch 4/10 - Loss: 782.7628




Epoch 5/10 - Loss: 784.8541




Epoch 6/10 - Loss: 775.0661




Epoch 7/10 - Loss: 779.0017




Epoch 8/10 - Loss: 793.2157




Epoch 9/10 - Loss: 782.8243


                                                                   

Epoch 10/10 - Loss: 781.5977


