In [2]:
import torch
import pickle
from model.EffectDecoder import EffectDecoder
from transformers import ASTModel
from sklearn.metrics import accuracy_score, roc_auc_score
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class EffectClassifier(torch.nn.Module):
    def __init__(self, n_classes,embed_dim=768):
        super(EffectClassifier, self).__init__()
        self.pretrained = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.embed = torch.nn.Linear(embed_dim,embed_dim)
        self.cls = torch.nn.Linear(embed_dim, n_classes)
        self.relu = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.pretrained(**x).pooler_output
        x = self.relu(self.embed(x))
        x = self.cls(x)
        return self.softmax(x)

In [5]:
with open("data/guitar_sample_dataset_multiclass.pkl", "rb") as f:
    dataset = pickle.load(f)

In [15]:
dataset[0]

{'dry_tone_path': 'data/Train_submission/Train_submission\\1-E1-Major 00.wav',
 'wet_tone_path': 'data/wet_tones/1-E1-Major 00_wet_0.wav',
 'wet_tone_features': {'input_values': tensor([[[-0.9958, -1.2776, -0.9260,  ..., -1.2501, -1.2776, -1.2776],
          [-0.9575, -1.1735, -0.7967,  ..., -1.2776, -1.2776, -1.2776],
          [-0.7767, -1.0278, -0.6510,  ..., -1.2776, -1.2776, -1.2776],
          ...,
          [ 0.4670,  0.4670,  0.4670,  ...,  0.4670,  0.4670,  0.4670],
          [ 0.4670,  0.4670,  0.4670,  ...,  0.4670,  0.4670,  0.4670],
          [ 0.4670,  0.4670,  0.4670,  ...,  0.4670,  0.4670,  0.4670]]],
        device='cuda:0')},
 'effect_names': ['Reverb'],
 'effects': tensor([[0., 1., 0., 0., 0.]]),
 'parameters': tensor([[0.6732, 0.4756, 0.9093, 0.2037, 0.4671, 0.7824, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.00

In [6]:
train_data, test_data = train_test_split(dataset, test_size=0.2)

In [28]:
def eval(model, loss_fn, dl):
    model.eval()
    total_loss = 0
    labels = []
    preds = []
    for batch in tqdm.tqdm(dl):
        features = batch['wet_tone_features'].to(device)
        label = batch['effects'].to(device)
        with torch.no_grad():
            output = model(features)
        loss = loss_fn(output, label)
        total_loss += loss.item()
        preds.append(torch.argmax(output, dim=-1).cpu().numpy())
        labels.append(torch.argmax(label).cpu().numpy())
    print(f"Accuracy:{accuracy_score(labels, preds)} | Total Loss:{total_loss}")
    return

In [29]:
def train(model, optimizer, loss_fn, train_loader,test_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm.tqdm(train_loader):
            optimizer.zero_grad()
            features = batch['wet_tone_features'].to(device)
            labels = batch['effects'].to(device)
            output = model(features)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss}")
        eval(model, loss_fn, test_loader)
    return

In [32]:
model = EffectClassifier(5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=.0000005)
loss_fn = torch.nn.CrossEntropyLoss()

In [33]:
train(model, optimizer, loss_fn, train_data, test_data, epochs=5)

100%|██████████| 400/400 [01:34<00:00,  4.23it/s]


Epoch 1, Loss: 627.8390402793884


100%|██████████| 100/100 [00:06<00:00, 15.92it/s]


Accuracy:0.56 | Total Loss:152.8575165271759


100%|██████████| 400/400 [01:34<00:00,  4.21it/s]


Epoch 2, Loss: 582.9692931175232


100%|██████████| 100/100 [00:06<00:00, 15.92it/s]


Accuracy:0.67 | Total Loss:140.95802009105682


100%|██████████| 400/400 [01:32<00:00,  4.34it/s]


Epoch 3, Loss: 537.9732059836388


100%|██████████| 100/100 [00:06<00:00, 15.64it/s]


Accuracy:0.7 | Total Loss:132.32380890846252


100%|██████████| 400/400 [01:41<00:00,  3.95it/s]


Epoch 4, Loss: 509.97031432390213


100%|██████████| 100/100 [00:06<00:00, 15.88it/s]


Accuracy:0.73 | Total Loss:127.62350732088089


100%|██████████| 400/400 [01:29<00:00,  4.47it/s]


Epoch 5, Loss: 492.1641817688942


100%|██████████| 100/100 [00:05<00:00, 16.82it/s]

Accuracy:0.74 | Total Loss:124.5223405957222



