In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, random_split
import tqdm
import pandas as pd
import copy
import numpy as np
from sklearn.metrics import roc_auc_score, confusion_matrix

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [3]:
year = 2019
month = 11
perc = 0.9

x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")
x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")

print(f"x_train shape:", x_train.shape)
print(f"x_test shape:", x_test.shape)
print(f"y_train shape:", y_train.shape)
print(f"y_test shape:", y_test.shape)

x_train shape: torch.Size([631620, 85])
x_test shape: torch.Size([282120, 85])
y_train shape: torch.Size([631620])
y_test shape: torch.Size([282120])


  x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")
  x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
  y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
  y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")


In [4]:
results_df = pd.read_csv("results.csv", dtype={
    "model": str,
    "year": int,
    "month": int,
    "perc": float,
    "epoch": int,
    "train_loss": float,
    "val_loss": float,
    "acc": float,
    "prec": float,
    "rec": float,
    "f1": float,
    "auc": float,
    "tp": int,
    "fp": int,
    "fn": int,
    "tn": int,
    "best_threshold": float,
    "done": bool
})

filtered_df = results_df[
    (results_df["model"] == "mlp") &
    (results_df["year"] == year) &
    (results_df["month"] == month) &
    (results_df["perc"] == perc)
]

if filtered_df.empty:
    latest_epoch = 0
    is_trained = False
else:
    latest_epoch = filtered_df["epoch"].max()
    is_trained = filtered_df["done"].any()

print("Latest epoch:", latest_epoch)
print("Is trained?", is_trained)

FileNotFoundError: [Errno 2] No such file or directory: 'results.csv'

In [None]:
input_size = x_train.shape[1]

In [None]:
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)

In [None]:
train_size = int(0.9 * x_train.shape[0])
val_size = x_train.shape[0] - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=10, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=10, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, num_workers=10, pin_memory=True)

In [None]:
class MLP(nn.Module):
    def __init__(self, input_size: int):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs, results_df, patience=5):
    best_threshold = 0.0
    best_val_f1 = 0.0
    epochs_no_improve = 0
    best_model_state = None
    train_losses = list()
    val_losses = list()
    best_epoch = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()  # Set model to training mode
        train_loss = 0.0
        for inputs, targets in tqdm.tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the weights

            train_loss += loss.item()

        train_loss /= len(train_loader)
        print("Train loss:", train_loss)
    
        # Validation
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_probs = []  # Store probabilities for ROC-AUC
        print("Validating...")
        with torch.no_grad():
            for inputs, targets in tqdm.tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs).squeeze(1)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                # Get predictions and probabilities (assuming binary classification with sigmoid output)
                probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
                labels = targets.cpu().numpy()

                all_labels.extend(labels)
                all_probs.extend(probs.flatten())

        val_loss /= len(val_loader)

        # Find threshold for predictions
        print("Looking for threshold")
        best_threshold_epoch = 0
        best_f1_epoch = 0
        for threshold in tqdm.tqdm(np.arange(0.05, 0.96, 0.01)):
            preds_binary = (all_probs > threshold).astype(int)
            cm = confusion_matrix(all_labels, preds_binary)
            tp = cm[1, 1]
            fp = cm[0, 1]
            fn = cm[1, 0]
            tn = cm[0, 0]
            precision = 0 if tp == 0 else tp / (tp + fp)
            recall = 0 if tp == 0 else tp / (tp + fn)
            f1 = 0 if precision * recall == 0 else 2 * precision * recall / (precision + recall)
            if f1 > best_f1_epoch:
                best_threshold_epoch = threshold
                best_f1_epoch = f1
        print(f"Best threshold: {best_threshold_epoch}")
        all_preds = (all_probs > best_threshold_epoch).astype(int)

        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]

        accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0 # Handle division by zero
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        roc_auc = roc_auc_score(all_labels, all_probs)

        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss      :{val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

        new_row = pd.DataFrame(
            {
                "model": ["mlp"],
                "year": [year],
                "month": [month],
                "perc": [perc],
                "epoch": [latest_epoch + epoch + 1],
                "train_loss": [epoch_loss],
                "val_loss": [val_loss],
                "acc": [accuracy],
                "prec": [precision],
                "rec": [recall],
                "f1": [f1],
                "auc": [roc_auc],
                "tp": [tp],
                "fp": [fp],
                "fn": [fn],
                "tn": [tn],
                "best_threshold": [best_threshold_epoch],
                "done": [False]
            }
        )
        results_df = pd.concat([results_df, new_row], ignore_index=True)
        results_df.to_csv("results.csv", index=False)

        torch.save(model.state_dict(), f"./model_mlp_{year}_{month}_{perc}_{latest_epoch + epoch + 1}.pth")

        if f1 > best_val_f1:
            best_val_f1 = f1
            best_threshold = best_threshold_epoch
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
            best_epoch = latest_epoch + epoch + 1
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print("Best epoch:", best_epoch)
                model.load_state_dict(best_model_state)
                break
    
    return best_threshold


In [None]:
def test(model, test_loader, device, criterion, best_threshold):
    model.eval()
    test_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for inputs, targets in tqdm.tqdm(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)  # Use criterion here
            test_loss += loss.item()

            probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
            preds = (probs > best_threshold).astype(int)  # Convert probabilities to predictions
            labels = targets.cpu().numpy()

            all_labels.extend(labels)
            all_preds.extend(preds.flatten())
            all_probs.extend(probs.flatten())

    test_loss /= len(test_loader)

    cm = confusion_matrix(all_labels, all_preds)
    tp = cm[1, 1]
    fp = cm[0, 1]
    fn = cm[1, 0]
    tn = cm[0, 0]

    accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    try:
      roc_auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        roc_auc = 0.0

    print(f"Test Metrics:")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")
    print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")
    print(f"Test Loss: {test_loss:.4f}") # Print the loss as well


In [None]:
model = MLP(input_size).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_threshold = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    1000,
    results_df
)

