In [1]:
'''
challenge: resnet input: 3 channels images, the mri images: 20 channels
-> 1 layer (conv2d) to transform 16 to 3 berfore resnet model

challenge: using pretrained resent50 model and update all parameters (aksh-ai)
or just fine tune last layer
-> update all parameters
challenge: resnet-50 transform normalizing is not applicable
we have gray scale 18 channel data 
'''


'\nchallenge: resnet input: 3 channels images, the mri images: 20 channels\n-> 1 layer (conv2d) to transform 16 to 3 berfore resnet model\n\nchallenge: using pretrained resent50 model and update all parameters (aksh-ai)\nor just fine tune last layer\n-> update all parameters\nchallenge: resnet-50 transform normalizing is not applicable\nwe have gray scale 18 channel data \n'

In [2]:
import os

import numpy as np
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

from config import Config, Device
from datasets import MRIDataset
from models import MriResentModel
from train import Trainer
from test_file import Tester

In [3]:
device = Device.device
print(device)

mps


In [4]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "train.csv"

batch_size = Config.batch_size
num_epochs = Config.num_epochs
learning_rate = Config.learning_rate
image_size = 256

In [5]:
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((image_size, image_size)),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [6]:
train_dataset = MRIDataset(
    data_path, labels_path, split="train", transform=train_transforms, max_slices=20)
    
val_dataset = MRIDataset(data_path, labels_path, split="val",
                         transform=test_transforms, max_slices=20)
test_dataset = MRIDataset(
    data_path, labels_path, split="test", transform=test_transforms, max_slices=20)


train_dl = DataLoader(train_dataset, batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size)
test_dl = DataLoader(test_dataset, batch_size)

In [7]:
model = MriResentModel(20,1).to(device)



In [8]:
def compute_class_weights_from_csv(csv_file_path):
    df = pd.read_csv(csv_file_path)

    labels = df['prediction'].values

    labels = labels.astype(int)

    # Compute class weights
    unique_labels = np.unique(labels)
    class_weights = compute_class_weight(
        class_weight='balanced', classes=unique_labels, y=labels)

    # Convert to torch tensor
    return torch.tensor(class_weights, dtype=torch.float)


class_weights = compute_class_weights_from_csv(labels_path)
print(class_weights)
# For binary classification, use the appropriate class weight
# Assuming binary classification with class labels 0 and 1
class_weights = class_weights[1]  # Adjust if necessary
print("Class Weights:", class_weights)

tensor([0.5713, 4.0051])
Class Weights: tensor(4.0051)


In [9]:
class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, class_weights, device, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error
        self.device = device
        
        # Ensure class_weights is a tensor and moved to the correct device
        self.class_weights = class_weights.clone().detach().to(self.device) if class_weights is not None else None

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss
        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt


        if self.class_weights is not None:
            focal_loss = focal_loss * self.class_weights
        
        return torch.mean(focal_loss)


In [10]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

In [11]:
model_name = model.__class__.__name__
model_name

'MriResentModel'

In [12]:
criterion = nn.BCEWithLogitsLoss().to(device)

trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_dl=train_dl,
    val_dl=val_dl,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    device=device,
    num_epochs=100,
    patience=50,
    threshold=0.5,
    save_path=f"saved_models/{model_name}.pth"
)

trainer.train()

100%|██████████| 59/59 [02:49<00:00,  2.87s/it]


Confusion Matrix:
[[548   0]
 [ 78   0]]
Epoch 1/100, Train Loss: 25.5248, Train Accuracy: 0.8638
Epoch 1/100, Val Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.5955, Avg Metric: 0.1985


100%|██████████| 59/59 [02:59<00:00,  3.04s/it]


Confusion Matrix:
[[548   0]
 [ 78   0]]
Epoch 2/100, Train Loss: 22.1772, Train Accuracy: 0.8739
Epoch 2/100, Val Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.6015, Avg Metric: 0.2005


100%|██████████| 59/59 [02:49<00:00,  2.88s/it]


Confusion Matrix:
[[541   7]
 [ 75   3]]
Epoch 3/100, Train Loss: 22.6883, Train Accuracy: 0.8718
Epoch 3/100, Val Accuracy: 0.8690, Precision: 0.3000, Recall: 0.0385, AUC: 0.6329, Avg Metric: 0.3238


