In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, random_split
import tqdm
import pandas as pd
import copy
import numpy as np
from sklearn.metrics import roc_auc_score, confusion_matrix

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [3]:
year = 2021
month = 11
perc = 0.75

x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")
x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")

print(f"x_train shape:", x_train.shape)
print(f"x_test shape:", x_test.shape)
print(f"y_train shape:", y_train.shape)
print(f"y_test shape:", y_test.shape)

  x_train = torch.load(f"ds/mlp{year}{month}{perc}x_train.pt")


x_train shape: torch.Size([1526336, 85])
x_test shape: torch.Size([313544, 85])
y_train shape: torch.Size([1526336])
y_test shape: torch.Size([313544])


  x_test = torch.load(f"ds/mlp{year}{month}{perc}x_test.pt")
  y_train = torch.load(f"ds/mlp{year}{month}{perc}y_train.pt")
  y_test = torch.load(f"ds/mlp{year}{month}{perc}y_test.pt")


In [4]:
results_df = pd.read_csv("results.csv", dtype={
    "model": str,
    "year": int,
    "month": int,
    "perc": float,
    "epoch": int,
    "train_loss": float,
    "val_loss": float,
    "acc": float,
    "prec": float,
    "rec": float,
    "f1": float,
    "auc": float,
    "tp": int,
    "fp": int,
    "fn": int,
    "tn": int,
    "best_threshold": float,
    "done": bool
})

filtered_df = results_df[
    (results_df["model"] == "mlp") &
    (results_df["year"] == year) &
    (results_df["month"] == month) &
    (results_df["perc"] == perc)
]

if filtered_df.empty:
    latest_epoch = 0
    is_trained = False
else:
    latest_epoch = filtered_df["epoch"].max()
    is_trained = filtered_df["done"].any()

print("Latest epoch:", latest_epoch)
print("Is trained?", is_trained)

Latest epoch: 0
Is trained? False


In [5]:
input_size = x_train.shape[1]

In [6]:
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)

In [7]:
train_size = int(0.8 * x_train.shape[0])
val_size = x_train.shape[0] - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [8]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=5, pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=5, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, num_workers=5, pin_memory=True)

In [9]:
class MLP(nn.Module):
    def __init__(self, input_size: int):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [10]:
def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs, results_df, patience=5):
    best_threshold = 0.0
    best_val_f1 = 0.0
    epochs_no_improve = 0
    best_model_state = None
    train_losses = list()
    val_losses = list()
    best_epoch = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()  # Set model to training mode
        train_loss = 0.0
        for inputs, targets in tqdm.tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the weights

            train_loss += loss.item()

        train_loss /= len(train_loader)
        print("Train loss:", train_loss)
    
        # Validation
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_probs = []  # Store probabilities for ROC-AUC
        print("Validating...")
        with torch.no_grad():
            for inputs, targets in tqdm.tqdm(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs = model(inputs).squeeze(1)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                # Get predictions and probabilities (assuming binary classification with sigmoid output)
                probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
                labels = targets.cpu().numpy()

                all_labels.extend(labels)
                all_probs.extend(probs.flatten())

        val_loss /= len(val_loader)

        # Find threshold for predictions
        print("Looking for threshold")
        best_threshold_epoch = 0
        best_f1_epoch = 0
        for threshold in tqdm.tqdm(np.arange(0.05, 0.96, 0.01)):
            preds_binary = (all_probs > threshold).astype(int)
            cm = confusion_matrix(all_labels, preds_binary)
            tp = cm[1, 1]
            fp = cm[0, 1]
            fn = cm[1, 0]
            tn = cm[0, 0]
            precision = 0 if tp == 0 else tp / (tp + fp)
            recall = 0 if tp == 0 else tp / (tp + fn)
            f1 = 0 if precision * recall == 0 else 2 * precision * recall / (precision + recall)
            if f1 > best_f1_epoch:
                best_threshold_epoch = threshold
                best_f1_epoch = f1
        print(f"Best threshold: {best_threshold_epoch}")
        all_preds = (all_probs > best_threshold_epoch).astype(int)

        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]

        accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0 # Handle division by zero
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        roc_auc = roc_auc_score(all_labels, all_probs)

        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss      :{val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

        new_row = pd.DataFrame(
            {
                "model": ["mlp"],
                "year": [year],
                "month": [month],
                "perc": [perc],
                "epoch": [latest_epoch + epoch + 1],
                "train_loss": [train_loss],
                "val_loss": [val_loss],
                "acc": [accuracy],
                "prec": [precision],
                "rec": [recall],
                "f1": [f1],
                "auc": [roc_auc],
                "tp": [tp],
                "fp": [fp],
                "fn": [fn],
                "tn": [tn],
                "best_threshold": [best_threshold_epoch],
                "done": [False]
            }
        )
        results_df = pd.concat([results_df, new_row], ignore_index=True)
        results_df.to_csv("results.csv", index=False)

        torch.save(model.state_dict(), f"./model_mlp_{year}_{month}_{perc}_{latest_epoch + epoch + 1}.pth")

        if f1 > best_val_f1:
            best_val_f1 = f1
            best_threshold = best_threshold_epoch
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
            best_epoch = latest_epoch + epoch + 1
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print("Best epoch:", best_epoch)
                model.load_state_dict(best_model_state)
                break
    
    return best_threshold


In [11]:
def test(model, test_loader, device, criterion, best_threshold):
    model.eval()
    test_loss = 0.0
    all_labels = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for inputs, targets in tqdm.tqdm(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs).squeeze(1)
            loss = criterion(outputs, targets)  # Use criterion here
            test_loss += loss.item()

            probs = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid if needed
            preds = (probs > best_threshold).astype(int)  # Convert probabilities to predictions
            labels = targets.cpu().numpy()

            all_labels.extend(labels)
            all_preds.extend(preds.flatten())
            all_probs.extend(probs.flatten())

    test_loss /= len(test_loader)

    cm = confusion_matrix(all_labels, all_preds)
    tp = cm[1, 1]
    fp = cm[0, 1]
    fn = cm[1, 0]
    tn = cm[0, 0]

    accuracy = (tp + tn) / (tp + fp + fn + tn) if (tp + fp + fn + tn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    try:
      roc_auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        roc_auc = 0.0

    print(f"Test Metrics:")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")
    print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")
    print(f"Test Loss: {test_loss:.4f}") # Print the loss as well


In [12]:
model = MLP(input_size).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

best_threshold = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device,
    1000,
    results_df
)

Epoch 1/1000


100%|██████████| 38159/38159 [00:58<00:00, 651.71it/s]


Train loss: 0.2859578181562697
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1202.32it/s]


