In [1]:
import os
import ssl
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

from config import Config, Device
from datasets import BalancedMRIDataset
from models import InceptionV3
from trainer import Trainer
from tester import Tester

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = Device.device
print(device)

mps


In [3]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "train.csv"

batch_size = Config.batch_size
num_epochs = Config.num_epochs
learning_rate = Config.learning_rate
mean = Config.mean  # mean of the entire datasaet
std = Config.std  # std of the entire dataaset
image_size = 224

In [4]:
resclaed_mean = round(mean/255, 4)  # re-scale the actual mean
rescaled_std = round(std/255, 4)  # re-scale the actual std

train_transforms = transforms.Compose([
    transforms.RandomRotation(degrees=10),
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

augment_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    # transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

test_transforms = transforms.Compose([
    # transforms.Lambda(lambda img: img.astype(np.float32)),
    transforms.ToTensor(),
    transforms.Resize((299, 299)),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

In [5]:

train_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='train',
    transform=train_transforms,
    augment_transform=augment_transforms,
    augment=True,
    max_slices=20
)

val_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='val',
    transform=test_transforms,
    max_slices=20
)

test_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='test',
    transform=test_transforms,
    max_slices=20
)

train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=32)
test_dl = DataLoader(test_dataset, batch_size=32)

In [6]:
data_, label_ = next(iter(train_dl))
data_.size()

torch.Size([32, 20, 299, 299])

In [7]:
ssl._create_default_https_context = ssl._create_stdlib_context

In [8]:
model = InceptionV3().to(device=device)

In [13]:
criterion = nn.BCEWithLogitsLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [14]:
model_name = model.__class__.__name__
model_name

'Inception3'

In [17]:
trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_dl=train_dl,
    val_dl=val_dl,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    device=device,
    num_epochs=100,
    patience=100,
    threshold=0.5,
    save_path=f"saved_models/{model_name}.pth"
)

# Start training
trainer.train()

Epoch 1/100


100%|██████████| 89/89 [14:38<00:00,  9.87s/it]


Confusion Matrix:
[[487  61]
 [ 58  20]]
Train Loss: 767.3479, Train Accuracy: 0.8245
Val Loss: 436.0755, Val Accuracy: 0.8099
Precision: 0.2469, Recall: 0.2564, AUC: 0.5725, Avg Metric: 0.4377
Epoch 2/100


100%|██████████| 89/89 [14:50<00:00, 10.00s/it]


Confusion Matrix:
[[ 92 456]
 [ 15  63]]
Train Loss: 689.7377, Train Accuracy: 0.9177
Val Loss: 538.7695, Val Accuracy: 0.2476
Precision: 0.1214, Recall: 0.8077, AUC: 0.4878, Avg Metric: 0.3922
Epoch 3/100


100%|██████████| 89/89 [14:46<00:00,  9.97s/it]


Confusion Matrix:
[[446 102]
 [ 59  19]]
Train Loss: 555.1585, Train Accuracy: 0.9741
Val Loss: 333.2873, Val Accuracy: 0.7428
Precision: 0.1570, Recall: 0.2436, AUC: 0.5287, Avg Metric: 0.3811
Epoch 4/100


100%|██████████| 89/89 [14:32<00:00,  9.81s/it]


Confusion Matrix:
[[548   0]
 [ 78   0]]
Train Loss: 520.8567, Train Accuracy: 0.9858
Val Loss: 169.8881, Val Accuracy: 0.8754
Precision: 0.0000, Recall: 0.0000, AUC: 0.5000, Avg Metric: 0.2918
Epoch 5/100


100%|██████████| 89/89 [14:30<00:00,  9.78s/it]


Confusion Matrix:
[[471  77]
 [ 64  14]]
Train Loss: 478.4733, Train Accuracy: 0.9890
Val Loss: 330.6435, Val Accuracy: 0.7748
Precision: 0.1538, Recall: 0.1795, AUC: 0.5195, Avg Metric: 0.3694
Epoch 6/100


100%|██████████| 89/89 [14:39<00:00,  9.88s/it]


Confusion Matrix:
[[274 274]
 [ 22  56]]
