In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, random_split
import tqdm
import pandas as pd
import copy
import numpy as np
import multiprocessing as mp
from sklearn.metrics import roc_auc_score, confusion_matrix

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [None]:
model_name = "mlpnorm"
year = 2019
month = 11
perc = 0

x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")
x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")

print(f"x_train shape:", x_train.shape)
print(f"x_test shape:", x_test.shape)
print(f"y_train shape:", y_train.shape)
print(f"y_test shape:", y_test.shape)

  x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")


x_train shape: torch.Size([863004, 85])
x_test shape: torch.Size([50736, 85])
y_train shape: torch.Size([863004])
y_test shape: torch.Size([50736])


  x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
  y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
  y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")


In [4]:
def apply_normalization(x_train, x_test):
    mean = torch.mean(x_train, dim=0)
    std = torch.std(x_train, dim=0)

    x_train_normalized = (x_train - mean) / std
    x_test_normalized = (x_test - mean) / std

    return x_train_normalized, x_test_normalized

In [5]:
def apply_pca_explained_variance_normalized(x_train, x_test, target_variance=0.9):
    # Calculate mean and standard deviation from training data
    x_train_normalized, x_test_normalized = apply_normalization(x_train, x_test)

    # Perform PCA with all components initially
    U, S, V = torch.pca_lowrank(x_train_normalized)

    # Calculate explained variance ratio
    explained_variance_ratio = S**2 / torch.sum(S**2)
    cumulative_variance_ratio = torch.cumsum(explained_variance_ratio, dim=0)

    # Find the number of components to explain target variance
    n_components = torch.argmax((cumulative_variance_ratio >= target_variance).int()) + 1

    # Project the normalized data onto the selected components
    x_train_pca = torch.matmul(x_train_normalized, V[:, :n_components])
    x_test_pca = torch.matmul(x_test_normalized, V[:, :n_components])

    return x_train_pca, x_test_pca, n_components

In [None]:
apply_norm = True

if apply_norm:
    index_mask = [0, 1, 2, 8, *[i for i in range(15, 42)]]
    index_mask += [i + 42 for i in index_mask] + [84]
    keep_indexes = [i for i in range(x_train.shape[1]) if i not in index_mask]
    x_train_norm, x_test_norm = apply_normalization(x_train[:, index_mask], x_test[:, index_mask])

    new_x_train = torch.hstack([x_train_norm, x_train[:, keep_indexes]])
    new_x_test = torch.hstack([x_test_norm, x_test[:, keep_indexes]])

    x_train = new_x_train
    x_test = new_x_test

    print(new_x_train.shape)
    print(new_x_test.shape)

In [None]:
apply_PCA = False

if apply_PCA:
    index_mask = [0, 1, 2, 8, *[i for i in range(15, 42)]]
    index_mask += [i + 42 for i in index_mask] + [84]
    keep_indexes = [i for i in range(x_train.shape[1]) if i not in index_mask]
    x_train_pca, x_test_pca, num_components_used = apply_pca_explained_variance_normalized(x_train[:, index_mask], x_test[:, index_mask], target_variance=0.9)

    new_x_train = torch.hstack([x_train_pca, x_train[:, keep_indexes]])
    new_x_test = torch.hstack([x_test_pca, x_test[:, keep_indexes]])

    x_train = new_x_train
    x_test = new_x_test

    print(f"Number of components used: {num_components_used}")
    print(new_x_train.shape)
    print(new_x_test.shape)

Number of components used: 5
torch.Size([863004, 27])
torch.Size([50736, 27])


In [8]:
results_df = pd.read_csv("results.csv", dtype={
    "model": str,
    "year": int,
    "month": int,
    "perc": float,
    "epoch": int,
    "train_loss": float,
    "val_loss": float,
    "acc": float,
    "prec": float,
    "rec": float,
    "f1": float,
    "auc": float,
    "tp": int,
    "fp": int,
    "fn": int,
    "tn": int,
    "best_threshold": float,
    "done": bool
})

filtered_df = results_df[
    (results_df["model"] == model_name) &
    (results_df["year"] == year) &
    (results_df["month"] == month) &
    (results_df["perc"] == perc)
]

if filtered_df.empty:
    latest_epoch = 0
    is_trained = False