100%|██████████| 59/59 [02:45<00:00,  2.81s/it]


Confusion Matrix:
[[548   0]
 [ 78   0]]
Epoch 4/100, Train Loss: 21.4023, Train Accuracy: 0.8686
Epoch 4/100, Val Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.6108, Avg Metric: 0.2036


100%|██████████| 59/59 [02:44<00:00,  2.80s/it]


Confusion Matrix:
[[300 248]
 [ 24  54]]
Epoch 5/100, Train Loss: 20.4392, Train Accuracy: 0.8739
Epoch 5/100, Val Accuracy: 0.5655, Precision: 0.1788, Recall: 0.6923, AUC: 0.6463, Avg Metric: 0.5058


100%|██████████| 59/59 [02:42<00:00,  2.75s/it]


Confusion Matrix:
[[548   0]
 [ 78   0]]
Epoch 6/100, Train Loss: 19.5426, Train Accuracy: 0.8761
Epoch 6/100, Val Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.6366, Avg Metric: 0.2122


100%|██████████| 59/59 [02:42<00:00,  2.75s/it]


Confusion Matrix:
[[535  13]
 [ 73   5]]
Epoch 7/100, Train Loss: 18.6146, Train Accuracy: 0.8862
Epoch 7/100, Val Accuracy: 0.8626, Precision: 0.2778, Recall: 0.0641, AUC: 0.6496, Avg Metric: 0.3305


100%|██████████| 59/59 [02:40<00:00,  2.73s/it]


Confusion Matrix:
[[476  72]
 [ 64  14]]
Epoch 8/100, Train Loss: 15.7454, Train Accuracy: 0.8936
Epoch 8/100, Val Accuracy: 0.7827, Precision: 0.1628, Recall: 0.1795, AUC: 0.6016, Avg Metric: 0.3146


100%|██████████| 59/59 [02:40<00:00,  2.72s/it]


Confusion Matrix:
[[546   2]
 [ 76   2]]
Epoch 9/100, Train Loss: 14.7121, Train Accuracy: 0.9032
Epoch 9/100, Val Accuracy: 0.8754, Precision: 0.5000, Recall: 0.0256, AUC: 0.5545, Avg Metric: 0.3600


100%|██████████| 59/59 [02:40<00:00,  2.72s/it]


Confusion Matrix:
[[523  25]
 [ 71   7]]
Epoch 10/100, Train Loss: 13.4269, Train Accuracy: 0.9059
Epoch 10/100, Val Accuracy: 0.8466, Precision: 0.2188, Recall: 0.0897, AUC: 0.6368, Avg Metric: 0.3151


100%|██████████| 59/59 [02:40<00:00,  2.72s/it]


Confusion Matrix:
[[533  15]
 [ 75   3]]
Epoch 11/100, Train Loss: 11.5894, Train Accuracy: 0.9223
Epoch 11/100, Val Accuracy: 0.8562, Precision: 0.1667, Recall: 0.0385, AUC: 0.6046, Avg Metric: 0.2699


100%|██████████| 59/59 [02:40<00:00,  2.72s/it]


Confusion Matrix:
[[534  14]
 [ 73   5]]
Epoch 12/100, Train Loss: 11.7959, Train Accuracy: 0.9245
Epoch 12/100, Val Accuracy: 0.8610, Precision: 0.2632, Recall: 0.0641, AUC: 0.5457, Avg Metric: 0.2910


100%|██████████| 59/59 [02:39<00:00,  2.71s/it]


Confusion Matrix:
[[524  24]
 [ 69   9]]
Epoch 13/100, Train Loss: 10.4826, Train Accuracy: 0.9378
Epoch 13/100, Val Accuracy: 0.8514, Precision: 0.2727, Recall: 0.1154, AUC: 0.6536, Avg Metric: 0.3472


100%|██████████| 59/59 [02:39<00:00,  2.70s/it]


Confusion Matrix:
[[542   6]
 [ 75   3]]
Epoch 14/100, Train Loss: 9.0263, Train Accuracy: 0.9378
Epoch 14/100, Val Accuracy: 0.8706, Precision: 0.3333, Recall: 0.0385, AUC: 0.5901, Avg Metric: 0.3206