Train Loss: 430.8303, Train Accuracy: 0.9954
Val Loss: 393.9540, Val Accuracy: 0.5272
Precision: 0.1697, Recall: 0.7179, AUC: 0.6090, Avg Metric: 0.4716
Epoch 7/100


100%|██████████| 89/89 [14:30<00:00,  9.78s/it]


Confusion Matrix:
[[275 273]
 [ 25  53]]
Train Loss: 389.5077, Train Accuracy: 0.9986
Val Loss: 459.7212, Val Accuracy: 0.5240
Precision: 0.1626, Recall: 0.6795, AUC: 0.5907, Avg Metric: 0.4553
Epoch 8/100


100%|██████████| 89/89 [14:23<00:00,  9.71s/it]


Confusion Matrix:
[[215 333]
 [ 14  64]]
Train Loss: 348.2242, Train Accuracy: 0.9993
Val Loss: 521.0973, Val Accuracy: 0.4457
Precision: 0.1612, Recall: 0.8205, AUC: 0.6064, Avg Metric: 0.4758
Epoch 9/100


100%|██████████| 89/89 [14:48<00:00,  9.98s/it]


Confusion Matrix:
[[531  17]
 [ 66  12]]
Train Loss: 297.5147, Train Accuracy: 0.9989
Val Loss: 216.4145, Val Accuracy: 0.8674
Precision: 0.4138, Recall: 0.1538, AUC: 0.5614, Avg Metric: 0.4784
Epoch 10/100


100%|██████████| 89/89 [14:57<00:00, 10.08s/it]


Confusion Matrix:
[[548   0]
 [ 68  10]]
Train Loss: 246.1203, Train Accuracy: 1.0000
Val Loss: 225.2817, Val Accuracy: 0.8914
Precision: 1.0000, Recall: 0.1282, AUC: 0.5641, Avg Metric: 0.6732
Epoch 11/100


100%|██████████| 89/89 [14:29<00:00,  9.77s/it]


Confusion Matrix:
[[540   8]
 [ 62  16]]
Train Loss: 203.1829, Train Accuracy: 1.0000
Val Loss: 272.1619, Val Accuracy: 0.8882
Precision: 0.6667, Recall: 0.2051, AUC: 0.5953, Avg Metric: 0.5867
Epoch 12/100


100%|██████████| 89/89 [14:06<00:00,  9.51s/it]


Confusion Matrix:
[[539   9]
 [ 53  25]]
Train Loss: 173.1529, Train Accuracy: 1.0000
Val Loss: 294.7462, Val Accuracy: 0.9010
Precision: 0.7353, Recall: 0.3205, AUC: 0.6520, Avg Metric: 0.6523
Epoch 13/100


100%|██████████| 89/89 [14:21<00:00,  9.68s/it]


Confusion Matrix:
[[518  30]
 [ 54  24]]
Train Loss: 143.0957, Train Accuracy: 1.0000
Val Loss: 410.9548, Val Accuracy: 0.8658
Precision: 0.4444, Recall: 0.3077, AUC: 0.6265, Avg Metric: 0.5393
Epoch 14/100


100%|██████████| 89/89 [14:01<00:00,  9.45s/it]


Confusion Matrix:
[[540   8]
 [ 58  20]]
Train Loss: 122.7394, Train Accuracy: 1.0000
Val Loss: 243.6190, Val Accuracy: 0.8946
Precision: 0.7143, Recall: 0.2564, AUC: 0.6209, Avg Metric: 0.6218
Epoch 15/100


100%|██████████| 89/89 [13:25<00:00,  9.05s/it]


Confusion Matrix:
[[526  22]
 [ 57  21]]
Train Loss: 102.4391, Train Accuracy: 1.0000
Val Loss: 359.6725, Val Accuracy: 0.8738
Precision: 0.4884, Recall: 0.2692, AUC: 0.6145, Avg Metric: 0.5438
Epoch 16/100


100%|██████████| 89/89 [13:14<00:00,  8.93s/it]


Confusion Matrix:
[[542   6]
 [ 60  18]]
Train Loss: 100.4107, Train Accuracy: 1.0000
Val Loss: 505.7478, Val Accuracy: 0.8946
Precision: 0.7500, Recall: 0.2308, AUC: 0.6099, Avg Metric: 0.6251
Epoch 17/100


