In [1]:
'''
model : BehradG/resnet-18-finetuned-MRI-Brain
resnet18 pretrained on brain mri images
model first layer modified to take 1-channel image instead of 3-channel
model last layer modified to output=1 instead of output=2
model fine tuned on just first layer and last layer 

'''

'\nmodel : BehradG/resnet-18-finetuned-MRI-Brain\nresnet18 pretrained on brain mri images\nmodel first layer modified to take 1-channel image instead of 3-channel\nmodel last layer modified to output=1 instead of output=2\nmodel fine tuned on just first layer and last layer \n\n'

In [2]:
import os

import numpy as np
import pandas as pd
# from sklearn.utils.class_weight import compute_class_weight       

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

from config import Config, Device
from datasets import MRIDataset, BalancedMRIDataset
from models import ResNet18MRI
from train import Trainer_for_resnet18
from test_file import Tester    

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = Device.device
print(device)

mps


In [4]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "train.csv"

batch_size = Config.batch_size
num_epochs = Config.num_epochs
learning_rate = Config.learning_rate
mean = Config.mean # mean of the entire datasaet
std = Config.std # std of the entire dataaset
image_size = 224

In [5]:
resclaed_mean = round(mean/255,4) # re-scale the actual mean
rescaled_std = round(std/255, 4) # re-scale the actual std

train_transforms = transforms.Compose([
    transforms.RandomRotation(degrees=10),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

augment_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    # transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

test_transforms = transforms.Compose([
    # transforms.Lambda(lambda img: img.astype(np.float32)),
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

In [6]:

train_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='train',
    transform=train_transforms,
    augment_transform=augment_transforms,
    augment=True,
    max_slices=20
)

val_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='val',
    transform=test_transforms,
    max_slices=20
)

test_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='test',
    transform=test_transforms,
    max_slices=20
)

train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=32)
test_dl = DataLoader(test_dataset, batch_size=32)

In [7]:
data_, label_ = next(iter(train_dl))
data_.size()

torch.Size([32, 20, 224, 224])

In [8]:
model = ResNet18MRI().to(device=device)

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at BehradG/resnet-18-finetuned-MRI-Brain and are newly initialized because the shapes did not match:
- resnet.embedder.embedder.convolution.weight: found shape torch.Size([64, 3, 7, 7]) in the checkpoint and torch.Size([64, 1, 7, 7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
for param in model.parameters():
    param.requires_grad = True

In [10]:
# loss and optimizer

# criterion = nn.BCEWithLogitsLoss().to(device)
# criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights).to(device)
criterion = nn.BCEWithLogitsLoss().to(device)
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [11]:
model_name = model.__class__.__name__
model_name

'ResNet18MRI'

In [12]:
trainer = Trainer_for_resnet18(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_dl=train_dl,
    val_dl=val_dl,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    device=device,
    num_epochs=100,
    patience=10,
    threshold=0.5,
    save_path=f"saved_models/{model_name}.pth"
)

# Start training
trainer.train()

100%|██████████| 89/89 [06:06<00:00,  4.12s/it]


Confusion Matrix:
[[10900    60]
 [ 1560     0]]
Epoch 1/100, Train Loss: 819.7164, Train Accuracy: 0.7755
Epoch 1/100, Val Accuracy: 0.8706, Precision: 0.0000, Recall: 0.0000, AUC: 0.4973, Avg Metric: 0.2902


100%|██████████| 89/89 [06:37<00:00,  4.47s/it]


Confusion Matrix:
[[9820 1140]
 [1460  100]]
Epoch 2/100, Train Loss: 803.9556, Train Accuracy: 0.7911
Epoch 2/100, Val Accuracy: 0.7923, Precision: 0.0806, Recall: 0.0641, AUC: 0.4800, Avg Metric: 0.3124


100%|██████████| 89/89 [06:39<00:00,  4.48s/it]


Confusion Matrix:
[[10920    40]
 [ 1560     0]]
Epoch 3/100, Train Loss: 787.3235, Train Accuracy: 0.7883
Epoch 3/100, Val Accuracy: 0.8722, Precision: 0.0000, Recall: 0.0000, AUC: 0.4982, Avg Metric: 0.2907


100%|██████████| 89/89 [06:32<00:00,  4.41s/it]


Confusion Matrix:
[[10960     0]
 [ 1560     0]]
Epoch 4/100, Train Loss: 752.8545, Train Accuracy: 0.7972
Epoch 4/100, Val Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.5000, Avg Metric: 0.2918


100%|██████████| 89/89 [06:25<00:00,  4.33s/it]


Confusion Matrix:
[[10920    40]
 [ 1560     0]]
Epoch 5/100, Train Loss: 678.4534, Train Accuracy: 0.8177
Epoch 5/100, Val Accuracy: 0.8722, Precision: 0.0000, Recall: 0.0000, AUC: 0.4982, Avg Metric: 0.2907


100%|██████████| 89/89 [06:09<00:00,  4.15s/it]


Confusion Matrix:
[[9900 1060]
 [1360  200]]
Epoch 6/100, Train Loss: 606.7041, Train Accuracy: 0.8585
Epoch 6/100, Val Accuracy: 0.8067, Precision: 0.1587, Recall: 0.1282, AUC: 0.5157, Avg Metric: 0.3645


100%|██████████| 89/89 [06:10<00:00,  4.17s/it]


Confusion Matrix:
[[10540   420]
 [ 1480    80]]
Epoch 7/100, Train Loss: 549.7631, Train Accuracy: 0.8826
Epoch 7/100, Val Accuracy: 0.8482, Precision: 0.1600, Recall: 0.0513, AUC: 0.5065, Avg Metric: 0.3532


100%|██████████| 89/89 [06:11<00:00,  4.17s/it]


Confusion Matrix:
[[10560   400]
 [ 1440   120]]
Epoch 8/100, Train Loss: 541.6665, Train Accuracy: 0.8865
Epoch 8/100, Val Accuracy: 0.8530, Precision: 0.2308, Recall: 0.0769, AUC: 0.5202, Avg Metric: 0.3869


100%|██████████| 89/89 [06:08<00:00,  4.15s/it]


Confusion Matrix:
[[10960     0]
 [ 1560     0]]
Epoch 9/100, Train Loss: 516.2483, Train Accuracy: 0.8848
Epoch 9/100, Val Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.5000, Avg Metric: 0.2918


100%|██████████| 89/89 [06:04<00:00,  4.10s/it]


Confusion Matrix:
[[5700 5260]
 [ 680  880]]
Epoch 10/100, Train Loss: 504.9040, Train Accuracy: 0.8840
Epoch 10/100, Val Accuracy: 0.5256, Precision: 0.1433, Recall: 0.5641, AUC: 0.5421, Avg Metric: 0.4110


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[10660   300]
 [ 1540    20]]
Epoch 11/100, Train Loss: 500.0452, Train Accuracy: 0.8947
Epoch 11/100, Val Accuracy: 0.8530, Precision: 0.0625, Recall: 0.0128, AUC: 0.4927, Avg Metric: 0.3095


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[10660   300]
 [ 1460   100]]