100%|██████████| 59/59 [02:39<00:00,  2.70s/it]


Confusion Matrix:
[[423 125]
 [ 54  24]]
Epoch 15/100, Train Loss: 8.6091, Train Accuracy: 0.9527
Epoch 15/100, Val Accuracy: 0.7141, Precision: 0.1611, Recall: 0.3077, AUC: 0.5817, Avg Metric: 0.3501


100%|██████████| 59/59 [02:41<00:00,  2.74s/it]


Confusion Matrix:
[[436 112]
 [ 55  23]]
Epoch 16/100, Train Loss: 7.4380, Train Accuracy: 0.9532
Epoch 16/100, Val Accuracy: 0.7332, Precision: 0.1704, Recall: 0.2949, AUC: 0.5753, Avg Metric: 0.3468


100%|██████████| 59/59 [02:49<00:00,  2.87s/it]


Confusion Matrix:
[[541   7]
 [ 76   2]]
Epoch 17/100, Train Loss: 7.6423, Train Accuracy: 0.9521
Epoch 17/100, Val Accuracy: 0.8674, Precision: 0.2222, Recall: 0.0256, AUC: 0.5979, Avg Metric: 0.2819


100%|██████████| 59/59 [02:49<00:00,  2.87s/it]


Confusion Matrix:
[[496  52]
 [ 69   9]]
Epoch 18/100, Train Loss: 5.5409, Train Accuracy: 0.9612
Epoch 18/100, Val Accuracy: 0.8067, Precision: 0.1475, Recall: 0.1154, AUC: 0.5877, Avg Metric: 0.2835


100%|██████████| 59/59 [02:48<00:00,  2.86s/it]


Confusion Matrix:
[[504  44]
 [ 63  15]]
Epoch 19/100, Train Loss: 5.1252, Train Accuracy: 0.9697
Epoch 19/100, Val Accuracy: 0.8291, Precision: 0.2542, Recall: 0.1923, AUC: 0.5874, Avg Metric: 0.3446


100%|██████████| 59/59 [02:49<00:00,  2.87s/it]


Confusion Matrix:
[[354 194]
 [ 49  29]]
Epoch 20/100, Train Loss: 4.3034, Train Accuracy: 0.9734
Epoch 20/100, Val Accuracy: 0.6118, Precision: 0.1300, Recall: 0.3718, AUC: 0.5849, Avg Metric: 0.3622


100%|██████████| 59/59 [02:48<00:00,  2.86s/it]


Confusion Matrix:
[[516  32]
 [ 70   8]]
Epoch 21/100, Train Loss: 5.4088, Train Accuracy: 0.9606
Epoch 21/100, Val Accuracy: 0.8371, Precision: 0.2000, Recall: 0.1026, AUC: 0.5997, Avg Metric: 0.3008


100%|██████████| 59/59 [02:48<00:00,  2.86s/it]


Confusion Matrix:
[[499  49]
 [ 70   8]]
Epoch 22/100, Train Loss: 5.0036, Train Accuracy: 0.9660
Epoch 22/100, Val Accuracy: 0.8099, Precision: 0.1404, Recall: 0.1026, AUC: 0.5907, Avg Metric: 0.2779


100%|██████████| 59/59 [02:49<00:00,  2.87s/it]


Confusion Matrix:
[[515  33]
 [ 69   9]]
Epoch 23/100, Train Loss: 3.1486, Train Accuracy: 0.9824
Epoch 23/100, Val Accuracy: 0.8371, Precision: 0.2143, Recall: 0.1154, AUC: 0.6356, Avg Metric: 0.3217


100%|██████████| 59/59 [02:48<00:00,  2.85s/it]


Confusion Matrix:
[[534  14]
 [ 75   3]]
Epoch 24/100, Train Loss: 3.5631, Train Accuracy: 0.9766
Epoch 24/100, Val Accuracy: 0.8578, Precision: 0.1765, Recall: 0.0385, AUC: 0.6288, Avg Metric: 0.2812


100%|██████████| 59/59 [02:49<00:00,  2.87s/it]


Confusion Matrix:
[[341 207]
 [ 42  36]]