100%|██████████| 89/89 [13:22<00:00,  9.02s/it]


Confusion Matrix:
[[527  21]
 [ 52  26]]
Train Loss: 100.3141, Train Accuracy: 1.0000
Val Loss: 322.5864, Val Accuracy: 0.8834
Precision: 0.5532, Recall: 0.3333, AUC: 0.6475, Avg Metric: 0.5900
Epoch 18/100


100%|██████████| 89/89 [13:24<00:00,  9.04s/it]


Confusion Matrix:
[[546   2]
 [ 76   2]]
Train Loss: 110.6125, Train Accuracy: 1.0000
Val Loss: 235.5790, Val Accuracy: 0.8754
Precision: 0.5000, Recall: 0.0256, AUC: 0.5110, Avg Metric: 0.4670
Epoch 19/100


100%|██████████| 89/89 [13:24<00:00,  9.04s/it]


Confusion Matrix:
[[546   2]
 [ 74   4]]
Train Loss: 94.2638, Train Accuracy: 0.9993
Val Loss: 210.9961, Val Accuracy: 0.8786
Precision: 0.6667, Recall: 0.0513, AUC: 0.5238, Avg Metric: 0.5322
Epoch 20/100


100%|██████████| 89/89 [13:28<00:00,  9.08s/it]


Confusion Matrix:
[[316 232]
 [ 16  62]]
Train Loss: 80.9773, Train Accuracy: 1.0000
Val Loss: 685.1912, Val Accuracy: 0.6038
Precision: 0.2109, Recall: 0.7949, AUC: 0.6858, Avg Metric: 0.5365
Epoch 21/100


100%|██████████| 89/89 [13:16<00:00,  8.95s/it]


Confusion Matrix:
[[536  12]
 [ 49  29]]
Train Loss: 78.3882, Train Accuracy: 1.0000
Val Loss: 288.4523, Val Accuracy: 0.9026
Precision: 0.7073, Recall: 0.3718, AUC: 0.6749, Avg Metric: 0.6606
Epoch 22/100


100%|██████████| 89/89 [13:14<00:00,  8.93s/it]


Confusion Matrix:
[[528  20]
 [ 57  21]]
Train Loss: 61.4214, Train Accuracy: 1.0000
Val Loss: 361.1205, Val Accuracy: 0.8770
Precision: 0.5122, Recall: 0.2692, AUC: 0.6164, Avg Metric: 0.5528
Epoch 23/100


100%|██████████| 89/89 [13:12<00:00,  8.91s/it]


Confusion Matrix:
[[529  19]
 [ 52  26]]
Train Loss: 60.2420, Train Accuracy: 1.0000
Val Loss: 356.7512, Val Accuracy: 0.8866
Precision: 0.5778, Recall: 0.3333, AUC: 0.6493, Avg Metric: 0.5992
Epoch 24/100


100%|██████████| 89/89 [13:15<00:00,  8.94s/it]


Confusion Matrix:
[[545   3]
 [ 69   9]]
Train Loss: 55.6897, Train Accuracy: 1.0000
Val Loss: 311.9065, Val Accuracy: 0.8850
Precision: 0.7500, Recall: 0.1154, AUC: 0.5550, Avg Metric: 0.5835
Epoch 25/100


100%|██████████| 89/89 [13:23<00:00,  9.03s/it]


Confusion Matrix:
[[548   0]
 [ 74   4]]
Train Loss: 51.5896, Train Accuracy: 1.0000
Val Loss: 322.5732, Val Accuracy: 0.8818
Precision: 1.0000, Recall: 0.0513, AUC: 0.5256, Avg Metric: 0.6444
Epoch 26/100


100%|██████████| 89/89 [13:52<00:00,  9.35s/it]


Confusion Matrix:
[[548   0]
 [ 71   7]]
Train Loss: 49.2542, Train Accuracy: 1.0000
Val Loss: 400.9298, Val Accuracy: 0.8866
Precision: 1.0000, Recall: 0.0897, AUC: 0.5449, Avg Metric: 0.6588
Epoch 27/100


100%|██████████| 89/89 [13:36<00:00,  9.18s/it]