Epoch 1/10


100%|██████████| 17765/17765 [00:30<00:00, 581.75it/s]


Train loss: 0.3772745685261335
Validating...


100%|██████████| 1974/1974 [00:01<00:00, 1037.29it/s]


Looking for threshold


100%|██████████| 91/91 [00:11<00:00,  7.82it/s]


Best threshold: 0.36000000000000004
Validation Metrics - Epoch 1/10:
Loss      :0.2508
Accuracy:  0.8959
Precision: 0.8797
Recall:    0.9177
F1-score:  0.8983
ROC-AUC:   0.9640
Confusion Matrix:
29051 2605
3971 27535
Epoch 2/10


100%|██████████| 17765/17765 [00:30<00:00, 582.34it/s]


Train loss: 0.3247240778044088
Validating...


100%|██████████| 1974/1974 [00:01<00:00, 1024.49it/s]


Looking for threshold


100%|██████████| 91/91 [00:11<00:00,  7.76it/s]


Best threshold: 0.37000000000000005
Validation Metrics - Epoch 2/10:
Loss      :0.2716
Accuracy:  0.9072
Precision: 0.8979
Recall:    0.9194
F1-score:  0.9086
ROC-AUC:   0.9709
Confusion Matrix:
29105 2551
3308 28198
Epoch 3/10


  2%|▏         | 370/17765 [00:00<00:31, 559.32it/s]Exception in thread Thread-9 (_pin_memory_loop):
Traceback (most recent call last):
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 55, in _pin_memory_loop
    do_one_step()
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 32, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/a

KeyboardInterrupt: 

In [None]:
test(
    model,
    test_loader,
    device,
    criterion,
    best_threshold
)

100%|██████████| 8817/8817 [00:07<00:00, 1156.73it/s]


Test Metrics:
Accuracy:  0.7662
Precision: 0.8077
Recall:    0.6988
F1-score:  0.7493
ROC-AUC:   0.8226
Confusion Matrix:
98566 42494
23471 117589
Test Loss: 0.9688
