In [1]:
import os
import time
import numpy as np
import pandas as pd
import json


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

import matplotlib.pyplot as plt
from PIL import Image

import timm
import wandb

import torchmetrics
import wandb

from utils import load_image, LossMeter, save_metrics_to_json

  from .autonotebook import tqdm as notebook_tqdm


# hyperparameters

In [2]:
class config:
    # Hyperparameters
    LEARNING_RATE = 0.0001
    BATCH_SIZE = 128
    NUM_EPOCHS = 30
    
    # Other
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    NUM_WORKERS = 0

    DATA_PATH = './data/archive'
    IMG_SIZE = 112
    WANDB = False

# Dataset

In [3]:
# Generate data paths with labels
data_dir = config.DATA_PATH
filepaths = []
labels = []

folds = [f for f in os.listdir(data_dir) if not f.startswith('.')]
for fold in folds:
    foldpath = os.path.join(data_dir, fold)
    filelist = [f for f in os.listdir(foldpath) if not f.startswith('.')]
    for file in filelist:
        fpath = os.path.join(foldpath, file)
        
        filepaths.append(fpath)
        labels.append(fold)

# Concatenate data paths with labels into one dataframe
Fseries = pd.Series(filepaths, name='paths')
Lseries = pd.Series(labels, name='labels')
df = pd.concat([Fseries, Lseries], axis=1)

df

Unnamed: 0,paths,labels
0,./data/archive\no\1 no.jpeg,no
1,./data/archive\no\10 no.jpg,no
2,./data/archive\no\11 no.jpg,no
3,./data/archive\no\12 no.jpg,no
4,./data/archive\no\13 no.jpg,no
...,...,...
248,./data/archive\yes\Y95.jpg,yes
249,./data/archive\yes\Y96.jpg,yes
250,./data/archive\yes\Y97.JPG,yes
251,./data/archive\yes\Y98.JPG,yes


# Label encoding

In [4]:
 label_encoder = LabelEncoder()
 df['labels']= label_encoder.fit_transform(df['labels'])
 df

Unnamed: 0,paths,labels
0,./data/archive\no\1 no.jpeg,0
1,./data/archive\no\10 no.jpg,0
2,./data/archive\no\11 no.jpg,0
3,./data/archive\no\12 no.jpg,0
4,./data/archive\no\13 no.jpg,0
...,...,...
248,./data/archive\yes\Y95.jpg,1
249,./data/archive\yes\Y96.jpg,1
250,./data/archive\yes\Y97.JPG,1
251,./data/archive\yes\Y98.JPG,1


In [5]:
train_df, test_df = train_test_split(df, train_size=0.7, shuffle=True, random_state=123, stratify=df['labels'])

# Dataset

In [7]:
class BrainMRIdataset(Dataset):
    def __init__(self, df, transform=None):
        self.paths = df['paths']
        self.labels = df['labels']
        self.transform = transform
    
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths.iloc[idx]
        image = load_image(path, size=(config.IMG_SIZE, config.IMG_SIZE))
        if self.transform:
           image = self.transform(image)

        label = torch.tensor(self.labels.iloc[idx])
        
        return (image, label)

In [8]:
# Checking the dataset
train_loader = BrainMRIdataset(df=train_df, transform=transforms.ToTensor())

for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break


Image batch dimensions: torch.Size([1, 112, 112])
Image label dimensions: torch.Size([])


# Model

In [9]:
#model = timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k', pretrained=True, num_classes=2, in_chans=1)

In [10]:
def update_metrics(metrics, dataset_type, metric_name, value):
    if dataset_type not in metrics:
        metrics[dataset_type] = {}
    
    if metric_name not in metrics[dataset_type]:
        metrics[dataset_type][metric_name] = []

    metrics[dataset_type][metric_name].append(value)

class TensorEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.tolist()
        return super().default(obj)