Looking for threshold


100%|██████████| 91/91 [00:59<00:00,  1.54it/s]


Best threshold: 0.5400000000000001
Validation Metrics - Epoch 1/1000:
Loss      :0.2106
Accuracy:  0.9109
Precision: 0.9006
Recall:    0.9235
F1-score:  0.9119
ROC-AUC:   0.9737
Confusion Matrix:
140651 11657
15529 137431
Epoch 2/1000


100%|██████████| 38159/38159 [00:59<00:00, 644.12it/s]


Train loss: 0.19771351786326194
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1196.89it/s]


Looking for threshold


100%|██████████| 91/91 [00:57<00:00,  1.57it/s]


Best threshold: 0.32000000000000006
Validation Metrics - Epoch 2/1000:
Loss      :0.1806
Accuracy:  0.9258
Precision: 0.9128
Recall:    0.9411
F1-score:  0.9267
ROC-AUC:   0.9816
Confusion Matrix:
143335 8973
13690 139270
Epoch 3/1000


100%|██████████| 38159/38159 [00:57<00:00, 663.60it/s]


Train loss: 0.16647106700758255
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1276.98it/s]


Looking for threshold


100%|██████████| 91/91 [00:58<00:00,  1.57it/s]


Best threshold: 0.5600000000000002
Validation Metrics - Epoch 3/1000:
Loss      :0.1592
Accuracy:  0.9349
Precision: 0.9265
Recall:    0.9445
F1-score:  0.9354
ROC-AUC:   0.9847
Confusion Matrix:
143860 8448
11411 141549
Epoch 4/1000


100%|██████████| 38159/38159 [00:57<00:00, 663.45it/s]


Train loss: 0.14964225380247143
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1238.04it/s]


Looking for threshold


100%|██████████| 91/91 [00:57<00:00,  1.58it/s]


Best threshold: 0.43000000000000005
Validation Metrics - Epoch 4/1000:
Loss      :0.1453
Accuracy:  0.9404
Precision: 0.9299
Recall:    0.9524
F1-score:  0.9410
ROC-AUC:   0.9873
Confusion Matrix:
145055 7253
10941 142019
Epoch 5/1000


100%|██████████| 38159/38159 [00:57<00:00, 663.59it/s]


Train loss: 0.13916772406924924
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1191.60it/s]


Looking for threshold


100%|██████████| 91/91 [00:59<00:00,  1.52it/s]


Best threshold: 0.42000000000000004
Validation Metrics - Epoch 5/1000:
Loss      :0.1321
Accuracy:  0.9452
Precision: 0.9360
Recall:    0.9556
F1-score:  0.9457
ROC-AUC:   0.9894
Confusion Matrix:
145550 6758
9960 143000
Epoch 6/1000


100%|██████████| 38159/38159 [01:01<00:00, 622.80it/s]


Train loss: 0.12857103632402325
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1173.89it/s]


Looking for threshold


100%|██████████| 91/91 [00:57<00:00,  1.58it/s]