Confusion Matrix:
[[545   3]
 [ 69   9]]
Train Loss: 46.9213, Train Accuracy: 1.0000
Val Loss: 266.6337, Val Accuracy: 0.8850
Precision: 0.7500, Recall: 0.1154, AUC: 0.5550, Avg Metric: 0.5835
Epoch 28/100


100%|██████████| 89/89 [13:31<00:00,  9.12s/it]


Confusion Matrix:
[[547   1]
 [ 74   4]]
Train Loss: 44.6214, Train Accuracy: 1.0000
Val Loss: 373.6981, Val Accuracy: 0.8802
Precision: 0.8000, Recall: 0.0513, AUC: 0.5247, Avg Metric: 0.5772
Epoch 29/100


100%|██████████| 89/89 [13:28<00:00,  9.08s/it]


Confusion Matrix:
[[548   0]
 [ 77   1]]
Train Loss: 46.8466, Train Accuracy: 1.0000
Val Loss: 244.1366, Val Accuracy: 0.8770
Precision: 1.0000, Recall: 0.0128, AUC: 0.5064, Avg Metric: 0.6299
Epoch 30/100


100%|██████████| 89/89 [13:25<00:00,  9.05s/it]


Confusion Matrix:
[[529  19]
 [ 50  28]]
Train Loss: 69.0855, Train Accuracy: 0.9996
Val Loss: 419.7052, Val Accuracy: 0.8898
Precision: 0.5957, Recall: 0.3590, AUC: 0.6622, Avg Metric: 0.6148
Epoch 31/100


100%|██████████| 89/89 [13:33<00:00,  9.14s/it]


Confusion Matrix:
[[547   1]
 [ 72   6]]
Train Loss: 42.6791, Train Accuracy: 1.0000
Val Loss: 251.7989, Val Accuracy: 0.8834
Precision: 0.8571, Recall: 0.0769, AUC: 0.5375, Avg Metric: 0.6058
Epoch 32/100


100%|██████████| 89/89 [13:48<00:00,  9.30s/it]


Confusion Matrix:
[[537  11]
 [ 56  22]]
Train Loss: 37.7033, Train Accuracy: 1.0000
Val Loss: 530.8027, Val Accuracy: 0.8930
Precision: 0.6667, Recall: 0.2821, AUC: 0.6310, Avg Metric: 0.6139
Epoch 33/100


100%|██████████| 89/89 [13:47<00:00,  9.30s/it]


Confusion Matrix:
[[546   2]
 [ 71   7]]
Train Loss: 37.9915, Train Accuracy: 1.0000
Val Loss: 258.8815, Val Accuracy: 0.8834
Precision: 0.7778, Recall: 0.0897, AUC: 0.5430, Avg Metric: 0.5836
Epoch 34/100


100%|██████████| 89/89 [13:48<00:00,  9.31s/it]


Confusion Matrix:
[[541   7]
 [ 59  19]]
Train Loss: 32.8893, Train Accuracy: 1.0000
Val Loss: 281.9925, Val Accuracy: 0.8946
Precision: 0.7308, Recall: 0.2436, AUC: 0.6154, Avg Metric: 0.6230
Epoch 35/100


100%|██████████| 89/89 [13:53<00:00,  9.37s/it]


Confusion Matrix:
[[545   3]
 [ 68  10]]
Train Loss: 34.1928, Train Accuracy: 1.0000
Val Loss: 347.0190, Val Accuracy: 0.8866
Precision: 0.7692, Recall: 0.1282, AUC: 0.5614, Avg Metric: 0.5947
Epoch 36/100


100%|██████████| 89/89 [14:02<00:00,  9.46s/it]


Confusion Matrix:
[[538  10]
 [ 51  27]]
Train Loss: 36.2510, Train Accuracy: 1.0000
Val Loss: 370.5312, Val Accuracy: 0.9026
Precision: 0.7297, Recall: 0.3462, AUC: 0.6640, Avg Metric: 0.6595
Epoch 37/100


100%|██████████| 89/89 [14:02<00:00,  9.46s/it]


Confusion Matrix:
[[374 174]
 [ 31  47]]