In [11]:
class Trainer():
    def __init__(
        self, 
        model, 
        device, 
        optimizer, 
        criterion,
        epochs,
        loss_meter, 
    ):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.epochs = epochs
        self.loss_meter = loss_meter
        self.hist = {'val_loss':[],
                    'val_acc':[],
                    'val_f1':[],
                    'val_auroc':[],
                    'train_loss':[],
                    'train_acc':[],
                    'train_f1': [],
                    'train_auroc': [],
                    }
        
        self.best_test_auroc = -np.inf
        

    def fit(self, train_loader, test_loader, save_path, patience = 0):
        train_time = time.time()
        
        for epoch in range(self.epochs):
            t = time.time()
            self.model.train()
            train_loss = self.loss_meter()
            train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(self.device)
            train_f1 = torchmetrics.F1Score(task="multiclass", num_classes=2, average='macro').to(self.device)
            train_auroc = torchmetrics.AUROC(task="multiclass", num_classes=2).to(self.device)
            
            for idx, (images, labels) in enumerate(train_loader):
                images = images.to(self.device)
                labels = labels.to(self.device)
                
                logits = self.model(images)
                #print(logits)
                labels=labels.to(torch.int64)
                
                loss = self.criterion(logits, labels)
                
                train_loss.update(loss.detach().item())
                train_acc.update(logits.detach(), labels)
                train_f1.update(logits.detach(), labels)
                train_auroc.update(logits.detach(), labels)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()  
                
            _loss = train_loss.avg
            _acc = train_acc.compute()
            _f1 = train_f1.compute()
            _roc = train_auroc.compute()

            if config.WANDB:
                wandb.log({'train loss': _loss,
                        'train acc': _acc,
                        'train f1_score': _f1,
                        'train AUROC': _roc
                        })

            train_acc.reset()
            train_f1.reset()
            train_auroc.reset()
            
            self.hist['train_loss'].append(_loss)
            self.hist['train_acc'].append(_acc)
            self.hist['train_f1'].append(_f1)
            self.hist['train_auroc'].append(_roc)
            
            print(f' Train Epoch: {epoch+1}/{self.epochs} | Loss: {_loss:.5f} | Accuracy: {_acc:.4f}% | F1 Score: {_f1:.4f} | AUROC: {_roc:.4f} | Time: {time.time() - t}')
            
            val_loss, val_acc, val_f1, val_auroc = self.validate(test_loader, save_path)
            
            self.hist['val_loss'].append(val_loss)
            self.hist['val_acc'].append(val_acc)
            self.hist['val_f1'].append(val_f1)
            self.hist['val_auroc'].append(val_auroc)
            
        
        avg_loss = torch.mean(torch.tensor(self.hist['train_loss']))
        avg_acc = torch.mean(torch.tensor(self.hist['train_acc']))
        avg_f1 = torch.mean(torch.tensor(self.hist['train_f1']))
        avg_auroc = torch.mean(torch.tensor(self.hist['train_auroc']))

        print(f'Training Time: {(time.time() - train_time) // 60:.0f}m {(time.time() - train_time) % 60:.0f}s | Avg Loss: {avg_loss:.5f} | Avg Accuracy: {avg_acc:.3f}% | Avg F1 Score: {avg_f1:.4f} | Avg AUROC:{avg_auroc:.4f}')
        
            
    
    def validate(self, test_loader, save_path):
        test_time = time.time()
        self.model.eval()

        val_loss = self.loss_meter()    
        val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(self.device)
        val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=2, average='macro').to(self.device)       
        val_auroc = torchmetrics.AUROC(task="multiclass", num_classes=2).to(self.device)

        test_pred = []
        test_targets = []
        preds = []        
        for idx, (images, labels) in enumerate(test_loader):
            with torch.no_grad():
                images = images.to(self.device)
                labels = labels.to(torch.int64).to(self.device)
                
                logits = self.model(images)
                
                loss = self.criterion(logits, labels)
                
                val_loss.update(loss.detach().item())
                
                test_targets.append(labels)
                preds.append(logits.detach())
                
                
        test_targets = torch.cat(test_targets).flatten()
        preds = torch.cat(preds)

        
        loss = val_loss.avg
        acc = val_acc(preds, test_targets)
        f1 = val_f1(preds, test_targets)
        auroc = val_auroc(preds, test_targets)             
                
        if auroc > self.best_test_auroc: 
            self.best_test_auroc = auroc
            
            base_dir = "./data/pretrain_convnext/"
            if not os.path.exists(base_dir):
                os.mkdir(base_dir)

            torch.save({"model_state_dict": self.model.state_dict(),
                        "best_auroc": self.best_test_auroc,
                        },
                        save_path)
                    
            print(f'Checkpoint saved at {save_path} '
                f'| Test acc: {acc :.2f}% '
                f'| Test F1: {f1 :.3f}% '
                f'| Best AUROC: {self.best_test_auroc:.3f}')           
            
        val_acc.reset()
        val_f1.reset()
        val_auroc.reset()
            
        if config.WANDB:
            wandb.log({'val loss':loss,
                    'val acc': acc,
                    'val f1_score': f1,
                    'val AUROC': auroc
                    })

        print(f"Validation Epoch: {(time.time() - test_time) // 60:.0f}m {(time.time() - test_time) % 60:.0f}s | Accuracy: {acc:.2f}% | F1 Score: {f1:.4f} | AUROC: {auroc:.4f}")
        return loss, acc.item(), f1.item(), auroc.item()

