In [1]:
#Params = [batch_size: 103, learning_rate: 0.00023844542114720738, dropout_rate: 0.11187612585438168, num_heads: 9, patch_size: 4, num_layers: 4]

batch_size = 300
lr = 0.00023
dropout_rate = 0.11
num_heads = 9
patch_size = 4
num_layers = 4

In [2]:
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchmetrics import Precision, Recall
from dataset import ImageDataset
import numpy as np
import datetime
import random
import time
import time
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import ImageDataset
import torch.nn.functional as F
from torchmetrics import Precision, Recall
from torchvision.models import resnet50
import warnings
from collections import defaultdict
import wandb
import datetime
import os
from vit.vit import VisionTransformer

warnings.filterwarnings('ignore')
#intecubic interpol

run_name = f'vit_after_tuning_reduced{datetime.datetime.now()}'
run_path = f'training_checkpoints/{run_name}'

wandb.init(project="cells", 
           entity="adamsoja",
          name=run_name)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(2233)

mean = [0.5006, 0.3526, 0.5495]
std = [0.1493, 0.1341, 0.1124]

from albumentations import (
    Compose,
    Resize,
    OneOf,
    RandomBrightness,
    RandomContrast,
    MotionBlur,
    MedianBlur,
    GaussianBlur,
    VerticalFlip,
    HorizontalFlip,
    ShiftScaleRotate,
    Normalize,
)

transform = Compose(
    [
        Normalize(mean=mean, std=std),
        OneOf([RandomBrightness(limit=0.1, p=1), RandomContrast(limit=0.1, p=0.8)]),
        OneOf([MotionBlur(blur_limit=3), MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3),], p=0.7,),
        VerticalFlip(p=0.5),
        HorizontalFlip(p=0.5),
    ]
)

transform_test = Compose(
    [Normalize(mean=mean, std=std)]
)

def get_valid_embedding_dims(num_heads):
    return [dim for dim in range(32, 257) if dim % num_heads == 0]