else:
    latest_epoch = filtered_df["epoch"].max()
    is_trained = filtered_df["done"].any()

print("Latest epoch:", latest_epoch)
print("Is trained?", is_trained)

Latest epoch: 0
Is trained? False


In [9]:
input_size = x_train.shape[1]

In [10]:
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)

In [11]:
train_size = int(0.8 * x_train.shape[0])
val_size = x_train.shape[0] - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [12]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=10, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, num_workers=10, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, num_workers=10, pin_memory=True)

In [13]:
class MLP(nn.Module):
    def __init__(self, input_size: int):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        # self.layers = nn.Sequential(
        #     nn.Linear(input_size, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 1),
        # )

    def forward(self, x):
        return self.layers(x)

In [14]:
def calculate_metrics(threshold, all_probs, all_labels):
    preds_binary = (all_probs > threshold).astype(int)
    cm = confusion_matrix(all_labels, preds_binary)
    tp = cm[1, 1]
    fp = cm[0, 1]
    fn = cm[1, 0]
    tn = cm[0, 0]
    precision = 0 if tp == 0 else tp / (tp + fp)
    recall = 0 if tp == 0 else tp / (tp + fn)
    f1 = 0 if precision * recall == 0 else 2 * precision * recall / (precision + recall)
    return threshold, f1

In [15]:
def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs, results_df, patience=5):
    best_threshold = 0.0
    best_val_f1 = 0.0
    epochs_no_improve = 0
    best_model_state = None
    train_losses = list()
    val_losses = list()
    best_epoch = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()  # Set model to training mode
        train_loss = 0.0
        for inputs, targets in tqdm.tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the weights

            train_loss += loss.item()

        train_loss /= len(train_loader)
        print("Train loss:", train_loss)
    
        # Validation
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_probs = []  # Store probabilities for ROC-AUC
        print("Validating...")
        with torch.no_grad():
            for inputs, targets in tqdm.tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs).squeeze(1)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                # Get predictions and probabilities (assuming binary classification with sigmoid output)
                probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
                labels = targets.cpu().numpy()

                all_labels.extend(labels)
                all_probs.extend(probs.flatten())

        val_loss /= len(val_loader)

        # Find threshold for predictions
        print("Looking for threshold")

        with mp.Pool(10) as pool:
            results = pool.starmap(
                calculate_metrics, 
                [
                    (threshold, all_probs, all_labels)
                    for threshold in np.arange(0.05, 0.96, 0.01)
                ]
            )

        best_threshold_epoch = 0
        best_f1_epoch = -1
        for threshold, f1 in results:
            if f1 > best_f1_epoch:
                best_f1_epoch = f1
                best_threshold_epoch = threshold

        print(f"Best threshold: {best_threshold_epoch}")
        all_preds = (all_probs > best_threshold_epoch).astype(int)

        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]

        accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0 # Handle division by zero
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        roc_auc = roc_auc_score(all_labels, all_probs)

        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss      :{val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

        new_row = pd.DataFrame(
            {
                "model": [model_name],
                "year": [year],
                "month": [month],
                "perc": [perc],
                "epoch": [latest_epoch + epoch + 1],
                "train_loss": [train_loss],
                "val_loss": [val_loss],
                "acc": [accuracy],
                "prec": [precision],
                "rec": [recall],
                "f1": [f1],
                "auc": [roc_auc],
                "tp": [tp],
                "fp": [fp],
                "fn": [fn],
                "tn": [tn],
                "best_threshold": [best_threshold_epoch],
                "done": [False]
            }
        )
        results_df = pd.concat([results_df, new_row], ignore_index=True)
        results_df.to_csv("results.csv", index=False)

        torch.save(model.state_dict(), f"./model_{model_name}_{year}_{month}_{perc}_{latest_epoch + epoch + 1}.pth")

        if f1 > best_val_f1:
            best_val_f1 = f1
            best_threshold = best_threshold_epoch
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
            best_epoch = latest_epoch + epoch + 1
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print("Best epoch:", best_epoch)
                model.load_state_dict(best_model_state)
                break
    
    return best_threshold


In [16]:
def test(model, test_loader, device, criterion, best_threshold):
    model.eval()
    test_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for inputs, targets in tqdm.tqdm(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)  # Use criterion here
            test_loss += loss.item()

            probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
            preds = (probs > best_threshold).astype(int)  # Convert probabilities to predictions
            labels = targets.cpu().numpy()

            all_labels.extend(labels)
            all_preds.extend(preds.flatten())
            all_probs.extend(probs.flatten())

    test_loss /= len(test_loader)

    cm = confusion_matrix(all_labels, all_preds)
    tp = cm[1, 1]
    fp = cm[0, 1]
    fn = cm[1, 0]
    tn = cm[0, 0]

    accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    try:
      roc_auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        roc_auc = 0.0

    print(f"Test Metrics:")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")
    print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")
    print(f"Test Loss: {test_loss:.4f}") # Print the loss as well


In [17]:
model = MLP(input_size).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_threshold = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    1000,
    results_df
)

