In [1]:
batch_size = 279
learning_rate = 0.003483773601114058
dropout_rate = 0.2757656331185102

In [2]:
import time
import sys
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import ImageDataset
import torch.nn.functional as F
from torchmetrics import Precision, Recall
from torchvision.models import resnet18
import warnings
from collections import defaultdict
import wandb
import datetime
import os
import torchvision.models as models
import torch.nn as nn
#intecubic interpol

class EfficientNetB0(nn.Module):
    def __init__(self, num_classes=4):
        super(EfficientNetB0, self).__init__()
        self.base_model = models.efficientnet_b0(pretrained=False)
        num_ftrs = self.base_model.classifier[1].in_features
        self.base_model.classifier[1] = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.base_model(x)


warnings.filterwarnings('ignore')


run_name = f'efficient_net_b0_after_tune2{datetime.datetime.now()}'
run_path = f'training_checkpoints/{run_name}'

wandb.init(project="cells", 
           entity="adamsoja",
          name=run_name)

import random
random.seed(2233)
torch.manual_seed(2233)

mean = [0.5006, 0.3526, 0.5495]
std = [0.1493, 0.1341, 0.1124]


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(2233)



from albumentations import (
    Compose,
    Resize,
    OneOf,
    RandomBrightness,
    RandomContrast,
    MotionBlur,
    MedianBlur,
    GaussianBlur,
    VerticalFlip,
    HorizontalFlip,
    ShiftScaleRotate,
    Normalize,
)

transform = Compose(
    [
        Normalize(mean=mean, std=std),
        OneOf([RandomBrightness(limit=0.1, p=1), RandomContrast(limit=0.1, p=0.8)]),
        OneOf([MotionBlur(blur_limit=3), MedianBlur(blur_limit=3), GaussianBlur(blur_limit=3),], p=0.7,),
        VerticalFlip(p=0.5),
        HorizontalFlip(p=0.5),
    ]
)

transform_test = Compose(
    [Normalize(mean=mean, std=std)]
)

class MyModel(nn.Module):
    def __init__(self, model, learning_rate, weight_decay):
        super(MyModel, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.criterion = nn.CrossEntropyLoss()
        self.metric_precision = Precision(task="multiclass", num_classes=4, average=None).to('cuda')
        self.metric_recall = Recall(task="multiclass", num_classes=4, average=None).to('cuda')
        self.train_loss = []
        self.valid_loss = []
        self.precision_per_epochs = []
        self.recall_per_epochs = []

        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode="min", factor=0.1, patience=7, min_lr=5e-6)
        self.step = 0

    
    def forward(self, x):
        return self.model(x)

    def train_one_epoch(self, trainloader):
        self.step += 1
        self.train()
        for batch_idx, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            _, preds = torch.max(outputs, 1)
            _, labels = torch.max(labels, 1)
            self.metric_precision(preds, labels)
            self.metric_recall(preds, labels)
            self.train_loss.append(loss.item())


        

        
        avg_loss = np.mean(self.train_loss)
        self.train_loss.clear()
        precision = self.metric_precision.compute()
        recall = self.metric_recall.compute()
        self.precision_per_epochs.append(precision)
        self.recall_per_epochs.append(recall)
        print(f'train_loss: {avg_loss}')
        print(f'train_precision: {precision}')
        print(f'train_recall: {recall}')

        wandb.log({'loss': avg_loss},step=self.step)
        wandb.log({'Normal precision': precision[0].item()},step=self.step)
        wandb.log({'Inflamatory precision': precision[1].item()},step=self.step)
        wandb.log({'Tumor precision': precision[2].item()},step=self.step)
        wandb.log({'Other precision': precision[3].item()},step=self.step)


        wandb.log({'Normal recall': recall[0].item()},step=self.step)
        wandb.log({'Inflamatory recall': recall[1].item()},step=self.step)
        wandb.log({'Tumor recall': recall[2].item()},step=self.step)
        wandb.log({'Other recall': recall[3].item()},step=self.step)
        
        
        self.metric_precision.reset()
        self.metric_recall.reset()


    

    def evaluate(self, testloader):
        self.eval()
        with torch.no_grad():
            for batch_idx, (inputs, labels) in enumerate(testloader):
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                _, labels = torch.max(labels, 1)
                self.metric_precision(preds, labels)
                self.metric_recall(preds, labels)
                self.valid_loss.append(loss.item())
    
        avg_loss = np.mean(self.valid_loss)
        self.scheduler.step(avg_loss)
        self.valid_loss.clear()
        precision = self.metric_precision.compute()
        recall = self.metric_recall.compute()
        print(f'val_loss: {avg_loss}')
        print(f'val_precision: {precision}')
        print(f'val_recall: {recall}')
        self.metric_precision.reset()
        self.metric_recall.reset()

        wandb.log({'val_loss': avg_loss}, step=self.step)
        
        wandb.log({'val_Normal precision': precision[0].item()},step=self.step)
        wandb.log({'val_Inflamatory precision': precision[1].item()},step=self.step)
        wandb.log({'val_Tumor precision': precision[2].item()},step=self.step)
        wandb.log({'val_Other precision': precision[3].item()},step=self.step)


        wandb.log({'val_Normal recall': recall[0].item()},step=self.step)
        wandb.log({'val_Inflamatory recall': recall[1].item()},step=self.step)
        wandb.log({'val_Tumor recall': recall[2].item()},step=self.step)
        wandb.log({'val_Other recall': recall[3].item()},step=self.step)


        for param_group in self.optimizer.param_groups:
            print(f"Learning rate: {param_group['lr']}")
        return avg_loss