valid_embedding_dims = get_valid_embedding_dims(num_heads)
embedding_dim = valid_embedding_dims[len(valid_embedding_dims)//2]

  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Currently logged in as: [33madamsoja[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
model = VisionTransformer(image_size=32, 
                      in_channels=3, 
                      num_classes=4, 
                      hidden_dims=[32], 
                      dropout_rate=dropout_rate,
                      embedding_dim=embedding_dim,
                      patch_size=patch_size,
                      num_layers=3,
                      num_heads=num_heads,
                      use_linear_patch=False,
                      use_conv_stem=True,
                      use_conv_patch=False)

In [4]:
class MyModel(nn.Module):
    def __init__(self, model, learning_rate):
        super(MyModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()
        self.metric_precision = Precision(task="multiclass", num_classes=4, average=None).to('cuda')
        self.metric_recall = Recall(task="multiclass", num_classes=4, average=None).to('cuda')
        self.train_loss = []
        self.valid_loss = []
        self.precision_per_epochs = []
        self.recall_per_epochs = []

        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode="min", factor=0.1, patience=7, min_lr=5e-6, verbose=True)
        self.step = 0

    
    def forward(self, x):
        return self.model(x)

    def train_one_epoch(self, trainloader):
        self.step += 1
        self.train()
        for batch_idx, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            _, preds = torch.max(outputs, 1)
            _, labels = torch.max(labels, 1)
            self.metric_precision(preds, labels)
            self.metric_recall(preds, labels)
            self.train_loss.append(loss.item())


        

        
        avg_loss = np.mean(self.train_loss)
        self.train_loss.clear()
        precision = self.metric_precision.compute()
        recall = self.metric_recall.compute()
        self.precision_per_epochs.append(precision)
        self.recall_per_epochs.append(recall)
        print(f'train_loss: {avg_loss}')
        print(f'train_precision: {precision}')
        print(f'train_recall: {recall}')

        wandb.log({'loss': avg_loss}, step=self.step)
        
        # Logowanie precision dla każdej klasy
        wandb.log({'Normal precision': precision[0].item()}, step=self.step)
        wandb.log({'Inflamatory precision': precision[1].item()}, step=self.step)
        wandb.log({'Tumor precision': precision[2].item()}, step=self.step)
        wandb.log({'Other precision': precision[3].item()}, step=self.step)
        
        # Logowanie recall dla każdej klasy
        wandb.log({'Normal recall': recall[0].item()}, step=self.step)
        wandb.log({'Inflamatory recall': recall[1].item()}, step=self.step)
        wandb.log({'Tumor recall': recall[2].item()}, step=self.step)
        wandb.log({'Other recall': recall[3].item()}, step=self.step)
        
        # Obliczanie głównych metryk
        main_metrics_precision = (precision[0].item() + precision[1].item() + precision[2].item() + precision[3].item()) / 4
        main_metrics_recall = (recall[0].item() + recall[1].item() + recall[2].item() + recall[3].item()) / 4
        
        # Logowanie głównych metryk
        wandb.log({'main_metrics_precision': main_metrics_precision}, step=self.step)
        wandb.log({'main_metrics_recall': main_metrics_recall}, step=self.step)

        precision_ = main_metrics_precision
        recall_ = main_metrics_recall
        
        if (precision_ + recall_) > 0:
            f1_score_val = 2 * (precision_ * recall_) / (precision_ + recall_)
        else:
            f1_score_val = 0
        
        wandb.log({'f1_score_val': f1_score_val}, step=self.step)

        
        
        self.metric_precision.reset()
        self.metric_recall.reset()


    

    def evaluate(self, testloader):
        self.eval()
        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(testloader):
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                _, labels = torch.max(labels, 1)
                self.metric_precision(preds, labels)
                self.metric_recall(preds, labels)
                self.valid_loss.append(loss.item())
    
        avg_loss = np.mean(self.valid_loss)
        self.scheduler.step(avg_loss)
        self.valid_loss.clear()
        precision = self.metric_precision.compute()
        recall = self.metric_recall.compute()
        print(f'val_loss: {avg_loss}')
        print(f'val_precision: {precision}')
        print(f'val_recall: {recall}')
        self.metric_precision.reset()
        self.metric_recall.reset()
    
        main_metrics_precision = (precision[0].item() + precision[1].item() + precision[2].item() + precision[3].item()) / 4
        
        main_metrics_recall = (recall[0].item() + recall[1].item() + recall[2].item() + recall[3].item()) / 4
        
        wandb.log({'val_loss': avg_loss}, step=self.step)
        
        wandb.log({'val_Normal precision': precision[0].item()}, step=self.step)
        wandb.log({'val_Inflamatory precision': precision[1].item()}, step=self.step)
        wandb.log({'val_Tumor precision': precision[2].item()}, step=self.step)
        wandb.log({'val_Other precision': precision[3].item()}, step=self.step)
        
        wandb.log({'val_Normal recall': recall[0].item()}, step=self.step)
        wandb.log({'val_Inflamatory recall': recall[1].item()}, step=self.step)
        wandb.log({'val_Tumor recall': recall[2].item()}, step=self.step)
        wandb.log({'val_Other recall': recall[3].item()}, step=self.step)
        
        wandb.log({'val_main_metrics_precision': main_metrics_precision}, step=self.step)
        wandb.log({'val_main_metrics_recall': main_metrics_recall}, step=self.step)

        precision_ = main_metrics_precision
        recall_ = main_metrics_recall
        
        if (precision_ + recall_) > 0:
            f1_score_val = 2 * (precision_ * recall_) / (precision_ + recall_)
        else:
            f1_score_val = 0
        
        wandb.log({'f1_score_val': f1_score_val}, step=self.step)
        
        

        for param_group in self.optimizer.param_groups:
            print(f"Learning rate: {param_group['lr']}")
        return avg_loss

In [5]:
my_model = MyModel(model=model, learning_rate=lr)
my_model = my_model.to('cuda')



trainset = ImageDataset(data_path='train_data', transform=transform, reduce=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=3)

testset = ImageDataset(data_path='validation_data', transform=transform_test, reduce=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

num_epochs = 100
early_stop_patience = 15
best_val_loss = float('inf')
best_model_state_dict = None

for epoch in range(num_epochs):
    print('========================================')
    print(f'EPOCH: {epoch}') 
    time_start = time.perf_counter()
    my_model.train_one_epoch(trainloader)
    val_loss = my_model.evaluate(testloader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state_dict = my_model.state_dict()
        torch.save(best_model_state_dict, f'{run_path}.pth')
        
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= early_stop_patience:
        print(f"Early stopping at epoch {epoch} with best validation loss {best_val_loss}")
        break
    time_epoch = time.perf_counter() - time_start
    print(f'epoch {epoch} time:  {time_epoch}')
    print('--------------------------------')

# Load the best model state dict
print(f'{run_path}.pth')
my_model.load_state_dict(torch.load(f'{run_path}.pth'))

EPOCH: 0
train_loss: 0.9282462659382051
train_precision: tensor([0.5405, 0.6304, 0.5762, 0.6645], device='cuda:0')
train_recall: tensor([0.3815, 0.7644, 0.6510, 0.0260], device='cuda:0')
val_loss: 0.7855055979768136
val_precision: tensor([0.6104, 0.6772, 0.7361, 0.5891], device='cuda:0')
val_recall: tensor([0.5666, 0.8495, 0.6308, 0.2468], device='cuda:0')
Learning rate: 0.00023
epoch 0 time:  125.27327691300161
--------------------------------
EPOCH: 1
train_loss: 0.8090495446997304
train_precision: tensor([0.5954, 0.6984, 0.6714, 0.6730], device='cuda:0')
train_recall: tensor([0.5498, 0.7816, 0.6754, 0.1436], device='cuda:0')
val_loss: 0.7428478674781054
val_precision: tensor([0.6640, 0.6844, 0.7264, 0.5464], device='cuda:0')
val_recall: tensor([0.5270, 0.8721, 0.6902, 0.3355], device='cuda:0')
Learning rate: 0.00023
epoch 1 time:  124.94000907000009
--------------------------------
EPOCH: 2
train_loss: 0.7729893701691781
train_precision: tensor([0.6132, 0.7137, 0.6864, 0.6764], devi

KeyboardInterrupt: 

In [6]:
from sklearn.metrics import confusion_matrix, classification_report
from torch.utils.data import DataLoader
import numpy as np
import torch 
import torch.nn as nn
my_model.load_state_dict(torch.load(f'{run_path}.pth'))

def test_report(model, dataloader):
    """Prints confusion matrix for testing dataset
    dataloader should be of batch_size=1."""

    y_pred = []
    y_test = []
    model.eval()
    with torch.no_grad():
        for data, label in dataloader:
            output = model(data)
            label = label.numpy()
            output = output.numpy()
            y_pred.append(np.argmax(output))
            y_test.append(np.argmax(label))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))

testset =ImageDataset(data_path='test_data', transform=transform_test, reduce=True)
dataloader = DataLoader(testset, batch_size=1, shuffle=True)

test_report(my_model.to('cpu'), dataloader)

[[589  95 157   4]
 [ 86 764  34   5]
 [102  45 757   4]
 [  3   7  12  36]]
              precision    recall  f1-score   support

           0       0.76      0.70      0.72       845
           1       0.84      0.86      0.85       889
           2       0.79      0.83      0.81       908
           3       0.73      0.62      0.67        58

    accuracy                           0.79      2700
   macro avg       0.78      0.75      0.76      2700
weighted avg       0.79      0.79      0.79      2700



In [None]:
len(testset)