Epoch 1/1000


100%|██████████| 5394/5394 [00:08<00:00, 607.47it/s]


Train loss: 0.3675663983259415
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 944.50it/s] 

Looking for threshold





Best threshold: 0.44000000000000006
Validation Metrics - Epoch 1/1000:
Loss      :0.3491
Accuracy:  0.8438
Precision: 0.8301
Recall:    0.8643
F1-score:  0.8469
ROC-AUC:   0.9241
Confusion Matrix:
74578 11706
15262 71055
Epoch 2/1000


100%|██████████| 5394/5394 [00:09<00:00, 589.67it/s]


Train loss: 0.3442218260971493
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 968.77it/s] 

Looking for threshold





Best threshold: 0.42000000000000004
Validation Metrics - Epoch 2/1000:
Loss      :0.3387
Accuracy:  0.8499
Precision: 0.8384
Recall:    0.8667
F1-score:  0.8523
ROC-AUC:   0.9285
Confusion Matrix:
74784 11500
14415 71902
Epoch 3/1000


100%|██████████| 5394/5394 [00:08<00:00, 606.92it/s]


Train loss: 0.33699963819463474
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 984.79it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 3/1000:
Loss      :0.3352
Accuracy:  0.8517
Precision: 0.8418
Recall:    0.8661
F1-score:  0.8538
ROC-AUC:   0.9299
Confusion Matrix:
74731 11553
14042 72275
Epoch 4/1000


100%|██████████| 5394/5394 [00:09<00:00, 569.89it/s]


Train loss: 0.3327221684917548
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 978.83it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 4/1000:
Loss      :0.3324
Accuracy:  0.8528
Precision: 0.8424
Recall:    0.8679
F1-score:  0.8550
ROC-AUC:   0.9312
Confusion Matrix:
74882 11402
14005 72312
Epoch 5/1000


100%|██████████| 5394/5394 [00:09<00:00, 570.95it/s]


Train loss: 0.3292246349832726
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 948.36it/s] 

Looking for threshold





Best threshold: 0.39000000000000007
Validation Metrics - Epoch 5/1000:
Loss      :0.3307
Accuracy:  0.8542
Precision: 0.8453
Recall:    0.8670
F1-score:  0.8560
ROC-AUC:   0.9325
Confusion Matrix:
74807 11477
13691 72626
Epoch 6/1000


100%|██████████| 5394/5394 [00:09<00:00, 568.34it/s]


Train loss: 0.32660850193275537
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 960.09it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 6/1000:
Loss      :0.3279
Accuracy:  0.8544
Precision: 0.8429
Recall:    0.8710
F1-score:  0.8567
ROC-AUC:   0.9330
Confusion Matrix:
75154 11130
14006 72311
Epoch 7/1000


100%|██████████| 5394/5394 [00:08<00:00, 601.62it/s]


Train loss: 0.32442479533953805
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 960.80it/s] 

Looking for threshold





Best threshold: 0.42000000000000004
Validation Metrics - Epoch 7/1000:
Loss      :0.3248
Accuracy:  0.8548
Precision: 0.8385
Recall:    0.8789
F1-score:  0.8582
ROC-AUC:   0.9343
Confusion Matrix:
75832 10452
14609 71708
Epoch 8/1000


100%|██████████| 5394/5394 [00:08<00:00, 607.51it/s]


Train loss: 0.32221601165563124
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 966.21it/s] 

Looking for threshold