Best threshold: 0.5100000000000001
Validation Metrics - Epoch 6/1000:
Loss      :0.1201
Accuracy:  0.9497
Precision: 0.9418
Recall:    0.9585
F1-score:  0.9501
ROC-AUC:   0.9912
Confusion Matrix:
145993 6315
9026 143934
Epoch 7/1000


100%|██████████| 38159/38159 [00:55<00:00, 683.65it/s]


Train loss: 0.12096042792631695
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1259.00it/s]


Looking for threshold


100%|██████████| 91/91 [00:57<00:00,  1.58it/s]


Best threshold: 0.45000000000000007
Validation Metrics - Epoch 7/1000:
Loss      :0.1201
Accuracy:  0.9490
Precision: 0.9384
Recall:    0.9608
F1-score:  0.9495
ROC-AUC:   0.9909
Confusion Matrix:
146336 5972
9601 143359
Epoch 8/1000


100%|██████████| 38159/38159 [00:57<00:00, 658.91it/s]


Train loss: 0.11487410850247438
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1304.45it/s]


Looking for threshold


100%|██████████| 91/91 [00:57<00:00,  1.58it/s]


Best threshold: 0.38000000000000006
Validation Metrics - Epoch 8/1000:
Loss      :0.1117
Accuracy:  0.9541
Precision: 0.9442
Recall:    0.9650
F1-score:  0.9545
ROC-AUC:   0.9925
Confusion Matrix:
146974 5334
8686 144274
Epoch 9/1000


100%|██████████| 38159/38159 [00:58<00:00, 653.51it/s]


Train loss: 0.11206605615460132
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1289.65it/s]


Looking for threshold


100%|██████████| 91/91 [00:56<00:00,  1.60it/s]


Best threshold: 0.45000000000000007
Validation Metrics - Epoch 9/1000:
Loss      :0.1081
Accuracy:  0.9554
Precision: 0.9527
Recall:    0.9582
F1-score:  0.9554
ROC-AUC:   0.9931
Confusion Matrix:
145941 6367
7249 145711
Epoch 10/1000


100%|██████████| 38159/38159 [00:55<00:00, 684.67it/s]


Train loss: 0.10593503298499257
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1259.27it/s]


Looking for threshold


100%|██████████| 91/91 [00:57<00:00,  1.59it/s]


Best threshold: 0.27
Validation Metrics - Epoch 10/1000:
Loss      :0.1172
Accuracy:  0.9558
Precision: 0.9430
Recall:    0.9700
F1-score:  0.9563
ROC-AUC:   0.9934
Confusion Matrix:
147734 4574
8929 144031
Epoch 11/1000


100%|██████████| 38159/38159 [00:58<00:00, 651.02it/s]


Train loss: 0.10301148380399014
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1296.57it/s]


Looking for threshold


100%|██████████| 91/91 [00:57<00:00,  1.58it/s]


Best threshold: 0.5600000000000002
Validation Metrics - Epoch 11/1000:
Loss      :0.1083
Accuracy:  0.9562
Precision: 0.9474
Recall:    0.9659
F1-score:  0.9565
ROC-AUC:   0.9932
Confusion Matrix:
147111 5197
8172 144788
Epoch 12/1000


100%|██████████| 38159/38159 [00:56<00:00, 678.58it/s]


Train loss: 0.09961143990491741
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1304.77it/s]


Looking for threshold


100%|██████████| 91/91 [00:58<00:00,  1.57it/s]


Best threshold: 0.4100000000000001
Validation Metrics - Epoch 12/1000:
Loss      :0.1020
Accuracy:  0.9571
Precision: 0.9492
Recall:    0.9658
F1-score:  0.9574
ROC-AUC:   0.9938
Confusion Matrix:
147095 5213
7879 145081
Epoch 13/1000


100%|██████████| 38159/38159 [00:59<00:00, 645.77it/s]


Train loss: 0.0968974452467597
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1284.14it/s]


Looking for threshold


100%|██████████| 91/91 [00:58<00:00,  1.57it/s]


Best threshold: 0.34
Validation Metrics - Epoch 13/1000:
Loss      :0.1109
Accuracy:  0.9544
Precision: 0.9415
Recall:    0.9689
F1-score:  0.9550
ROC-AUC:   0.9928
Confusion Matrix:
147573 4735
9177 143783
Epoch 14/1000


100%|██████████| 38159/38159 [00:57<00:00, 668.46it/s]


Train loss: 0.09730960528986558
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1271.55it/s]


Looking for threshold


100%|██████████| 91/91 [00:58<00:00,  1.55it/s]


Best threshold: 0.5500000000000002
Validation Metrics - Epoch 14/1000:
Loss      :0.0932
Accuracy:  0.9603
Precision: 0.9552
Recall:    0.9657
F1-score:  0.9604
ROC-AUC:   0.9946
Confusion Matrix:
147086 5222
6897 146063
Epoch 15/1000