Train Loss: 31.2636, Train Accuracy: 1.0000
Val Loss: 965.3238, Val Accuracy: 0.6725
Precision: 0.2127, Recall: 0.6026, AUC: 0.6425, Avg Metric: 0.4959
Epoch 38/100


100%|██████████| 89/89 [14:01<00:00,  9.46s/it]


Confusion Matrix:
[[546   2]
 [ 68  10]]
Train Loss: 37.5624, Train Accuracy: 1.0000
Val Loss: 332.9092, Val Accuracy: 0.8882
Precision: 0.8333, Recall: 0.1282, AUC: 0.5623, Avg Metric: 0.6166
Epoch 39/100


100%|██████████| 89/89 [14:00<00:00,  9.44s/it]


Confusion Matrix:
[[547   1]
 [ 65  13]]
Train Loss: 29.3591, Train Accuracy: 1.0000
Val Loss: 366.9225, Val Accuracy: 0.8946
Precision: 0.9286, Recall: 0.1667, AUC: 0.5824, Avg Metric: 0.6633
Epoch 40/100


100%|██████████| 89/89 [13:56<00:00,  9.40s/it]


Confusion Matrix:
[[547   1]
 [ 71   7]]
Train Loss: 28.7134, Train Accuracy: 1.0000
Val Loss: 408.5463, Val Accuracy: 0.8850
Precision: 0.8750, Recall: 0.0897, AUC: 0.5440, Avg Metric: 0.6166
Epoch 41/100


100%|██████████| 89/89 [14:04<00:00,  9.49s/it]


Confusion Matrix:
[[543   5]
 [ 61  17]]
Train Loss: 24.4222, Train Accuracy: 1.0000
Val Loss: 381.2527, Val Accuracy: 0.8946
Precision: 0.7727, Recall: 0.2179, AUC: 0.6044, Avg Metric: 0.6284
Epoch 42/100


100%|██████████| 89/89 [14:12<00:00,  9.58s/it]


Confusion Matrix:
[[545   3]
 [ 63  15]]
Train Loss: 27.7212, Train Accuracy: 1.0000
Val Loss: 377.6500, Val Accuracy: 0.8946
Precision: 0.8333, Recall: 0.1923, AUC: 0.5934, Avg Metric: 0.6401
Epoch 43/100


100%|██████████| 89/89 [14:16<00:00,  9.63s/it]


Confusion Matrix:
[[544   4]
 [ 59  19]]
Train Loss: 27.4045, Train Accuracy: 1.0000
Val Loss: 374.1431, Val Accuracy: 0.8994
Precision: 0.8261, Recall: 0.2436, AUC: 0.6181, Avg Metric: 0.6563
Epoch 44/100


100%|██████████| 89/89 [14:23<00:00,  9.70s/it]


Confusion Matrix:
[[547   1]
 [ 61  17]]
Train Loss: 25.5949, Train Accuracy: 1.0000
Val Loss: 330.8820, Val Accuracy: 0.9010
Precision: 0.9444, Recall: 0.2179, AUC: 0.6081, Avg Metric: 0.6878
Epoch 45/100


100%|██████████| 89/89 [14:25<00:00,  9.73s/it]


Confusion Matrix:
[[361 187]
 [ 29  49]]
Train Loss: 28.7162, Train Accuracy: 1.0000
Val Loss: 980.5257, Val Accuracy: 0.6550
Precision: 0.2076, Recall: 0.6282, AUC: 0.6435, Avg Metric: 0.4969
Epoch 46/100


  0%|          | 0/89 [00:08<?, ?it/s]


KeyboardInterrupt: 

In [None]:
plt.figure()
plt.plot(trainer.train_losses, "bo", label="Training loss")
plt.plot(trainer.val_losses, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()

In [15]:
model.load_state_dict(torch.load(f"saved_models/{model_name}.pth"))

  model.load_state_dict(torch.load(f"saved_models/{model_name}.pth"))


<All keys matched successfully>

In [16]:
tester = Tester(
    model=model,
    criterion=criterion,
    test_dl=test_dl,
    test_dataset=test_dataset,
    device=device,
    threshold=0.5
)

tester.test()

Test Accuracy: 0.8802, Precision: 0.6154, Recall: 0.1026, AUC: 0.5467, Avg Metric: 0.5327