Epoch 25/100, Train Loss: 2.6217, Train Accuracy: 0.9824
Epoch 25/100, Val Accuracy: 0.6022, Precision: 0.1481, Recall: 0.4615, AUC: 0.5810, Avg Metric: 0.3969


100%|██████████| 59/59 [02:47<00:00,  2.84s/it]


Confusion Matrix:
[[508  40]
 [ 68  10]]
Epoch 26/100, Train Loss: 3.9106, Train Accuracy: 0.9766
Epoch 26/100, Val Accuracy: 0.8275, Precision: 0.2000, Recall: 0.1282, AUC: 0.6034, Avg Metric: 0.3105


100%|██████████| 59/59 [02:46<00:00,  2.83s/it]


Confusion Matrix:
[[530  18]
 [ 72   6]]
Epoch 27/100, Train Loss: 2.1412, Train Accuracy: 0.9867
Epoch 27/100, Val Accuracy: 0.8562, Precision: 0.2500, Recall: 0.0769, AUC: 0.5364, Avg Metric: 0.2878


100%|██████████| 59/59 [02:47<00:00,  2.84s/it]


Confusion Matrix:
[[529  19]
 [ 75   3]]
Epoch 28/100, Train Loss: 3.3405, Train Accuracy: 0.9814
Epoch 28/100, Val Accuracy: 0.8498, Precision: 0.1364, Recall: 0.0385, AUC: 0.5344, Avg Metric: 0.2364


100%|██████████| 59/59 [02:46<00:00,  2.82s/it]


Confusion Matrix:
[[472  76]
 [ 60  18]]
Epoch 29/100, Train Loss: 2.8875, Train Accuracy: 0.9809
Epoch 29/100, Val Accuracy: 0.7827, Precision: 0.1915, Recall: 0.2308, AUC: 0.6118, Avg Metric: 0.3447


100%|██████████| 59/59 [02:45<00:00,  2.80s/it]


Confusion Matrix:
[[532  16]
 [ 73   5]]
Epoch 30/100, Train Loss: 2.2915, Train Accuracy: 0.9851
Epoch 30/100, Val Accuracy: 0.8578, Precision: 0.2381, Recall: 0.0641, AUC: 0.6116, Avg Metric: 0.3046


100%|██████████| 59/59 [02:44<00:00,  2.79s/it]


Confusion Matrix:
[[513  35]
 [ 68  10]]
Epoch 31/100, Train Loss: 2.7067, Train Accuracy: 0.9846
Epoch 31/100, Val Accuracy: 0.8355, Precision: 0.2222, Recall: 0.1282, AUC: 0.5865, Avg Metric: 0.3123


100%|██████████| 59/59 [02:45<00:00,  2.80s/it]


Confusion Matrix:
[[519  29]
 [ 66  12]]
Epoch 32/100, Train Loss: 1.6604, Train Accuracy: 0.9894
Epoch 32/100, Val Accuracy: 0.8482, Precision: 0.2927, Recall: 0.1538, AUC: 0.6133, Avg Metric: 0.3533


100%|██████████| 59/59 [02:45<00:00,  2.80s/it]


Confusion Matrix:
[[536  12]
 [ 76   2]]
Epoch 33/100, Train Loss: 1.9249, Train Accuracy: 0.9867
Epoch 33/100, Val Accuracy: 0.8594, Precision: 0.1429, Recall: 0.0256, AUC: 0.5667, Avg Metric: 0.2451


100%|██████████| 59/59 [02:45<00:00,  2.80s/it]


Confusion Matrix:
[[526  22]
 [ 70   8]]
Epoch 34/100, Train Loss: 1.4289, Train Accuracy: 0.9904
Epoch 34/100, Val Accuracy: 0.8530, Precision: 0.2667, Recall: 0.1026, AUC: 0.6005, Avg Metric: 0.3232


100%|██████████| 59/59 [02:45<00:00,  2.80s/it]


Confusion Matrix:
[[500  48]
 [ 60  18]]
Epoch 35/100, Train Loss: 1.9910, Train Accuracy: 0.9867
Epoch 35/100, Val Accuracy: 0.8275, Precision: 0.2727, Recall: 0.2308, AUC: 0.6143, Avg Metric: 0.3726


100%|██████████| 59/59 [02:46<00:00,  2.81s/it]


