In [36]:
import torch
from torchvision.models import resnet18
import torch.nn as nn
ckpt_path = 'training_checkpoints/resnet18_normalize_batch_2048_test_from_pretrained2024-04-28 22:44:55.674715.pth'
model = resnet18()
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
num_classes = 4
model.fc = nn.Sequential(
    nn.Dropout(0.4),
    nn.Linear(model.fc.in_features, num_classes))


In [37]:
state_dict = torch.load(ckpt_path)

new_state_dict = {}
for key, value in state_dict.items():
    new_key = key.replace('model.', '')  
    new_state_dict[new_key] = value

In [38]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [39]:
for name, param in model.named_parameters():
    if not name.startswith('fc'):  
        param.requires_grad = False

In [42]:
i=0
for name, param in model.named_parameters():
    if i%2 == 0:
        param.requires_grad = True
    i+=1

In [43]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

conv1.weight True
bn1.weight False
bn1.bias True
layer1.0.conv1.weight False
layer1.0.bn1.weight True
layer1.0.bn1.bias False
layer1.0.conv2.weight True
layer1.0.bn2.weight False
layer1.0.bn2.bias True
layer1.1.conv1.weight False
layer1.1.bn1.weight True
layer1.1.bn1.bias False
layer1.1.conv2.weight True
layer1.1.bn2.weight False
layer1.1.bn2.bias True
layer2.0.conv1.weight False
layer2.0.bn1.weight True
layer2.0.bn1.bias False
layer2.0.conv2.weight True
layer2.0.bn2.weight False
layer2.0.bn2.bias True
layer2.0.downsample.0.weight False
layer2.0.downsample.1.weight True
layer2.0.downsample.1.bias False
layer2.1.conv1.weight True
layer2.1.bn1.weight False
layer2.1.bn1.bias True
layer2.1.conv2.weight False
layer2.1.bn2.weight True
layer2.1.bn2.bias False
layer3.0.conv1.weight True
layer3.0.bn1.weight False
layer3.0.bn1.bias True
layer3.0.conv2.weight False
layer3.0.bn2.weight True
layer3.0.bn2.bias False
layer3.0.downsample.0.weight True
layer3.0.downsample.1.weight False
layer3.0.downsa

In [44]:
import time
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import ImageDataset
import torch.nn.functional as F
from torchmetrics import Precision, Recall
from torchvision.models import resnet18
import warnings
from collections import defaultdict
import wandb
import datetime
import os
warnings.filterwarnings('ignore')
torch.set_float32_matmul_precision('high')
#intecubic interpol

run_name = f'resnet18_normalize_batch_2048_test_from_pretrained{datetime.datetime.now()}'
run_path = f'training_checkpoints/{run_name}'

wandb.init(project="cells", 
           entity="adamsoja",
          name=run_name)

import random
random.seed(2233)
torch.manual_seed(2233)

#After /255 so in loading dataset there are no division by 255 just this normalization
mean = [0.5006, 0.3526, 0.5495]
std = [0.1493, 0.1341, 0.1124]


from albumentations import (
    Compose,
    Resize,
    OneOf,
    RandomBrightness,
    RandomContrast,
    MotionBlur,
    MedianBlur,
    GaussianBlur,
    VerticalFlip,
    HorizontalFlip,
    ShiftScaleRotate,
    Normalize,
)

transform = Compose(
    [
        Normalize(mean=mean, std=std),
        OneOf([RandomBrightness(limit=0.4, p=1), RandomContrast(limit=0.4, p=0.8)]),
        OneOf([MotionBlur(blur_limit=3), MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3),], p=0.7,),
        VerticalFlip(p=0.5),
        HorizontalFlip(p=0.5),
    ]
)

transform_test = Compose(
    [Normalize(mean=mean, std=std)]
)

class MyModel(nn.Module):
    def __init__(self, model, learning_rate, weight_decay):
        super(MyModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.criterion = nn.CrossEntropyLoss()
        self.metric_precision = Precision(task="multiclass", num_classes=4, average=None).to('cuda')
        self.metric_recall = Recall(task="multiclass", num_classes=4, average=None).to('cuda')
        self.train_loss = []
        self.valid_loss = []
        self.precision_per_epochs = []
        self.recall_per_epochs = []

        self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=weight_decay)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode="min", factor=0.0001, patience=2, min_lr=5e-6, verbose=True)
        self.step = 0

    
    def forward(self, x):
        return self.model(x)

    def train_one_epoch(self, trainloader):
        self.step += 1
        self.train()
        for batch_idx, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            _, preds = torch.max(outputs, 1)
            _, labels = torch.max(labels, 1)
            self.metric_precision(preds, labels)
            self.metric_recall(preds, labels)
            self.train_loss.append(loss.item())


        

        
        avg_loss = np.mean(self.train_loss)
        self.train_loss.clear()
        precision = self.metric_precision.compute()
        recall = self.metric_recall.compute()
        self.precision_per_epochs.append(precision)
        self.recall_per_epochs.append(recall)
        print(f'train_loss: {avg_loss}')
        print(f'train_precision: {precision}')
        print(f'train_recall: {recall}')

        wandb.log({'loss': avg_loss},step=self.step)
        wandb.log({'Normal precision': precision[0].item()},step=self.step)
        wandb.log({'Inflamatory precision': precision[1].item()},step=self.step)
        wandb.log({'Tumor precision': precision[2].item()},step=self.step)
        wandb.log({'Other precision': precision[3].item()},step=self.step)


        wandb.log({'Normal recall': recall[0].item()},step=self.step)
        wandb.log({'Inflamatory recall': recall[1].item()},step=self.step)
        wandb.log({'Tumor recall': recall[2].item()},step=self.step)
        wandb.log({'Other recall': recall[3].item()},step=self.step)
        
        
        self.metric_precision.reset()
        self.metric_recall.reset()


    

    def evaluate(self, testloader):
        self.eval()
        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(testloader):
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                _, labels = torch.max(labels, 1)
                self.metric_precision(preds, labels)
                self.metric_recall(preds, labels)
                self.valid_loss.append(loss.item())
    
        avg_loss = np.mean(self.valid_loss)
        self.scheduler.step(avg_loss)
        self.valid_loss.clear()
        precision = self.metric_precision.compute()
        recall = self.metric_recall.compute()
        print(f'val_loss: {avg_loss}')
        print(f'val_precision: {precision}')
        print(f'val_recall: {recall}')
        self.metric_precision.reset()
        self.metric_recall.reset()

        wandb.log({'val_loss': avg_loss}, step=self.step)
        
        wandb.log({'val_Normal precision': precision[0].item()},step=self.step)
        wandb.log({'val_Inflamatory precision': precision[1].item()},step=self.step)
        wandb.log({'val_Tumor precision': precision[2].item()},step=self.step)
        wandb.log({'val_Other precision': precision[3].item()},step=self.step)


        wandb.log({'val_Normal recall': recall[0].item()},step=self.step)
        wandb.log({'val_Inflamatory recall': recall[1].item()},step=self.step)
        wandb.log({'val_Tumor recall': recall[2].item()},step=self.step)
        wandb.log({'val_Other recall': recall[3].item()},step=self.step)


        for param_group in self.optimizer.param_groups:
            print(f"Learning rate: {param_group['lr']}")
        return avg_loss

