In [1]:
import os
import ssl
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

from config import Config, Device
from datasets import BalancedMRIDataset
from models import InceptionV3
from trainer import Trainer
from tester import Tester
from models import darknet53

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = Device.device
print(device)

mps


In [3]:
data_path = os.path.join(os.getcwd(), "data")
labels_path = "train.csv"

batch_size = Config.batch_size
num_epochs = Config.num_epochs
learning_rate = Config.learning_rate
mean = Config.mean  # mean of the entire datasaet
std = Config.std  # std of the entire dataaset
image_size = 224

In [4]:
resclaed_mean = round(mean/255, 4)  # re-scale the actual mean
rescaled_std = round(std/255, 4)  # re-scale the actual std

train_transforms = transforms.Compose([
    transforms.RandomRotation(degrees=10),
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

augment_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    # transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

test_transforms = transforms.Compose([
    # transforms.Lambda(lambda img: img.astype(np.float32)),
    transforms.ToTensor(),
    transforms.Resize((299, 299)),
    transforms.Normalize(mean=[resclaed_mean], std=[rescaled_std])
])

In [5]:

train_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='train',
    transform=train_transforms,
    augment_transform=augment_transforms,
    augment=True,
    max_slices=20
)

val_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='val',
    transform=test_transforms,
    max_slices=20
)

test_dataset = BalancedMRIDataset(
    data_path,
    labels_path,
    split='test',
    transform=test_transforms,
    max_slices=20
)

train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=32)
test_dl = DataLoader(test_dataset, batch_size=32)

In [6]:
data_, label_ = next(iter(train_dl))
data_.size()

torch.Size([32, 20, 299, 299])

In [7]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)


In [8]:
model = darknet53(in_channels=1, num_classes=1).to(device=device)
initialize_weights(model)

In [9]:
criterion = nn.BCEWithLogitsLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [10]:
model_name = model.__class__.__name__
model_name

'Darknet53'

In [11]:
trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    train_dl=train_dl,
    val_dl=val_dl,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    device=device,
    num_epochs=100,
    patience=20,
    threshold=0.5,
    save_path=f"saved_models/{model_name}.pth"
)

# Start training
trainer.train()

Epoch 1/100


100%|██████████| 89/89 [30:23<00:00, 20.49s/it]  


Train Loss: 669.1884, Train Accuracy: 0.9901
Val Loss: 233.0131, Val Accuracy: 0.8754
Precision: 0.0000, Recall: 0.0000, AUC: 0.5000, Avg Metric: 0.1667
Confusion Matrix:
[[548   0]
 [ 78   0]]
Epoch 2/100


 46%|████▌     | 41/89 [13:55<16:47, 21.00s/it]

In [None]:
model.load_state_dict(torch.load(f"saved_models/{model_name}.pth"))

In [None]:
tester = Tester(
    model=model,
    criterion=criterion,
    test_dl=test_dl,
    test_dataset=test_dataset,
    device=device,
    threshold=0.5
)

tester.test()