Confusion Matrix:
[[528  20]
 [ 71   7]]
Epoch 36/100, Train Loss: 3.2035, Train Accuracy: 0.9761
Epoch 36/100, Val Accuracy: 0.8546, Precision: 0.2593, Recall: 0.0897, AUC: 0.5714, Avg Metric: 0.3068


100%|██████████| 59/59 [02:35<00:00,  2.64s/it]


Confusion Matrix:
[[504  44]
 [ 67  11]]
Epoch 37/100, Train Loss: 1.5241, Train Accuracy: 0.9926
Epoch 37/100, Val Accuracy: 0.8227, Precision: 0.2000, Recall: 0.1410, AUC: 0.5765, Avg Metric: 0.3058


100%|██████████| 59/59 [02:35<00:00,  2.64s/it]


Confusion Matrix:
[[531  17]
 [ 73   5]]
Epoch 38/100, Train Loss: 1.3521, Train Accuracy: 0.9904
Epoch 38/100, Val Accuracy: 0.8562, Precision: 0.2273, Recall: 0.0641, AUC: 0.5927, Avg Metric: 0.2947


100%|██████████| 59/59 [02:35<00:00,  2.64s/it]


Confusion Matrix:
[[527  21]
 [ 68  10]]
Epoch 39/100, Train Loss: 2.2941, Train Accuracy: 0.9867
Epoch 39/100, Val Accuracy: 0.8578, Precision: 0.3226, Recall: 0.1282, AUC: 0.6177, Avg Metric: 0.3562


100%|██████████| 59/59 [02:35<00:00,  2.64s/it]


Confusion Matrix:
[[543   5]
 [ 78   0]]
Epoch 40/100, Train Loss: 1.9222, Train Accuracy: 0.9856
Epoch 40/100, Val Accuracy: 0.8674, Precision: 0.0000, Recall: 0.0000, AUC: 0.6025, Avg Metric: 0.2008


100%|██████████| 59/59 [02:35<00:00,  2.64s/it]


Confusion Matrix:
[[533  15]
 [ 71   7]]
Epoch 41/100, Train Loss: 2.5691, Train Accuracy: 0.9856
Epoch 41/100, Val Accuracy: 0.8626, Precision: 0.3182, Recall: 0.0897, AUC: 0.5751, Avg Metric: 0.3277


100%|██████████| 59/59 [02:40<00:00,  2.72s/it]


Confusion Matrix:
[[541   7]
 [ 76   2]]
Epoch 42/100, Train Loss: 2.6873, Train Accuracy: 0.9824
Epoch 42/100, Val Accuracy: 0.8674, Precision: 0.2222, Recall: 0.0256, AUC: 0.6286, Avg Metric: 0.2922


100%|██████████| 59/59 [02:41<00:00,  2.74s/it]


Confusion Matrix:
[[502  46]
 [ 65  13]]
Epoch 43/100, Train Loss: 0.9484, Train Accuracy: 0.9947
Epoch 43/100, Val Accuracy: 0.8227, Precision: 0.2203, Recall: 0.1667, AUC: 0.5982, Avg Metric: 0.3284


100%|██████████| 59/59 [02:40<00:00,  2.72s/it]


Confusion Matrix:
[[533  15]
 [ 71   7]]
Epoch 44/100, Train Loss: 0.3400, Train Accuracy: 0.9989
Epoch 44/100, Val Accuracy: 0.8626, Precision: 0.3182, Recall: 0.0897, AUC: 0.6455, Avg Metric: 0.3511


100%|██████████| 59/59 [02:34<00:00,  2.61s/it]


Confusion Matrix:
[[532  16]
 [ 69   9]]
Epoch 45/100, Train Loss: 0.2518, Train Accuracy: 0.9989
Epoch 45/100, Val Accuracy: 0.8642, Precision: 0.3600, Recall: 0.1154, AUC: 0.6186, Avg Metric: 0.3647


100%|██████████| 59/59 [02:34<00:00,  2.61s/it]


Confusion Matrix:
[[521  27]
 [ 67  11]]
Epoch 46/100, Train Loss: 0.6602, Train Accuracy: 0.9968
Epoch 46/100, Val Accuracy: 0.8498, Precision: 0.2895, Recall: 0.1410, AUC: 0.5920, Avg Metric: 0.3408


