In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, random_split
import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score, confusion_matrix

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [3]:
year = 2023
month = 11
perc = 0.9

x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")
x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")

print(f"x_train shape:", x_train.shape)
print(f"x_test shape:", x_test.shape)
print(f"y_train shape:", y_train.shape)
print(f"y_test shape:", y_test.shape)

  x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")


x_train shape: torch.Size([863004, 84])
x_test shape: torch.Size([50736, 84])
y_train shape: torch.Size([863004])
y_test shape: torch.Size([50736])


  x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
  y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
  y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")


In [4]:
input_size = x_train.shape[1]

In [5]:
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)

In [6]:
train_size = int(0.8 * x_train.shape[0])
val_size = x_train.shape[0] - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [7]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=10, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=10, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, num_workers=10, pin_memory=True)

In [8]:
class MLP(nn.Module):
    def __init__(self, input_size: int):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [9]:
def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs):
    best_threshold = 0.0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()  # Set model to training mode
        train_loss = 0.0
        for inputs, targets in tqdm.tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the weights

            train_loss += loss.item()

        train_loss /= len(train_loader)
        print("Train loss:", train_loss)
    
        # Validation
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_probs = []  # Store probabilities for ROC-AUC
        print("Validating...")
        with torch.no_grad():
            for inputs, targets in tqdm.tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs).squeeze(1)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                # Get predictions and probabilities (assuming binary classification with sigmoid output)
                probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
                labels = targets.cpu().numpy()

                all_labels.extend(labels)
                all_probs.extend(probs.flatten())

        val_loss /= len(val_loader)

        # Find threshold for predictions
        print("Looking for threshold")
        best_threshold = 0
        best_f1 = 0
        for threshold in tqdm.tqdm(np.arange(0.2, 0.81, 0.01)):
            preds_binary = (all_probs > threshold).astype(int)
            cm = confusion_matrix(all_labels, preds_binary)
            tp = cm[1, 1]
            fp = cm[0, 1]
            fn = cm[1, 0]
            tn = cm[0, 0]
            precision = 0 if tp == 0 else tp / (tp + fp)
            recall = 0 if tp == 0 else tp / (tp + fn)
            f1 = 0 if precision * recall == 0 else 2 * precision * recall / (precision + recall)
            if f1 > best_f1:
                best_threshold = threshold
                best_f1 = f1
        print(f"Best threshold: {best_threshold}")
        all_preds = (all_probs > best_threshold).astype(int)

        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]

        accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0 # Handle division by zero
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        roc_auc = roc_auc_score(all_labels, all_probs)

        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss      :{val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")
    
    return best_threshold


In [10]:
def test(model, test_loader, device, criterion, best_threshold):
    model.eval()
    test_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for inputs, targets in tqdm.tqdm(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)  # Use criterion here
            test_loss += loss.item()

            probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
            preds = (probs > best_threshold).astype(int)  # Convert probabilities to predictions
            labels = targets.cpu().numpy()

            all_labels.extend(labels)
            all_preds.extend(preds.flatten())
            all_probs.extend(probs.flatten())

    test_loss /= len(test_loader)

    cm = confusion_matrix(all_labels, all_preds)
    tp = cm[1, 1]
    fp = cm[0, 1]
    fn = cm[1, 0]
    tn = cm[0, 0]

    accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    try:
      roc_auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        roc_auc = 0.0

    print(f"Test Metrics:")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")
    print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")
    print(f"Test Loss: {test_loss:.4f}") # Print the loss as well


In [11]:
model = MLP(input_size).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

best_threshold = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    10
)

Epoch 1/10


100%|██████████| 21576/21576 [00:33<00:00, 644.40it/s]


Train loss: 0.12577165061432294
Validating...


100%|██████████| 5394/5394 [00:04<00:00, 1148.27it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.77it/s]


Best threshold: 0.7300000000000004
Validation Metrics - Epoch 1/10:
Loss      :0.0497
Accuracy:  0.9850
Precision: 0.9930
Recall:    0.9769
F1-score:  0.9849
ROC-AUC:   0.9977
Confusion Matrix:
84506 1999
594 85502
Epoch 2/10


100%|██████████| 21576/21576 [00:33<00:00, 644.00it/s]


Train loss: 0.055270030343017605
Validating...


100%|██████████| 5394/5394 [00:04<00:00, 1167.80it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.70it/s]


Best threshold: 0.7100000000000004
Validation Metrics - Epoch 2/10:
Loss      :0.0411
Accuracy:  0.9920
Precision: 0.9999
Recall:    0.9842
F1-score:  0.9920
ROC-AUC:   0.9981
Confusion Matrix:
85137 1368
11 86085
Epoch 3/10


