# MFCC Model

In [None]:
from models import MFCCModel
import torch

In [None]:
model = MFCCModel()

In [None]:
torch.save(model, "saved_models/mfcc_model.ckpt")

# LSTM Model

In [None]:
from models import LSTMModel
from data_utils import get_dataloader_from_split

import wandb
from tqdm import tqdm

import torch
import torch.nn.functional as F

In [None]:
train_dataloader = get_dataloader_from_split("splits/embedding_train_split.txt")

In [None]:
test_dataloader = get_dataloader_from_split("splits/embedding_test_split.txt")

In [None]:
wandb.init(project="genre-classification")

In [None]:
def train_one_epoch(model, train_dataloader, optimizer):
    model.train()
    total, correct = 0, 0
    for (audios, labels) in tqdm(train_dataloader):
            optimizer.zero_grad()
            
            preds = model(audios)
            total += labels.shape[0]
            correct += (torch.argmax(preds, axis=-1) == labels).sum()
            
            loss = F.cross_entropy(preds, labels)
            loss.backward()
            optimizer.step()
        
            wandb.log({"train_loss": loss.item()})
    wandb.log({"train_accuracy": correct / total})

In [None]:
def test_one_epoch(model, test_dataloader):
    model.eval()
    total, correct = 0, 0
    total_loss = 0
    for (audios, labels) in tqdm(test_dataloader):            
            preds = model(audios)
            total += labels.shape[0]
            correct += (torch.argmax(preds, axis=-1) == labels).sum()
            
            loss = F.cross_entropy(preds, labels)
            total_loss += loss.item()
    wandb.log({"test_loss": total_loss/total})
    wandb.log({"test_accuracy": correct / total})

In [None]:
model = LSTMModel(last=False)
optimizer = torch.optim.Adam(model.parameters())

In [None]:
NUM_EPOCHS = 10

In [None]:
for epoch in range(NUM_EPOCHS):
    print(epoch)
    train_one_epoch(model, train_dataloader, optimizer)
    test_one_epoch(model, test_dataloader)

In [None]:
torch.save(model, "saved_models/lstm_model.ckpt")