100%|██████████| 59/59 [02:34<00:00,  2.61s/it]


Confusion Matrix:
[[522  26]
 [ 69   9]]
Epoch 47/100, Train Loss: 1.2719, Train Accuracy: 0.9952
Epoch 47/100, Val Accuracy: 0.8482, Precision: 0.2571, Recall: 0.1154, AUC: 0.5968, Avg Metric: 0.3231


100%|██████████| 59/59 [02:34<00:00,  2.61s/it]


Confusion Matrix:
[[523  25]
 [ 69   9]]
Epoch 48/100, Train Loss: 1.0563, Train Accuracy: 0.9947
Epoch 48/100, Val Accuracy: 0.8498, Precision: 0.2647, Recall: 0.1154, AUC: 0.6326, Avg Metric: 0.3376


100%|██████████| 59/59 [02:34<00:00,  2.61s/it]


Confusion Matrix:
[[522  26]
 [ 68  10]]
Epoch 49/100, Train Loss: 1.1943, Train Accuracy: 0.9941
Epoch 49/100, Val Accuracy: 0.8498, Precision: 0.2778, Recall: 0.1282, AUC: 0.6206, Avg Metric: 0.3422


100%|██████████| 59/59 [02:34<00:00,  2.62s/it]


Confusion Matrix:
[[539   9]
 [ 69   9]]
Epoch 50/100, Train Loss: 0.9140, Train Accuracy: 0.9957
Epoch 50/100, Val Accuracy: 0.8754, Precision: 0.5000, Recall: 0.1154, AUC: 0.6064, Avg Metric: 0.4073


100%|██████████| 59/59 [02:34<00:00,  2.62s/it]


Confusion Matrix:
[[528  20]
 [ 70   8]]
Epoch 51/100, Train Loss: 1.1238, Train Accuracy: 0.9947
Epoch 51/100, Val Accuracy: 0.8562, Precision: 0.2857, Recall: 0.1026, AUC: 0.5974, Avg Metric: 0.3286


100%|██████████| 59/59 [02:34<00:00,  2.61s/it]


Confusion Matrix:
[[523  25]
 [ 67  11]]
Epoch 52/100, Train Loss: 0.2519, Train Accuracy: 0.9989
Epoch 52/100, Val Accuracy: 0.8530, Precision: 0.3056, Recall: 0.1410, AUC: 0.6183, Avg Metric: 0.3550


100%|██████████| 59/59 [02:34<00:00,  2.61s/it]


Confusion Matrix:
[[514  34]
 [ 65  13]]
Epoch 53/100, Train Loss: 0.1588, Train Accuracy: 0.9995
Epoch 53/100, Val Accuracy: 0.8419, Precision: 0.2766, Recall: 0.1667, AUC: 0.6483, Avg Metric: 0.3639


100%|██████████| 59/59 [02:34<00:00,  2.62s/it]


Confusion Matrix:
[[527  21]
 [ 72   6]]
Epoch 54/100, Train Loss: 0.1793, Train Accuracy: 0.9989
Epoch 54/100, Val Accuracy: 0.8514, Precision: 0.2222, Recall: 0.0769, AUC: 0.6399, Avg Metric: 0.3130


100%|██████████| 59/59 [02:34<00:00,  2.62s/it]


Confusion Matrix:
[[514  34]
 [ 72   6]]
Epoch 55/100, Train Loss: 0.7169, Train Accuracy: 0.9963
Epoch 55/100, Val Accuracy: 0.8307, Precision: 0.1500, Recall: 0.0769, AUC: 0.6028, Avg Metric: 0.2766
Early stopping triggered


In [12]:
model = MriResentModel(20,1).to(device)
model.load_state_dict(torch.load(f"saved_models/{model_name}.pth"))



<All keys matched successfully>

In [13]:
tester = Tester(
    model=model,
    test_dl=test_dl,
    test_dataset=test_dataset,
    device=device,
    threshold=0.5
)

tester.test()

Confusion Matrix:
[[307 241]
 [ 27  51]]
Test Accuracy: 0.5719, Precision: 0.1747, Recall: 0.6538, AUC: 0.6642, Avg Metric: 0.4976