Epoch 12/100, Train Loss: 489.4446, Train Accuracy: 0.9046
Epoch 12/100, Val Accuracy: 0.8594, Precision: 0.2500, Recall: 0.0641, AUC: 0.5184, Avg Metric: 0.3912


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[10940    20]
 [ 1560     0]]
Epoch 13/100, Train Loss: 486.7590, Train Accuracy: 0.8904
Epoch 13/100, Val Accuracy: 0.8738, Precision: 0.0000, Recall: 0.0000, AUC: 0.4991, Avg Metric: 0.2913


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[10540   420]
 [ 1400   160]]
Epoch 14/100, Train Loss: 474.5050, Train Accuracy: 0.8982
Epoch 14/100, Val Accuracy: 0.8546, Precision: 0.2759, Recall: 0.1026, AUC: 0.5321, Avg Metric: 0.4110


100%|██████████| 89/89 [06:00<00:00,  4.05s/it]


Confusion Matrix:
[[10360   600]
 [ 1480    80]]
Epoch 15/100, Train Loss: 463.4072, Train Accuracy: 0.9004
Epoch 15/100, Val Accuracy: 0.8339, Precision: 0.1176, Recall: 0.0513, AUC: 0.4983, Avg Metric: 0.3343


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[10720   240]
 [ 1540    20]]
Epoch 16/100, Train Loss: 455.7757, Train Accuracy: 0.8904
Epoch 16/100, Val Accuracy: 0.8578, Precision: 0.0769, Recall: 0.0128, AUC: 0.4955, Avg Metric: 0.3159


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[5080 5880]
 [ 700  860]]
Epoch 17/100, Train Loss: 454.0553, Train Accuracy: 0.8947
Epoch 17/100, Val Accuracy: 0.4744, Precision: 0.1276, Recall: 0.5513, AUC: 0.5074, Avg Metric: 0.3844


100%|██████████| 89/89 [06:00<00:00,  4.05s/it]


Confusion Matrix:
[[10700   260]
 [ 1500    60]]
Epoch 18/100, Train Loss: 438.2871, Train Accuracy: 0.9039
Epoch 18/100, Val Accuracy: 0.8594, Precision: 0.1875, Recall: 0.0385, AUC: 0.5074, Avg Metric: 0.3618


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[10600   360]
 [ 1480    80]]
Epoch 19/100, Train Loss: 427.7929, Train Accuracy: 0.9092
Epoch 19/100, Val Accuracy: 0.8530, Precision: 0.1818, Recall: 0.0513, AUC: 0.5092, Avg Metric: 0.3620


