In [None]:
import torch
import librosa
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

from ProgDataset import SnippetProgDataset
from torch.utils.data import DataLoader
from torchsummary import summary

In [None]:
torch.cuda.empty_cache()

In [None]:
getName = lambda s : s.split('\\')[-1]
def splitSongs(split = 0.8, path = ""):
    prog_folder = "Progressive_Rock_Songs" if path == "" else os.path.join(path + "//" + "Progressive_Rock_Songs")
    non_prog_folder = "Not_Progressive_Rock" if path == "" else os.path.join(path + "//" + "Not_Progressive_Rock")
    songs = [(i, getName(i), 1) for i in librosa.util.find_files(prog_folder)] + [(i, getName(i), 0) for i in librosa.util.find_files(non_prog_folder)]

    split_idx = int(split * len(songs))
    return songs[:split_idx], songs[split_idx:]

train_songs, validate_songs = splitSongs()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = SnippetProgDataset(train_songs, transform = transform)
validate_dataset = SnippetProgDataset(validate_songs, transform = transform)

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
print(f"The current device used is {device}")

In [None]:
BATCH_SIZE = 32
NUM_EPOCHS = 3
LEARNING_RATE = 0.01

In [None]:
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
validation_loader = DataLoader(validate_dataset, batch_size = BATCH_SIZE, shuffle = True)

In [None]:
class ProgCNN(nn.Module):
    def __init__(self):
        super(ProgCNN, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size = 3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size = 3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2)
        )
        # self.conv3 = nn.Sequential(
        #     nn.Conv2d(64, 128, kernel_size = 3),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size = 2)
        # )
        
        self.flatten = nn.Flatten()
        
        self.fc1 = nn.Sequential(
            nn.Linear(128 * 38 * 52, 1024),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 64),
            nn.ReLU()
        )
        self.fc3 = nn.Sequential(
            nn.Linear(64, 4),
            nn.ReLU()
        )
        self.fc4 = nn.Sequential(
            nn.Linear(4, 1),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = F.sigmoid(self.fc4(x))
        return x

In [None]:
model = ProgCNN()
print(summary(model, (1, 160, 216)))

In [None]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [None]:
def train(epoch, train_loader, train_losses, num_batch_prints = 5):
    model.train(True)
    correct = 0
    total = 0
    batch_print_idx = len(train_loader) // num_batch_prints
    if epoch == 1:
        print(f"Train Epoch: {epoch}")
        
    for batch_idx, (inputs, labels, metadata) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels = inputs.float(), labels.float()
        
        outputs = model(inputs).reshape(-1)
        loss = loss_fn(outputs, labels)
        
        predictions = (outputs > 0.5).long().reshape(-1)
        correct += (predictions == labels).sum()
        total += len(inputs)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % batch_print_idx == 0:
            print(f"\tTrain Batch: {(batch_idx // batch_print_idx) + 1} \tBatchwise Loss: {loss.item():.4f}")
            train_losses.append(loss.item())
            
    print(f"Train Epoch: {epoch}\tEpochwise Loss: {loss.item():.4f}\tEpoch Accuracy: [{correct}/{len(train_loader.dataset)}] {(100. * correct / len(train_loader.dataset)):2f}%")
    return correct

In [None]:
getSongPredictions = lambda arr: 1 if sum(arr) > len(arr) / 2 else 0
def validate(validation_loader, test_losses):
    model.train(False)
    model.eval()
    
    test_loss = 0.0
    correct = 0
    total = 0
    predictions_dict = {}
    
    with torch.no_grad():
        for inputs, labels, metadata in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            inputs, labels = inputs.float(), labels.float()
            
            outputs = model(inputs).reshape(-1)
            test_loss += loss_fn(outputs, labels).item()
            
            predictions = (outputs > 0.5).long().reshape(-1)
            # print(predictions, metadata['song_idx'], metadata['snippet_idx'])
            for idx, song_name in enumerate(metadata['song_name']):
                if song_name not in predictions_dict:
                    predictions_dict[song_name] = []
                predictions_dict[song_name].append(predictions[idx].item())
            correct += (predictions == labels).sum()
            
        test_loss /= len(validation_loader)
        test_losses.append(test_loss)
        print(f"Test Set: Avg Loss: {test_loss:.4f}, Accuracy: {correct}/{len(validation_loader.dataset)} ({(100. * correct / len(validation_loader.dataset)):.2f}%)\n")
        for key in predictions_dict:
            predictions_dict[key] = getSongPredictions(predictions_dict[key])
        return predictions_dict

In [None]:
def validateSongs(validate_songs, prediction_dict):
    correct = 0
    total = 0
    for song in validate_songs:
        songs_path, song_name, truth = song
        if song_name in prediction_dict:
            correct += (truth and prediction_dict[song_name]) or (not truth and not prediction_dict[song_name])
        total += 1
    print(f"Songs correctly classified {correct}/{total}: {(correct/total * 100):4f}")

In [None]:
train_loss_arr = []
test_loss_arr = []
correct = 0
validateSongs(validate_songs, validate(validation_loader, test_loss_arr))
for epoch in range(1, NUM_EPOCHS + 1):
    correct += train(epoch, train_loader, train_loss_arr)
validateSongs(validate_songs, validate(validation_loader, test_loss_arr))