100%|██████████| 38159/38159 [00:57<00:00, 668.88it/s]


Train loss: 0.09150800787394456
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1245.58it/s]


Looking for threshold


100%|██████████| 91/91 [00:59<00:00,  1.52it/s]


Best threshold: 0.37000000000000005
Validation Metrics - Epoch 15/1000:
Loss      :0.0929
Accuracy:  0.9610
Precision: 0.9529
Recall:    0.9697
F1-score:  0.9612
ROC-AUC:   0.9949
Confusion Matrix:
147693 4615
7295 145665
Epoch 16/1000


100%|██████████| 38159/38159 [00:56<00:00, 680.29it/s]


Train loss: 0.08970450926302913
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1241.78it/s]


Looking for threshold


100%|██████████| 91/91 [00:59<00:00,  1.52it/s]


Best threshold: 0.5800000000000002
Validation Metrics - Epoch 16/1000:
Loss      :0.0898
Accuracy:  0.9626
Precision: 0.9571
Recall:    0.9684
F1-score:  0.9627
ROC-AUC:   0.9953
Confusion Matrix:
147492 4816
6606 146354
Epoch 17/1000


100%|██████████| 38159/38159 [00:56<00:00, 671.66it/s]


Train loss: 0.08785655134426863
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1244.36it/s]


Looking for threshold


100%|██████████| 91/91 [00:59<00:00,  1.52it/s]


Best threshold: 0.4700000000000001
Validation Metrics - Epoch 17/1000:
Loss      :0.0851
Accuracy:  0.9639
Precision: 0.9587
Recall:    0.9694
F1-score:  0.9640
ROC-AUC:   0.9955
Confusion Matrix:
147645 4663
6366 146594
Epoch 18/1000


100%|██████████| 38159/38159 [00:56<00:00, 679.81it/s]


Train loss: 0.08605547958913616
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1239.85it/s]


Looking for threshold


100%|██████████| 91/91 [01:00<00:00,  1.52it/s]


Best threshold: 0.4800000000000001
Validation Metrics - Epoch 18/1000:
Loss      :0.0917
Accuracy:  0.9621
Precision: 0.9546
Recall:    0.9702
F1-score:  0.9624
ROC-AUC:   0.9949
Confusion Matrix:
147773 4535
7021 145939
Epoch 19/1000


100%|██████████| 38159/38159 [00:56<00:00, 672.08it/s]


Train loss: 0.0965257204631687
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1226.31it/s]


Looking for threshold


100%|██████████| 91/91 [01:01<00:00,  1.49it/s]


Best threshold: 0.4800000000000001
Validation Metrics - Epoch 19/1000:
Loss      :0.0886
Accuracy:  0.9630
Precision: 0.9606
Recall:    0.9655
F1-score:  0.9630
ROC-AUC:   0.9953
Confusion Matrix:
147048 5260
6025 146935
Epoch 20/1000


100%|██████████| 38159/38159 [01:01<00:00, 617.57it/s]


Train loss: 0.08287655979085258
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1119.66it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.45it/s]


Best threshold: 0.4000000000000001
Validation Metrics - Epoch 20/1000:
Loss      :0.0856
Accuracy:  0.9641
Precision: 0.9576
Recall:    0.9711
F1-score:  0.9643
ROC-AUC:   0.9957
Confusion Matrix:
147913 4395
6554 146406
Epoch 21/1000


100%|██████████| 38159/38159 [01:01<00:00, 617.17it/s]


Train loss: 0.08285902100995987
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1185.36it/s]


Looking for threshold


100%|██████████| 91/91 [01:01<00:00,  1.47it/s]


Best threshold: 0.4800000000000001
Validation Metrics - Epoch 21/1000:
Loss      :0.0851
Accuracy:  0.9643
Precision: 0.9579
Recall:    0.9711
F1-score:  0.9645
ROC-AUC:   0.9957
Confusion Matrix:
147913 4395
6504 146456
Epoch 22/1000


100%|██████████| 38159/38159 [00:56<00:00, 673.56it/s]


Train loss: 0.08217395528504623
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1254.64it/s]


Looking for threshold


100%|██████████| 91/91 [00:59<00:00,  1.53it/s]


Best threshold: 0.5000000000000001
Validation Metrics - Epoch 22/1000:
Loss      :0.0853
Accuracy:  0.9635
Precision: 0.9575
Recall:    0.9700
F1-score:  0.9637
ROC-AUC:   0.9956
Confusion Matrix:
147733 4575
6555 146405
Epoch 23/1000


100%|██████████| 38159/38159 [00:59<00:00, 640.08it/s]


Train loss: 0.07947315720555104
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1206.31it/s]


Looking for threshold


100%|██████████| 91/91 [01:01<00:00,  1.48it/s]