Best threshold: 0.44000000000000006
Validation Metrics - Epoch 8/1000:
Loss      :0.3235
Accuracy:  0.8568
Precision: 0.8471
Recall:    0.8707
F1-score:  0.8587
ROC-AUC:   0.9348
Confusion Matrix:
75130 11154
13563 72754
Epoch 9/1000


100%|██████████| 5394/5394 [00:10<00:00, 503.17it/s]


Train loss: 0.3207550944306694
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 926.64it/s] 

Looking for threshold





Best threshold: 0.4600000000000001
Validation Metrics - Epoch 9/1000:
Loss      :0.3217
Accuracy:  0.8587
Precision: 0.8522
Recall:    0.8679
F1-score:  0.8600
ROC-AUC:   0.9355
Confusion Matrix:
74888 11396
12987 73330
Epoch 10/1000


100%|██████████| 5394/5394 [00:08<00:00, 638.91it/s]


Train loss: 0.31918521018261636
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 953.14it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 10/1000:
Loss      :0.3225
Accuracy:  0.8575
Precision: 0.8452
Recall:    0.8753
F1-score:  0.8600
ROC-AUC:   0.9353
Confusion Matrix:
75524 10760
13837 72480
Epoch 11/1000


100%|██████████| 5394/5394 [00:08<00:00, 633.07it/s]


Train loss: 0.31755455920088055
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 937.38it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 11/1000:
Loss      :0.3210
Accuracy:  0.8580
Precision: 0.8464
Recall:    0.8748
F1-score:  0.8603
ROC-AUC:   0.9358
Confusion Matrix:
75477 10807
13696 72621
Epoch 12/1000


100%|██████████| 5394/5394 [00:08<00:00, 633.28it/s]


Train loss: 0.31642417323215827
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 938.42it/s] 

Looking for threshold





Best threshold: 0.4000000000000001
Validation Metrics - Epoch 12/1000:
Loss      :0.3197
Accuracy:  0.8582
Precision: 0.8453
Recall:    0.8767
F1-score:  0.8607
ROC-AUC:   0.9363
Confusion Matrix:
75647 10637
13842 72475
Epoch 13/1000


100%|██████████| 5394/5394 [00:09<00:00, 557.38it/s]


Train loss: 0.3152544028921661
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 959.62it/s] 

Looking for threshold





Best threshold: 0.4800000000000001
Validation Metrics - Epoch 13/1000:
Loss      :0.3191
Accuracy:  0.8605
Precision: 0.8539
Recall:    0.8698
F1-score:  0.8618
ROC-AUC:   0.9369
Confusion Matrix:
75052 11232
12839 73478
Epoch 14/1000


100%|██████████| 5394/5394 [00:08<00:00, 633.80it/s]


Train loss: 0.31410894534921135
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 938.86it/s] 

Looking for threshold





Best threshold: 0.38000000000000006
Validation Metrics - Epoch 14/1000:
Loss      :0.3177
Accuracy:  0.8597
Precision: 0.8460
Recall:    0.8794
F1-score:  0.8624
ROC-AUC:   0.9373
Confusion Matrix:
75878 10406
13810 72507
Epoch 15/1000


100%|██████████| 5394/5394 [00:08<00:00, 623.75it/s]


Train loss: 0.3132190177445726
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 963.37it/s] 

Looking for threshold





Best threshold: 0.4000000000000001
Validation Metrics - Epoch 15/1000:
Loss      :0.3172
Accuracy:  0.8598
Precision: 0.8474
Recall:    0.8775
F1-score:  0.8622
ROC-AUC:   0.9374
Confusion Matrix:
75711 10573
13633 72684
Epoch 16/1000


100%|██████████| 5394/5394 [00:08<00:00, 631.07it/s]


Train loss: 0.31236873723188063
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 954.62it/s] 

Looking for threshold





Best threshold: 0.45000000000000007
Validation Metrics - Epoch 16/1000:
Loss      :0.3166
Accuracy:  0.8605
Precision: 0.8512
Recall:    0.8737
F1-score:  0.8623
ROC-AUC:   0.9377
Confusion Matrix:
75386 10898
13178 73139
Epoch 17/1000


100%|██████████| 5394/5394 [00:09<00:00, 558.39it/s]


Train loss: 0.31129747880833386
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 947.78it/s] 

Looking for threshold