100%|██████████| 21576/21576 [00:36<00:00, 583.51it/s]


Train loss: 0.050299476429407464
Validating...


100%|██████████| 5394/5394 [00:05<00:00, 1048.51it/s]


Looking for threshold


100%|██████████| 62/62 [00:23<00:00,  2.67it/s]


Best threshold: 0.22000000000000003
Validation Metrics - Epoch 3/10:
Loss      :0.0309
Accuracy:  0.9919
Precision: 0.9931
Recall:    0.9907
F1-score:  0.9919
ROC-AUC:   0.9988
Confusion Matrix:
85701 804
592 85504
Epoch 4/10


100%|██████████| 21576/21576 [00:37<00:00, 570.08it/s]


Train loss: 0.04023191972113583
Validating...


100%|██████████| 5394/5394 [00:05<00:00, 1005.80it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.71it/s]


Best threshold: 0.2
Validation Metrics - Epoch 4/10:
Loss      :0.0224
Accuracy:  0.9956
Precision: 0.9998
Recall:    0.9914
F1-score:  0.9956
ROC-AUC:   0.9987
Confusion Matrix:
85765 740
19 86077
Epoch 5/10


100%|██████████| 21576/21576 [00:34<00:00, 624.86it/s]


Train loss: 0.04767420536087492
Validating...


100%|██████████| 5394/5394 [00:04<00:00, 1136.71it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.72it/s]


Best threshold: 0.2
Validation Metrics - Epoch 5/10:
Loss      :0.0249
Accuracy:  0.9949
Precision: 1.0000
Recall:    0.9898
F1-score:  0.9949
ROC-AUC:   0.9983
Confusion Matrix:
85625 880
0 86096
Epoch 6/10


100%|██████████| 21576/21576 [00:32<00:00, 655.95it/s]


Train loss: 0.03578764749703192
Validating...


100%|██████████| 5394/5394 [00:04<00:00, 1156.30it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.74it/s]


Best threshold: 0.25000000000000006
Validation Metrics - Epoch 6/10:
Loss      :0.0225
Accuracy:  0.9963
Precision: 0.9995
Recall:    0.9932
F1-score:  0.9963
ROC-AUC:   0.9984
Confusion Matrix:
85913 592
42 86054
Epoch 7/10


100%|██████████| 21576/21576 [00:33<00:00, 648.21it/s]


Train loss: 0.030264798156805363
Validating...


100%|██████████| 5394/5394 [00:04<00:00, 1170.00it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.72it/s]


Best threshold: 0.2
Validation Metrics - Epoch 7/10:
Loss      :0.0383
Accuracy:  0.9919
Precision: 0.9992
Recall:    0.9846
F1-score:  0.9919
ROC-AUC:   0.9967
Confusion Matrix:
85174 1331
67 86029
Epoch 8/10


100%|██████████| 21576/21576 [00:33<00:00, 635.67it/s]


Train loss: 0.030184068677883723
Validating...


100%|██████████| 5394/5394 [00:04<00:00, 1175.25it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.72it/s]


Best threshold: 0.35000000000000014
Validation Metrics - Epoch 8/10:
Loss      :0.0305
Accuracy:  0.9939
Precision: 0.9999
Recall:    0.9879
F1-score:  0.9939
ROC-AUC:   0.9977
Confusion Matrix:
85458 1047
9 86087
Epoch 9/10


100%|██████████| 21576/21576 [00:34<00:00, 623.02it/s]


Train loss: 0.03159454565586905
Validating...


100%|██████████| 5394/5394 [00:04<00:00, 1123.67it/s]


Looking for threshold


100%|██████████| 62/62 [00:22<00:00,  2.70it/s]


Best threshold: 0.2
Validation Metrics - Epoch 9/10:
Loss      :0.0259
Accuracy:  0.9949
Precision: 1.0000
Recall:    0.9899
F1-score:  0.9949
ROC-AUC:   0.9979
Confusion Matrix:
85631 874
2 86094
Epoch 10/10


 10%|▉         | 2122/21576 [00:03<00:31, 616.74it/s]Exception in thread Thread-23 (_pin_memory_loop):
Traceback (most recent call last):
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 55, in _pin_memory_loop
    do_one_step()
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 32, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home

KeyboardInterrupt: 

In [15]:
test(
    model,
    test_loader,
    device,
    criterion,
    0.5
)

100%|██████████| 1586/1586 [00:01<00:00, 971.77it/s] 


Test Metrics:
Accuracy:  0.9922
Precision: 0.9978
Recall:    0.9866
F1-score:  0.9922
ROC-AUC:   0.9980
Confusion Matrix:
25028 340
54 25314
Test Loss: 0.0298