Best threshold: 0.27
Validation Metrics - Epoch 23/1000:
Loss      :0.0946
Accuracy:  0.9621
Precision: 0.9494
Recall:    0.9761
F1-score:  0.9626
ROC-AUC:   0.9953
Confusion Matrix:
148665 3643
7923 145037
Epoch 24/1000


100%|██████████| 38159/38159 [00:59<00:00, 641.82it/s]


Train loss: 0.07885479689333491
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1143.30it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.46it/s]


Best threshold: 0.4700000000000001
Validation Metrics - Epoch 24/1000:
Loss      :0.0771
Accuracy:  0.9665
Precision: 0.9610
Recall:    0.9723
F1-score:  0.9666
ROC-AUC:   0.9964
Confusion Matrix:
148082 4226
6013 146947
Epoch 25/1000


100%|██████████| 38159/38159 [00:58<00:00, 656.86it/s]


Train loss: 0.07714929671624805
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1229.70it/s]


Looking for threshold


100%|██████████| 91/91 [01:00<00:00,  1.50it/s]


Best threshold: 0.6000000000000002
Validation Metrics - Epoch 25/1000:
Loss      :0.0893
Accuracy:  0.9630
Precision: 0.9629
Recall:    0.9629
F1-score:  0.9629
ROC-AUC:   0.9955
Confusion Matrix:
146658 5650
5644 147316
Epoch 26/1000


100%|██████████| 38159/38159 [00:58<00:00, 654.39it/s]


Train loss: 0.07739456578860708
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1083.05it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.45it/s]


Best threshold: 0.5200000000000001
Validation Metrics - Epoch 26/1000:
Loss      :0.0798
Accuracy:  0.9653
Precision: 0.9602
Recall:    0.9706
F1-score:  0.9654
ROC-AUC:   0.9961
Confusion Matrix:
147836 4472
6133 146827
Epoch 27/1000


100%|██████████| 38159/38159 [01:01<00:00, 617.42it/s]


Train loss: 0.07678503020668065
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1196.49it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.46it/s]


Best threshold: 0.5800000000000002
Validation Metrics - Epoch 27/1000:
Loss      :0.0764
Accuracy:  0.9678
Precision: 0.9663
Recall:    0.9692
F1-score:  0.9678
ROC-AUC:   0.9965
Confusion Matrix:
147622 4686
5152 147808
Epoch 28/1000


100%|██████████| 38159/38159 [00:57<00:00, 662.59it/s]


Train loss: 0.07651961018671025
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1223.97it/s]


Looking for threshold


100%|██████████| 91/91 [01:01<00:00,  1.48it/s]


Best threshold: 0.4100000000000001
Validation Metrics - Epoch 28/1000:
Loss      :0.0800
Accuracy:  0.9668
Precision: 0.9631
Recall:    0.9708
F1-score:  0.9669
ROC-AUC:   0.9965
Confusion Matrix:
147860 4448
5673 147287
Epoch 29/1000


100%|██████████| 38159/38159 [01:02<00:00, 605.88it/s]


Train loss: 0.0740357009724299
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1151.36it/s]


Looking for threshold


100%|██████████| 91/91 [01:01<00:00,  1.48it/s]


Best threshold: 0.44000000000000006
Validation Metrics - Epoch 29/1000:
Loss      :0.0773
Accuracy:  0.9668
Precision: 0.9634
Recall:    0.9704
F1-score:  0.9669
ROC-AUC:   0.9965
Confusion Matrix:
147799 4509
5618 147342
Epoch 30/1000


100%|██████████| 38159/38159 [01:00<00:00, 626.55it/s]


Train loss: 0.07436752606853933
Validating...


100%|██████████| 9540/9540 [00:07<00:00, 1231.26it/s]


Looking for threshold


100%|██████████| 91/91 [01:01<00:00,  1.47it/s]


Best threshold: 0.5600000000000002
Validation Metrics - Epoch 30/1000:
Loss      :0.0760
Accuracy:  0.9678
Precision: 0.9664
Recall:    0.9691
F1-score:  0.9678
ROC-AUC:   0.9965
Confusion Matrix:
147605 4703
5134 147826
Epoch 31/1000


100%|██████████| 38159/38159 [01:01<00:00, 623.79it/s]


Train loss: 0.08029440404135069
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1102.65it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.44it/s]


Best threshold: 0.4800000000000001
Validation Metrics - Epoch 31/1000:
Loss      :0.0762
Accuracy:  0.9672
Precision: 0.9611
Recall:    0.9737
F1-score:  0.9674
ROC-AUC:   0.9965
Confusion Matrix:
148306 4002
5999 146961
Epoch 32/1000


100%|██████████| 38159/38159 [01:02<00:00, 605.70it/s]


Train loss: 0.07268573008730145
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1081.92it/s]


Looking for threshold