torch.cuda.empty_cache()
batch_size = 128

trainset = ImageDataset(data_path='train_data', transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=3)

testset = ImageDataset(data_path='validation_data', transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)


learning_rate = 0.001
weight_decay = 0.00005


model = model.to('cuda')
my_model = MyModel(model=model, learning_rate=learning_rate, weight_decay=weight_decay)
my_model = my_model.to('cuda')

num_epochs = 100
early_stop_patience = 10
best_val_loss = float('inf')
best_model_state_dict = None

for epoch in range(num_epochs):
    print('========================================')
    print(f'EPOCH: {epoch}') 
    time_start = time.perf_counter()
    my_model.train_one_epoch(trainloader)
    val_loss = my_model.evaluate(testloader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state_dict = my_model.state_dict()
        torch.save(best_model_state_dict, f'{run_path}.pth')
        
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= early_stop_patience:
        print(f"Early stopping at epoch {epoch} with best validation loss {best_val_loss}")
        break
    time_epoch = time.perf_counter() - time_start
    print(f'epoch {epoch} time:  {time_epoch/60}')
    print('--------------------------------')

# Load the best model state dict
print(f'{run_path}.pth')
my_model.load_state_dict(torch.load(f'{run_path}.pth'))

0,1
Inflamatory precision,▆▅▅▅▃█▁█▃▅▅▇▂█▄▅▇
Inflamatory recall,▁▅▆▅▅▄▄▃▆▆█▅▆▆▅▅▅
Normal precision,▁▄▄▆▆▆▃▄▆▆█▇▇▆▆▇█
Normal recall,▂▄▄█▄▆█▇▅▁▇▁▅▅▅▅▄
Other precision,▁▃▃▅▄▄▅▅▇▄▆▇█▅▆▆▇
Other recall,▁█▆▇▇█▆▇▇▆▇▇█▇▆▇█
Tumor precision,▁▆▆██▇█▇▇▅█▅▇▇▅▇▆
Tumor recall,▅▁▃▃▆▇▁▆▃▇▅█▄▆▄▆▇
loss,█▂▂▂▂▁▂▁▂▂▁▁▁▁▂▁▁
val_Inflamatory precision,▅▁▂▅█▇▇▄▅▃▄▄▆▁▄▂▄

0,1
Inflamatory precision,0.79873
Inflamatory recall,0.84166
Normal precision,0.7348
Normal recall,0.71625
Other precision,0.74092
Other recall,0.60952
Tumor precision,0.77699
Tumor recall,0.76863
loss,0.56141
val_Inflamatory precision,0.82955


EPOCH: 0
train_loss: 0.6707911467253984
train_precision: tensor([0.6811, 0.7584, 0.7304, 0.6541], device='cuda:0')
train_recall: tensor([0.6626, 0.8144, 0.7222, 0.4227], device='cuda:0')
val_loss: 0.643190831070145
val_precision: tensor([0.7210, 0.7607, 0.7529, 0.6127], device='cuda:0')
val_recall: tensor([0.6331, 0.8647, 0.7441, 0.5939], device='cuda:0')
Learning rate: 0.001
epoch 0 time:  0.36732950931667196
--------------------------------
EPOCH: 1
train_loss: 0.6592486793615763
train_precision: tensor([0.6838, 0.7618, 0.7372, 0.6609], device='cuda:0')
train_recall: tensor([0.6698, 0.8173, 0.7225, 0.4466], device='cuda:0')
val_loss: 0.6020309990478887
val_precision: tensor([0.6845, 0.8152, 0.7794, 0.6850], device='cuda:0')
val_recall: tensor([0.7337, 0.8123, 0.7400, 0.5630], device='cuda:0')
Learning rate: 0.001
epoch 1 time:  0.383525579683328
--------------------------------
EPOCH: 2
train_loss: 0.6496819525051856
train_precision: tensor([0.6920, 0.7651, 0.7431, 0.6673], device='c

KeyboardInterrupt: 