100%|██████████| 89/89 [06:02<00:00,  4.07s/it]


Confusion Matrix:
[[10540   420]
 [ 1480    80]]
Epoch 20/100, Train Loss: 414.1100, Train Accuracy: 0.9067
Epoch 20/100, Val Accuracy: 0.8482, Precision: 0.1600, Recall: 0.0513, AUC: 0.5065, Avg Metric: 0.3532


100%|██████████| 89/89 [06:00<00:00,  4.05s/it]


Confusion Matrix:
[[10500   460]
 [ 1360   200]]
Epoch 21/100, Train Loss: 397.7838, Train Accuracy: 0.9011
Epoch 21/100, Val Accuracy: 0.8546, Precision: 0.3030, Recall: 0.1282, AUC: 0.5431, Avg Metric: 0.4286


100%|██████████| 89/89 [06:01<00:00,  4.06s/it]


Confusion Matrix:
[[9560 1400]
 [1300  260]]
Epoch 22/100, Train Loss: 375.2558, Train Accuracy: 0.9121
Epoch 22/100, Val Accuracy: 0.7843, Precision: 0.1566, Recall: 0.1667, AUC: 0.5195, Avg Metric: 0.3692


100%|██████████| 89/89 [06:00<00:00,  4.06s/it]


Confusion Matrix:
[[9720 1240]
 [1340  220]]
Epoch 23/100, Train Loss: 376.2390, Train Accuracy: 0.9018
Epoch 23/100, Val Accuracy: 0.7939, Precision: 0.1507, Recall: 0.1410, AUC: 0.5139, Avg Metric: 0.3619


100%|██████████| 89/89 [06:00<00:00,  4.05s/it]


Confusion Matrix:
[[10460   500]
 [ 1540    20]]
Epoch 24/100, Train Loss: 357.1448, Train Accuracy: 0.8996
Epoch 24/100, Val Accuracy: 0.8371, Precision: 0.0385, Recall: 0.0128, AUC: 0.4836, Avg Metric: 0.2961


100%|██████████| 89/89 [06:02<00:00,  4.08s/it]


Confusion Matrix:
[[10420   540]
 [ 1440   120]]
Epoch 25/100, Train Loss: 329.6895, Train Accuracy: 0.9128
Epoch 25/100, Val Accuracy: 0.8419, Precision: 0.1818, Recall: 0.0769, AUC: 0.5138, Avg Metric: 0.3669


100%|██████████| 89/89 [05:56<00:00,  4.00s/it]


Confusion Matrix:
[[10600   360]
 [ 1400   160]]
Epoch 26/100, Train Loss: 319.9134, Train Accuracy: 0.9191
Epoch 26/100, Val Accuracy: 0.8594, Precision: 0.3077, Recall: 0.1026, AUC: 0.5349, Avg Metric: 0.4232


100%|██████████| 89/89 [05:52<00:00,  3.96s/it]


Confusion Matrix:
[[8600 2360]
 [1000  560]]
Epoch 27/100, Train Loss: 303.1804, Train Accuracy: 0.9209
Epoch 27/100, Val Accuracy: 0.7316, Precision: 0.1918, Recall: 0.3590, AUC: 0.5718, Avg Metric: 0.4275


100%|██████████| 89/89 [05:53<00:00,  3.97s/it]


Confusion Matrix:
[[10960     0]
 [ 1560     0]]
Epoch 28/100, Train Loss: 287.7815, Train Accuracy: 0.9230
Epoch 28/100, Val Accuracy: 0.8754, Precision: 0.0000, Recall: 0.0000, AUC: 0.5000, Avg Metric: 0.2918


100%|██████████| 89/89 [05:53<00:00,  3.97s/it]


Confusion Matrix:
[[9600 1360]
 [1460  100]]
Epoch 29/100, Train Loss: 264.3314, Train Accuracy: 0.9230
Epoch 29/100, Val Accuracy: 0.7748, Precision: 0.0685, Recall: 0.0641, AUC: 0.4700, Avg Metric: 0.3025


100%|██████████| 89/89 [05:53<00:00,  3.97s/it]


Confusion Matrix:
[[10120   840]
 [ 1520    40]]
Epoch 30/100, Train Loss: 247.7530, Train Accuracy: 0.9387
Epoch 30/100, Val Accuracy: 0.8115, Precision: 0.0455, Recall: 0.0256, AUC: 0.4745, Avg Metric: 0.2942


100%|██████████| 89/89 [05:53<00:00,  3.97s/it]


Confusion Matrix:
[[8520 2440]
 [1320  240]]
Epoch 31/100, Train Loss: 225.3331, Train Accuracy: 0.9415
Epoch 31/100, Val Accuracy: 0.6997, Precision: 0.0896, Recall: 0.1538, AUC: 0.4656, Avg Metric: 0.3144
Early stopping triggered