100%|██████████| 91/91 [01:03<00:00,  1.44it/s]


Best threshold: 0.4700000000000001
Validation Metrics - Epoch 32/1000:
Loss      :0.0729
Accuracy:  0.9688
Precision: 0.9650
Recall:    0.9727
F1-score:  0.9689
ROC-AUC:   0.9968
Confusion Matrix:
148154 4154
5369 147591
Epoch 33/1000


100%|██████████| 38159/38159 [01:01<00:00, 615.73it/s]


Train loss: 0.07247975791930106
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1123.01it/s]


Looking for threshold


100%|██████████| 91/91 [01:03<00:00,  1.44it/s]


Best threshold: 0.6100000000000001
Validation Metrics - Epoch 33/1000:
Loss      :0.0744
Accuracy:  0.9688
Precision: 0.9656
Recall:    0.9721
F1-score:  0.9688
ROC-AUC:   0.9968
Confusion Matrix:
148055 4253
5271 147689
Epoch 34/1000


100%|██████████| 38159/38159 [01:02<00:00, 612.45it/s]


Train loss: 0.07038166533356705
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1169.74it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.45it/s]


Best threshold: 0.39000000000000007
Validation Metrics - Epoch 34/1000:
Loss      :0.0766
Accuracy:  0.9691
Precision: 0.9639
Recall:    0.9745
F1-score:  0.9692
ROC-AUC:   0.9967
Confusion Matrix:
148419 3889
5555 147405
Epoch 35/1000


100%|██████████| 38159/38159 [01:00<00:00, 634.92it/s]


Train loss: 0.06943741707450776
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1154.80it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.46it/s]


Best threshold: 0.5500000000000002
Validation Metrics - Epoch 35/1000:
Loss      :0.0781
Accuracy:  0.9686
Precision: 0.9628
Recall:    0.9746
F1-score:  0.9687
ROC-AUC:   0.9968
Confusion Matrix:
148439 3869
5728 147232
Epoch 36/1000


100%|██████████| 38159/38159 [01:04<00:00, 593.09it/s]


Train loss: 0.06927580010977333
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1064.03it/s]


Looking for threshold


100%|██████████| 91/91 [01:03<00:00,  1.43it/s]


Best threshold: 0.44000000000000006
Validation Metrics - Epoch 36/1000:
Loss      :0.0747
Accuracy:  0.9684
Precision: 0.9667
Recall:    0.9702
F1-score:  0.9684
ROC-AUC:   0.9968
Confusion Matrix:
147767 4541
5092 147868
Epoch 37/1000


100%|██████████| 38159/38159 [01:00<00:00, 629.25it/s]


Train loss: 0.06885708972096022
Validating...


100%|██████████| 9540/9540 [00:09<00:00, 1049.00it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.41it/s]


Best threshold: 0.5800000000000002
Validation Metrics - Epoch 37/1000:
Loss      :0.0760
Accuracy:  0.9702
Precision: 0.9692
Recall:    0.9711
F1-score:  0.9701
ROC-AUC:   0.9971
Confusion Matrix:
147900 4408
4701 148259
Epoch 38/1000


100%|██████████| 38159/38159 [01:05<00:00, 584.47it/s]


Train loss: 0.06975963740306863
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1091.70it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.41it/s]


Best threshold: 0.5800000000000002
Validation Metrics - Epoch 38/1000:
Loss      :0.0888
Accuracy:  0.9697
Precision: 0.9678
Recall:    0.9716
F1-score:  0.9697
ROC-AUC:   0.9968
Confusion Matrix:
147977 4331
4919 148041
Epoch 39/1000


100%|██████████| 38159/38159 [01:04<00:00, 592.58it/s]


Train loss: 0.06906203810246014
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1080.05it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.41it/s]


Best threshold: 0.6300000000000001
Validation Metrics - Epoch 39/1000:
Loss      :0.0697
Accuracy:  0.9713
Precision: 0.9719
Recall:    0.9706
F1-score:  0.9712
ROC-AUC:   0.9972
Confusion Matrix:
147826 4482
4279 148681
Epoch 40/1000


100%|██████████| 38159/38159 [01:04<00:00, 592.06it/s]


Train loss: 0.09009703240485026
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1180.35it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.47it/s]


Best threshold: 0.6400000000000001
Validation Metrics - Epoch 40/1000:
Loss      :0.0755
Accuracy:  0.9694
Precision: 0.9683
Recall:    0.9704
F1-score:  0.9694
ROC-AUC:   0.9969
Confusion Matrix:
147796 4512
4831 148129
Epoch 41/1000


100%|██████████| 38159/38159 [01:05<00:00, 583.49it/s]


Train loss: 0.06582678559123838
Validating...


100%|██████████| 9540/9540 [00:09<00:00, 1019.48it/s]


Looking for threshold


100%|██████████| 91/91 [01:05<00:00,  1.39it/s]