torch.cuda.empty_cache()
trainset = ImageDataset(data_path='train_data', transform=transform, reduce=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=5)

testset = ImageDataset(data_path='validation_data', transform=transform_test, reduce=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=3)

weight_decay = 0.00001

model = EfficientNetB0()
model = model.to('cuda')


my_model = MyModel(model=model, learning_rate=learning_rate, weight_decay=weight_decay)
my_model = my_model.to('cuda')

num_epochs = 100
early_stop_patience = 15
best_val_loss = float('inf')
best_model_state_dict = None

for epoch in range(num_epochs):
    print('========================================')
    print(f'EPOCH: {epoch}') 
    time_start = time.perf_counter()
    my_model.train_one_epoch(trainloader)
    val_loss = my_model.evaluate(testloader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state_dict = my_model.state_dict()
        torch.save(best_model_state_dict, f'{run_path}.pth')
        
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= early_stop_patience:
        print(f"Early stopping at epoch {epoch} with best validation loss {best_val_loss}")
        break
    time_epoch = time.perf_counter() - time_start
    print(f'epoch {epoch} time:  {time_epoch/60}')
    print('--------------------------------')

# Load the best model state dict
print(f'{run_path}.pth')
my_model.load_state_dict(torch.load(f'{run_path}.pth'))

[34m[1mwandb[0m: Currently logged in as: [33madamsoja[0m. Use [1m`wandb login --relogin`[0m to force relogin


EPOCH: 0
train_loss: 1.0904135073142789
train_precision: tensor([0.3979, 0.5591, 0.4624, 0.0084], device='cuda:0')
train_recall: tensor([2.4603e-01, 7.5420e-01, 4.9935e-01, 2.5202e-04], device='cuda:0')
val_loss: 1.0249288244680925
val_precision: tensor([0.4883, 0.5935, 0.4355, 0.0000], device='cuda:0')
val_recall: tensor([0.0083, 0.7432, 0.7536, 0.0000], device='cuda:0')
Learning rate: 0.003483773601114058
epoch 0 time:  0.2878741121499994
--------------------------------
EPOCH: 1
train_loss: 1.023477902298031
train_precision: tensor([0.4299, 0.5824, 0.5276, 0.0000], device='cuda:0')
train_recall: tensor([0.3286, 0.7868, 0.4978, 0.0000], device='cuda:0')
val_loss: 1.0224532591296243
val_precision: tensor([0.4859, 0.6222, 0.6302, 0.0000], device='cuda:0')
val_recall: tensor([0.5060, 0.7731, 0.4980, 0.0000], device='cuda:0')
Learning rate: 0.003483773601114058
epoch 1 time:  0.28538948363333627
--------------------------------
EPOCH: 2
train_loss: 0.9460268854380012
train_precision: ten

<All keys matched successfully>

In [4]:
from sklearn.metrics import confusion_matrix, classification_report
from torch.utils.data import DataLoader
import numpy as np
import torch 
import torch.nn as nn
my_model.load_state_dict(torch.load(f'{run_path}.pth'))

def test_report(model, dataloader):
    """Prints confusion matrix for testing dataset
    dataloader should be of batch_size=1."""

    y_pred = []
    y_test = []
    model.eval()
    with torch.no_grad():
        for data, label in dataloader:
            output = model(data)
            label = label.numpy()
            output = output.numpy()
            y_pred.append(np.argmax(output))
            y_test.append(np.argmax(label))
        print(confusion_matrix(y_test, y_pred))
        print(classification_report(y_test, y_pred))

testset =ImageDataset(data_path='test_data', transform=transform_test, reduce=True)
dataloader = DataLoader(testset, batch_size=1, shuffle=True)

test_report(my_model.to('cpu'), dataloader)

[[602  88 151   4]
 [ 79 767  43   0]
 [114  41 752   1]
 [  7   5  10  36]]
              precision    recall  f1-score   support

           0       0.75      0.71      0.73       845
           1       0.85      0.86      0.86       889
           2       0.79      0.83      0.81       908
           3       0.88      0.62      0.73        58

    accuracy                           0.80      2700
   macro avg       0.82      0.76      0.78      2700
weighted avg       0.80      0.80      0.80      2700