Best threshold: 0.49000000000000005
Validation Metrics - Epoch 17/1000:
Loss      :0.3178
Accuracy:  0.8613
Precision: 0.8599
Recall:    0.8631
F1-score:  0.8615
ROC-AUC:   0.9375
Confusion Matrix:
74476 11808
12139 74178
Epoch 18/1000


100%|██████████| 5394/5394 [00:09<00:00, 560.94it/s]


Train loss: 0.31045321930955416
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 957.15it/s] 

Looking for threshold





Best threshold: 0.4000000000000001
Validation Metrics - Epoch 18/1000:
Loss      :0.3160
Accuracy:  0.8601
Precision: 0.8478
Recall:    0.8777
F1-score:  0.8625
ROC-AUC:   0.9378
Confusion Matrix:
75734 10550
13597 72720
Epoch 19/1000


100%|██████████| 5394/5394 [00:09<00:00, 581.94it/s]


Train loss: 0.3095543066224081
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 944.33it/s] 

Looking for threshold





Best threshold: 0.37000000000000005
Validation Metrics - Epoch 19/1000:
Loss      :0.3190
Accuracy:  0.8607
Precision: 0.8526
Recall:    0.8723
F1-score:  0.8623
ROC-AUC:   0.9376
Confusion Matrix:
75263 11021
13014 73303
Epoch 20/1000


100%|██████████| 5394/5394 [00:10<00:00, 532.52it/s]


Train loss: 0.30879150038743225
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 911.54it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 20/1000:
Loss      :0.3142
Accuracy:  0.8625
Precision: 0.8561
Recall:    0.8713
F1-score:  0.8636
ROC-AUC:   0.9384
Confusion Matrix:
75179 11105
12634 73683
Epoch 21/1000


100%|██████████| 5394/5394 [00:09<00:00, 541.85it/s]


Train loss: 0.30812251668120916
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 940.13it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 21/1000:
Loss      :0.3131
Accuracy:  0.8613
Precision: 0.8518
Recall:    0.8748
F1-score:  0.8631
ROC-AUC:   0.9388
Confusion Matrix:
75479 10805
13137 73180
Epoch 22/1000


100%|██████████| 5394/5394 [00:08<00:00, 651.32it/s]


Train loss: 0.30706378064212597
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 917.32it/s] 

Looking for threshold





Best threshold: 0.4000000000000001
Validation Metrics - Epoch 22/1000:
Loss      :0.3126
Accuracy:  0.8616
Precision: 0.8478
Recall:    0.8814
F1-score:  0.8643
ROC-AUC:   0.9391
Confusion Matrix:
76053 10231
13653 72664
Epoch 23/1000


100%|██████████| 5394/5394 [00:10<00:00, 532.05it/s]


Train loss: 0.3063840119842914
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 921.16it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 23/1000:
Loss      :0.3140
Accuracy:  0.8610
Precision: 0.8507
Recall:    0.8755
F1-score:  0.8630
ROC-AUC:   0.9384
Confusion Matrix:
75544 10740
13255 73062
Epoch 24/1000


100%|██████████| 5394/5394 [00:09<00:00, 560.81it/s]


Train loss: 0.30573106009594636
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 894.97it/s] 

Looking for threshold





Best threshold: 0.4700000000000001
Validation Metrics - Epoch 24/1000:
Loss      :0.3130
Accuracy:  0.8623
Precision: 0.8506
Recall:    0.8790
F1-score:  0.8646
ROC-AUC:   0.9396
Confusion Matrix:
75842 10442
13321 72996
Epoch 25/1000


100%|██████████| 5394/5394 [00:09<00:00, 590.48it/s]


Train loss: 0.3049061662832763
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 910.81it/s] 

Looking for threshold





Best threshold: 0.42000000000000004
Validation Metrics - Epoch 25/1000:
Loss      :0.3139
Accuracy:  0.8607
Precision: 0.8470
Recall:    0.8804
F1-score:  0.8633
ROC-AUC:   0.9385
Confusion Matrix:
75962 10322
13726 72591
Epoch 26/1000


100%|██████████| 5394/5394 [00:09<00:00, 585.70it/s]


Train loss: 0.30440104188667566
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 747.30it/s] 

Looking for threshold





