In [None]:
'''
model : BehradG/resnet-18-finetuned-MRI-Brain
resnet18 pretrained on brain mri images
model first layer modified to take 1-channel image instead of 3-channel
model last layer modified to output=1 instead of output=2
model fine tuned on just first layer and last layer 

'''

In [13]:
import os

import numpy as np
import pandas as pd
# from sklearn.utils.class_weight import compute_class_weight       

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader

from config import Config, Device
from datasets import MRIDataset, BalancedMRIDataset
from models import ResNet18MRI
from train import Trainer_for_resnet18
from test_file import Tester    

In [2]:
device = Device.device
print(device)

mps


In [4]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "train.csv"

batch_size = Config.batch_size
num_epochs = Config.num_epochs
learning_rate = Config.learning_rate
mean = Config.mean # mean of the entire datasaet
std = Config.std # std of the entire dataaset
image_size = 224

In [5]:
resclaed_mean = round(mean/255,4) # re-scale the actual mean
rescaled_std = round(std/255, 4) # re-scale the actual std

train_transforms = transforms.Compose([
    transforms.RandomRotation(degrees=10),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

augment_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    # transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

test_transforms = transforms.Compose([
    # transforms.Lambda(lambda img: img.astype(np.float32)),
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

In [6]:

train_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='train',
    transform=train_transforms,
    augment_transform=augment_transforms,
    augment=True,
    max_slices=20
)

val_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='val',
    transform=test_transforms,
    max_slices=20
)

test_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='test',
    transform=test_transforms,
    max_slices=20
)

train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=32)
test_dl = DataLoader(test_dataset, batch_size=32)

In [7]:
data_, label_ = next(iter(train_dl))
data_.size()

torch.Size([32, 20, 224, 224])

In [9]:
model = ResNet18MRI().to(device=device)

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at BehradG/resnet-18-finetuned-MRI-Brain and are newly initialized because the shapes did not match:
- resnet.embedder.embedder.convolution.weight: found shape torch.Size([64, 3, 7, 7]) in the checkpoint and torch.Size([64, 1, 7, 7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# loss and optimizer

# criterion = nn.BCEWithLogitsLoss().to(device)
# criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights).to(device)
criterion = nn.BCEWithLogitsLoss().to(device)
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [11]:
model_name = model.__class__.__name__
model_name

'ResNet18MRI'

In [14]:
trainer = Trainer_for_resnet18(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_dl=train_dl,
    val_dl=val_dl,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    device=device,
    num_epochs=10,
    patience=5,
    threshold=0.5,
    save_path=f"saved_models/{model_name}.pth"
)

# Start training
trainer.train()

100%|██████████| 89/89 [13:39<00:00,  9.21s/it]


Confusion Matrix:
[[10820   140]
 [ 1560     0]]
Epoch 1/10, Train Loss: 890.7513, Train Accuracy: 0.7603
Epoch 1/10, Val Accuracy: 17.2843, Precision: 0.0000, Recall: 0.0000, AUC: 0.4936, Avg Metric: 5.7614


100%|██████████| 89/89 [09:17<00:00,  6.26s/it]


Confusion Matrix:
[[10880    80]
 [ 1520    40]]
Epoch 2/10, Train Loss: 840.2845, Train Accuracy: 0.7833
Epoch 2/10, Val Accuracy: 17.4441, Precision: 0.3333, Recall: 0.0256, AUC: 0.5092, Avg Metric: 5.9344


100%|██████████| 89/89 [08:35<00:00,  5.79s/it]


Confusion Matrix:
[[10860   100]
 [ 1540    20]]
Epoch 3/10, Train Loss: 833.3933, Train Accuracy: 0.7869
Epoch 3/10, Val Accuracy: 17.3802, Precision: 0.1667, Recall: 0.0128, AUC: 0.5018, Avg Metric: 5.8532


100%|██████████| 89/89 [08:35<00:00,  5.79s/it]


Confusion Matrix:
[[10880    80]
 [ 1540    20]]
Epoch 4/10, Train Loss: 838.6473, Train Accuracy: 0.7734
Epoch 4/10, Val Accuracy: 17.4121, Precision: 0.2000, Recall: 0.0128, AUC: 0.5028, Avg Metric: 5.8750


100%|██████████| 89/89 [08:42<00:00,  5.88s/it]


Confusion Matrix:
[[10520   440]
 [ 1520    40]]
Epoch 5/10, Train Loss: 836.0139, Train Accuracy: 0.7833
Epoch 5/10, Val Accuracy: 16.8690, Precision: 0.0833, Recall: 0.0256, AUC: 0.4927, Avg Metric: 5.6593


100%|██████████| 89/89 [08:35<00:00,  5.80s/it]


Confusion Matrix:
[[10960     0]
 [ 1560     0]]
Epoch 6/10, Train Loss: 836.3028, Train Accuracy: 0.7706
Epoch 6/10, Val Accuracy: 17.5080, Precision: 0.0000, Recall: 0.0000, AUC: 0.5000, Avg Metric: 5.8360


100%|██████████| 89/89 [08:37<00:00,  5.81s/it]


Confusion Matrix:
[[10880    80]
 [ 1560     0]]
Epoch 7/10, Train Loss: 833.8117, Train Accuracy: 0.7794
Epoch 7/10, Val Accuracy: 17.3802, Precision: 0.0000, Recall: 0.0000, AUC: 0.4964, Avg Metric: 5.7934
Early stopping triggered