Best threshold: 0.5800000000000002
Validation Metrics - Epoch 41/1000:
Loss      :0.0677
Accuracy:  0.9713
Precision: 0.9691
Recall:    0.9734
F1-score:  0.9713
ROC-AUC:   0.9973
Confusion Matrix:
148262 4046
4724 148236
Epoch 42/1000


100%|██████████| 38159/38159 [01:05<00:00, 586.06it/s]


Train loss: 0.06605293771029118
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1086.38it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.40it/s]


Best threshold: 0.5400000000000001
Validation Metrics - Epoch 42/1000:
Loss      :0.0703
Accuracy:  0.9707
Precision: 0.9670
Recall:    0.9744
F1-score:  0.9707
ROC-AUC:   0.9971
Confusion Matrix:
148412 3896
5063 147897
Epoch 43/1000


100%|██████████| 38159/38159 [01:04<00:00, 593.13it/s]


Train loss: 0.06552112833867602
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1097.54it/s]


Looking for threshold


100%|██████████| 91/91 [01:05<00:00,  1.39it/s]


Best threshold: 0.44000000000000006
Validation Metrics - Epoch 43/1000:
Loss      :0.0701
Accuracy:  0.9707
Precision: 0.9685
Recall:    0.9728
F1-score:  0.9707
ROC-AUC:   0.9972
Confusion Matrix:
148164 4144
4815 148145
Epoch 44/1000


100%|██████████| 38159/38159 [01:03<00:00, 601.35it/s]


Train loss: 0.06579948519968379
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1106.27it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.40it/s]


Best threshold: 0.5100000000000001
Validation Metrics - Epoch 44/1000:
Loss      :0.0715
Accuracy:  0.9705
Precision: 0.9683
Recall:    0.9728
F1-score:  0.9705
ROC-AUC:   0.9971
Confusion Matrix:
148159 4149
4845 148115
Epoch 45/1000


100%|██████████| 38159/38159 [01:05<00:00, 585.78it/s]


Train loss: 0.06492287557020024
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1082.64it/s]


Looking for threshold


100%|██████████| 91/91 [01:05<00:00,  1.39it/s]


Best threshold: 0.49000000000000005
Validation Metrics - Epoch 45/1000:
Loss      :0.0665
Accuracy:  0.9721
Precision: 0.9700
Recall:    0.9743
F1-score:  0.9721
ROC-AUC:   0.9974
Confusion Matrix:
148392 3916
4587 148373
Epoch 46/1000


100%|██████████| 38159/38159 [01:04<00:00, 587.26it/s]


Train loss: 0.0630382210488664
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1094.10it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.41it/s]


Best threshold: 0.4700000000000001
Validation Metrics - Epoch 46/1000:
Loss      :0.0696
Accuracy:  0.9718
Precision: 0.9731
Recall:    0.9702
F1-score:  0.9716
ROC-AUC:   0.9973
Confusion Matrix:
147766 4542
4081 148879
Epoch 47/1000


100%|██████████| 38159/38159 [01:04<00:00, 596.22it/s]


Train loss: 0.06260764001718394
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1118.11it/s]


Looking for threshold


100%|██████████| 91/91 [01:05<00:00,  1.39it/s]


Best threshold: 0.4700000000000001
Validation Metrics - Epoch 47/1000:
Loss      :0.0667
Accuracy:  0.9722
Precision: 0.9704
Recall:    0.9741
F1-score:  0.9722
ROC-AUC:   0.9974
Confusion Matrix:
148356 3952
4520 148440
Epoch 48/1000


100%|██████████| 38159/38159 [01:04<00:00, 590.09it/s]


Train loss: 0.06265265038580509
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1103.20it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.41it/s]


Best threshold: 0.5100000000000001
Validation Metrics - Epoch 48/1000:
Loss      :0.0661
Accuracy:  0.9725
Precision: 0.9721
Recall:    0.9729
F1-score:  0.9725
ROC-AUC:   0.9975
Confusion Matrix:
148178 4130
4260 148700
Epoch 49/1000


100%|██████████| 38159/38159 [01:01<00:00, 625.45it/s]


Train loss: 0.06191763182400536
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1081.10it/s]


Looking for threshold


100%|██████████| 91/91 [01:05<00:00,  1.40it/s]


Best threshold: 0.38000000000000006
Validation Metrics - Epoch 49/1000:
Loss      :0.0775
Accuracy:  0.9724
Precision: 0.9712
Recall:    0.9736
F1-score:  0.9724
ROC-AUC:   0.9970
Confusion Matrix:
148282 4026
4404 148556
Epoch 50/1000


100%|██████████| 38159/38159 [01:06<00:00, 576.88it/s]


Train loss: 0.06310621313913567
Validating...


100%|██████████| 9540/9540 [00:09<00:00, 1049.26it/s]


Looking for threshold