Best threshold: 0.45000000000000007
Validation Metrics - Epoch 26/1000:
Loss      :0.3123
Accuracy:  0.8621
Precision: 0.8508
Recall:    0.8781
F1-score:  0.8642
ROC-AUC:   0.9394
Confusion Matrix:
75767 10517
13287 73030
Epoch 27/1000


100%|██████████| 5394/5394 [00:08<00:00, 618.55it/s]


Train loss: 0.3037130754519755
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 907.80it/s] 

Looking for threshold





Best threshold: 0.4700000000000001
Validation Metrics - Epoch 27/1000:
Loss      :0.3140
Accuracy:  0.8622
Precision: 0.8531
Recall:    0.8750
F1-score:  0.8639
ROC-AUC:   0.9390
Confusion Matrix:
75496 10788
13003 73314
Epoch 28/1000


100%|██████████| 5394/5394 [00:09<00:00, 567.92it/s]


Train loss: 0.30311931134607334
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 816.35it/s] 

Looking for threshold





Best threshold: 0.4700000000000001
Validation Metrics - Epoch 28/1000:
Loss      :0.3115
Accuracy:  0.8633
Precision: 0.8508
Recall:    0.8812
F1-score:  0.8657
ROC-AUC:   0.9402
Confusion Matrix:
76034 10250
13338 72979
Epoch 29/1000


100%|██████████| 5394/5394 [00:09<00:00, 581.98it/s]


Train loss: 0.30256774597226316
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 874.73it/s] 

Looking for threshold





Best threshold: 0.4600000000000001
Validation Metrics - Epoch 29/1000:
Loss      :0.3109
Accuracy:  0.8636
Precision: 0.8538
Recall:    0.8775
F1-score:  0.8655
ROC-AUC:   0.9402
Confusion Matrix:
75717 10567
12970 73347
Epoch 30/1000


100%|██████████| 5394/5394 [00:09<00:00, 566.26it/s]


Train loss: 0.3019932804171315
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 882.81it/s] 

Looking for threshold





Best threshold: 0.42000000000000004
Validation Metrics - Epoch 30/1000:
Loss      :0.3113
Accuracy:  0.8630
Precision: 0.8560
Recall:    0.8727
F1-score:  0.8643
ROC-AUC:   0.9395
Confusion Matrix:
75303 10981
12665 73652
Epoch 31/1000


100%|██████████| 5394/5394 [00:09<00:00, 593.57it/s]


Train loss: 0.3014644068349146
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 843.11it/s] 

Looking for threshold





Best threshold: 0.39000000000000007
Validation Metrics - Epoch 31/1000:
Loss      :0.3124
Accuracy:  0.8643
Precision: 0.8603
Recall:    0.8698
F1-score:  0.8650
ROC-AUC:   0.9402
Confusion Matrix:
75051 11233
12192 74125
Epoch 32/1000


100%|██████████| 5394/5394 [00:08<00:00, 602.00it/s]


Train loss: 0.3007806080649374
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 868.78it/s] 

Looking for threshold





Best threshold: 0.36000000000000004
Validation Metrics - Epoch 32/1000:
Loss      :0.3106
Accuracy:  0.8627
Precision: 0.8484
Recall:    0.8832
F1-score:  0.8654
ROC-AUC:   0.9404
Confusion Matrix:
76205 10079
13619 72698
Epoch 33/1000


100%|██████████| 5394/5394 [00:09<00:00, 591.44it/s]


Train loss: 0.30015916174535273
Validating...


100%|██████████| 1349/1349 [00:01<00:00, 890.11it/s] 

Looking for threshold





Best threshold: 0.43000000000000005
Validation Metrics - Epoch 33/1000:
Loss      :0.3097
Accuracy:  0.8639
Precision: 0.8557
Recall:    0.8753
F1-score:  0.8654
ROC-AUC:   0.9402
Confusion Matrix:
75524 10760
12737 73580
Early stopping!!!
Early stopping!!!
Early stopping!!!
Best epoch: 28


In [18]:
test(
    model,
    test_loader,
    device,
    criterion,
    best_threshold
)

100%|██████████| 397/397 [00:00<00:00, 544.88it/s]


Test Metrics:
Accuracy:  0.8327
Precision: 0.8241
Recall:    0.8460
F1-score:  0.8349
ROC-AUC:   0.9183
Confusion Matrix:
21461 3907
4581 20787
Test Loss: 0.3609