In [12]:
def train():
    
    start_time = time.time()

    '''if config.TRANSFORM:
        train_transform = .Compose([
                                    #A.HorizontalFlip(p=0.5),
                                    A.Rotate(limit=10, p=0.5)
                                ])
    else:
        train_transform = None'''

    metrics = {
    'train': {'loss': [], 'acc': [], 'f1': [], 'auroc': []},
    'valid': {'loss': [], 'acc': [], 'f1': [], 'auroc': []},
    }

    train_set = BrainMRIdataset(
                    train_df,
                    transform=transforms.ToTensor()
                        )
                    
    test_set = BrainMRIdataset(
                    test_df,
                    transform=transforms.ToTensor()
                        )

    train_loader = DataLoader(
                train_set,    
                batch_size=config.BATCH_SIZE,
                shuffle=True,
                num_workers=config.NUM_WORKERS,
            )
    
    test_loader = DataLoader(
                test_set,    
                batch_size=config.BATCH_SIZE,
                shuffle=False,
                num_workers=config.NUM_WORKERS,
            )
    
    model = timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k', pretrained=True, num_classes=2, in_chans=1)
    
    model = model.to(config.DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    criterion = F.cross_entropy
    trainer = Trainer(
                model, 
                config.DEVICE, 
                optimizer, 
                criterion,
                config.NUM_EPOCHS,
                LossMeter,
                )
    

    trainer.fit(train_loader,
                test_loader,
                save_path = f'./data/pretrain_convnext/ConvNext_finetuned_model_best_auroc.pth',
                )

    for value in trainer.hist['train_loss']:
        update_metrics(metrics, 'train', 'loss', value)

    for value in trainer.hist['train_acc']:
        update_metrics(metrics, 'train', 'acc', value)

    for value in trainer.hist['train_f1']:
        update_metrics(metrics, 'train', 'f1', value)

    for value in trainer.hist['train_auroc']:
        update_metrics(metrics, 'train', 'auroc', value)

    for value in trainer.hist['val_loss']:
        update_metrics(metrics, 'valid', 'loss', value)

    for value in trainer.hist['val_acc']:
        update_metrics(metrics, 'valid', 'acc', value)

    for value in trainer.hist['val_f1']:
        update_metrics(metrics, 'valid', 'f1', value)

    for value in trainer.hist['val_auroc']:
        update_metrics(metrics, 'valid', 'auroc', value)
    
    json_path = save_metrics_to_json(metrics, 'ft_convnext')

    elapsed_time = time.time() - start_time


    print('\nTraining complete in {:.0f}m {:.0f}s'.format(elapsed_time // 60, elapsed_time % 60))
    
    if config.WANDB:
        wandb.finish()

In [13]:
train()

 Train Epoch: 1/30 | Loss: 0.75728 | Accuracy: 0.4915% | F1 Score: 0.4794 | AUROC: 0.5059 | Time: 2.8602168560028076
Checkpoint saved at ./data/pretrain_convnext/ConvNext_finetuned_model_best_auroc.pth | Test acc: 0.76% | Test F1: 0.695% | Best AUROC: 0.957
Validation Epoch: 0m 0s | Accuracy: 0.76% | F1 Score: 0.6946 | AUROC: 0.9574
 Train Epoch: 2/30 | Loss: 0.38570 | Accuracy: 0.8531% | F1 Score: 0.8398 | AUROC: 0.9452 | Time: 0.8585059642791748
Validation Epoch: 0m 0s | Accuracy: 0.86% | F1 Score: 0.8532 | AUROC: 0.9479
 Train Epoch: 3/30 | Loss: 0.22405 | Accuracy: 0.9492% | F1 Score: 0.9472 | AUROC: 0.9914 | Time: 0.8619987964630127
Checkpoint saved at ./data/pretrain_convnext/ConvNext_finetuned_model_best_auroc.pth | Test acc: 0.84% | Test F1: 0.825% | Best AUROC: 0.971
Validation Epoch: 0m 0s | Accuracy: 0.84% | F1 Score: 0.8246 | AUROC: 0.9714
 Train Epoch: 4/30 | Loss: 0.11587 | Accuracy: 0.9661% | F1 Score: 0.9638 | AUROC: 1.0000 | Time: 0.8490002155303955
Validation Epoch: 0