100%|██████████| 91/91 [01:05<00:00,  1.39it/s]


Best threshold: 0.44000000000000006
Validation Metrics - Epoch 50/1000:
Loss      :0.0649
Accuracy:  0.9725
Precision: 0.9683
Recall:    0.9768
F1-score:  0.9725
ROC-AUC:   0.9975
Confusion Matrix:
148768 3540
4865 148095
Epoch 51/1000


100%|██████████| 38159/38159 [01:06<00:00, 577.90it/s]


Train loss: 0.06149509567260918
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1101.78it/s]


Looking for threshold


100%|██████████| 91/91 [01:03<00:00,  1.43it/s]


Best threshold: 0.44000000000000006
Validation Metrics - Epoch 51/1000:
Loss      :0.0631
Accuracy:  0.9742
Precision: 0.9723
Recall:    0.9762
F1-score:  0.9742
ROC-AUC:   0.9977
Confusion Matrix:
148688 3620
4242 148718
Epoch 52/1000


100%|██████████| 38159/38159 [01:05<00:00, 579.37it/s]


Train loss: 0.06010500495009072
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1131.45it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.47it/s]


Best threshold: 0.5300000000000001
Validation Metrics - Epoch 52/1000:
Loss      :0.0623
Accuracy:  0.9743
Precision: 0.9741
Recall:    0.9745
F1-score:  0.9743
ROC-AUC:   0.9977
Confusion Matrix:
148430 3878
3953 149007
Epoch 53/1000


100%|██████████| 38159/38159 [01:01<00:00, 624.08it/s]


Train loss: 0.06101197458030553
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1153.20it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.45it/s]


Best threshold: 0.6400000000000001
Validation Metrics - Epoch 53/1000:
Loss      :0.0679
Accuracy:  0.9728
Precision: 0.9734
Recall:    0.9719
F1-score:  0.9727
ROC-AUC:   0.9975
Confusion Matrix:
148033 4275
4040 148920
Epoch 54/1000


100%|██████████| 38159/38159 [01:05<00:00, 582.61it/s]


Train loss: 0.058870956363779474
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1085.21it/s]


Looking for threshold


100%|██████████| 91/91 [01:03<00:00,  1.43it/s]


Best threshold: 0.5700000000000002
Validation Metrics - Epoch 54/1000:
Loss      :0.0625
Accuracy:  0.9743
Precision: 0.9752
Recall:    0.9733
F1-score:  0.9743
ROC-AUC:   0.9978
Confusion Matrix:
148245 4063
3771 149189
Epoch 55/1000


100%|██████████| 38159/38159 [01:00<00:00, 635.42it/s]


Train loss: 0.05882553122429379
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1151.40it/s]


Looking for threshold


100%|██████████| 91/91 [01:02<00:00,  1.47it/s]


Best threshold: 0.5800000000000002
Validation Metrics - Epoch 55/1000:
Loss      :0.0656
Accuracy:  0.9740
Precision: 0.9724
Recall:    0.9756
F1-score:  0.9740
ROC-AUC:   0.9976
Confusion Matrix:
148596 3712
4210 148750
Epoch 56/1000


100%|██████████| 38159/38159 [00:58<00:00, 650.21it/s]


Train loss: 0.05793177091239318
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1168.14it/s]


Looking for threshold


100%|██████████| 91/91 [01:04<00:00,  1.42it/s]


Best threshold: 0.37000000000000005
Validation Metrics - Epoch 56/1000:
Loss      :0.0651
Accuracy:  0.9737
Precision: 0.9701
Recall:    0.9775
F1-score:  0.9738
ROC-AUC:   0.9977
Confusion Matrix:
148881 3427
4592 148368
Epoch 57/1000


100%|██████████| 38159/38159 [01:03<00:00, 600.31it/s]


Train loss: 0.05789980548284572
Validating...


100%|██████████| 9540/9540 [00:08<00:00, 1128.59it/s]


Looking for threshold


100%|██████████| 91/91 [01:03<00:00,  1.44it/s]


Best threshold: 0.4700000000000001
Validation Metrics - Epoch 57/1000:
Loss      :0.0635
Accuracy:  0.9740
Precision: 0.9717
Recall:    0.9763
F1-score:  0.9740
ROC-AUC:   0.9977
Confusion Matrix:
148694 3614
4323 148637
Early stopping!!!
Early stopping!!!
Early stopping!!!
Best epoch: 52


In [13]:
test(
    model,
    test_loader,
    device,
    criterion,
    best_threshold
)

100%|██████████| 9799/9799 [00:08<00:00, 1142.08it/s]


Test Metrics:
Accuracy:  0.7824
Precision: 0.8398
Recall:    0.6981
F1-score:  0.7624
ROC-AUC:   0.8032
Confusion Matrix:
109441 47331
20881 135891
Test Loss: 11.1693
