In [None]:
import json
import os

import pandas as pd
import pytorch_lightning as pl
import timm
import torch
import wandb
from PIL import Image
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import StratifiedKFold
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
torch.use_deterministic_algorithms(True)

In [None]:
DEBUG = False

In [None]:
CONFIG_JSON_PATH = "../../config/config.json"

with open(CONFIG_JSON_PATH) as f:
    cfg = json.load(f)

cfg["params"]["model_name"] = "seresnext50_32x4d.racm_in1k"

LABEL_TRAIN_PATH = cfg["dataset"]["label_train"]
TRAIN_CROP_DATASET_DIR = cfg["dataset"]["train_crop"]

MODEL_NAME = cfg["params"]["model_name"]
BATCH_SIZE = cfg["params"]["batch_size"]
NUM_WORKERS = cfg["params"]["num_workers"]
IMG_SIZE = cfg["params"]["img_size"]
TEST_SIZE = cfg["params"]["test_size"]
SEED = 2027
LEARNING_RATE = cfg["params"]["learning_rate"]

MAX_EPOCH = 20

N_SPLITS = 5

In [None]:
cfg

In [None]:
pl.seed_everything(SEED, workers=True)

In [None]:
df = pd.read_csv(LABEL_TRAIN_PATH)
df.head()

### dataset

In [None]:
class PupilCSV(Dataset):
    def __init__(self, df: pd.DataFrame, transform=None):
        self.df = df.reset_index(drop=True)
        self.df['age_normalized'] = (self.df['age'] - self.df['age'].min()) / (
                self.df['age'].max() - self.df['age'].min())
        self.df['AC_normalized'] = (self.df['AC'] - self.df['AC'].min()) / (self.df['AC'].max() - self.df['AC'].min())
        self.df['SBP_normalized'] = (self.df['SBP'] - self.df['SBP'].min()) / (
                self.df['SBP'].max() - self.df['SBP'].min())
        self.df['DBP_normalized'] = (self.df['DBP'] - self.df['DBP'].min()) / (
                self.df['DBP'].max() - self.df['DBP'].min())
        self.df['HDLC_normalized'] = (self.df['HDLC'] - self.df['HDLC'].min()) / (
                self.df['HDLC'].max() - self.df['HDLC'].min())
        self.df['TG_normalized'] = (self.df['TG'] - self.df['TG'].min()) / (self.df['TG'].max() - self.df['TG'].min())
        self.df['BS_normalized'] = (self.df['BS'] - self.df['BS'].min()) / (self.df['BS'].max() - self.df['BS'].min())
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'path']
        label = self.df.loc[idx, 'METS']
        age = self.df.loc[idx, 'age_normalized']
        ac = self.df.loc[idx, 'AC_normalized']
        sbp = self.df.loc[idx, 'SBP_normalized']
        dbp = self.df.loc[idx, 'DBP_normalized']
        hdlc = self.df.loc[idx, 'HDLC_normalized']
        tg = self.df.loc[idx, 'TG_normalized']
        bs = self.df.loc[idx, 'BS_normalized']

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)
        age = torch.tensor(age, dtype=torch.float)
        ac = torch.tensor(ac, dtype=torch.float)
        sbp = torch.tensor(sbp, dtype=torch.float)
        dbp = torch.tensor(dbp, dtype=torch.float)
        hdlc = torch.tensor(hdlc, dtype=torch.float)
        tg = torch.tensor(tg, dtype=torch.float)
        bs = torch.tensor(bs, dtype=torch.float)

        return image, label, age, ac, sbp, dbp, hdlc, tg, bs

In [None]:
class PupilTestCSV(Dataset):
    def __init__(self, df: pd.DataFrame, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'path']
        label = self.df.loc[idx, 'METS']

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)

        return image, label

### Dataloader

In [None]:
class PupilDataModule(pl.LightningDataModule):
    def __init__(
            self,
            train_df,
            val_df,
            test_df,
            batch_size: int = 32,
            num_workers: int = 4,
            img_size: int = 224,
    ):
        super().__init__()
        self.train_df = train_df
        self.val_df = val_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.img_size = img_size

        self.train_transforms = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),

            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),

            transforms.RandomRotation(degrees=45),

            transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1),
                shear=10,
            ),

            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.1
            ),

            transforms.RandomPerspective(
                distortion_scale=0.5,
                p=0.5,
                fill=0
            ),

            transforms.ToTensor(),

            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),

            transforms.RandomErasing(
                p=0.5,
                scale=(0.02, 0.33),
                ratio=(0.3, 3.3),
                value='random'
            ),
        ])

        self.val_transforms = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            ),
        ])

    def prepare_data(self):
        pass

    def setup(self, stage=None):
        self.train_dataset = PupilCSV(
            self.train_df,
            transform=self.train_transforms
        )

        self.val_dataset = PupilCSV(
            self.val_df,
            transform=self.val_transforms
        )

        if self.test_df:
            self.test_dataset = PupilTestCSV(
                self.test_df,
                transform=self.val_transforms,
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

### model

In [None]:
class PupilModel(pl.LightningModule):
    def __init__(
            self,
            model_name: str = MODEL_NAME,
            pretrained: bool = True,
            num_classes: int = 2,
            learning_rate: float = 1e-3,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.model = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes
        )

        self.criterion = nn.CrossEntropyLoss()

        self.aux_criterion = nn.L1Loss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y, age, ac, sbp, dbp, hdlc, tg, bs = batch
        logits = self(x)
        logits_for_y, logits_for_age, logits_for_ac, logits_for_sbp, logits_for_dbp, logits_for_hdlc, logits_for_tg, logits_for_bs = torch.split(
            logits, [2, 1, 1, 1, 1, 1, 1, 1], dim=1)
        loss_y = self.criterion(logits_for_y, y)

        logits_for_age = logits_for_age.squeeze()
        loss_age = self.aux_criterion(logits_for_age, age)

        logits_for_ac = logits_for_ac.squeeze()
        loss_ac = self.aux_criterion(logits_for_ac, ac)

        logits_for_sbp = logits_for_sbp.squeeze()
        loss_sbp = self.aux_criterion(logits_for_sbp, sbp)

        logits_for_dbp = logits_for_dbp.squeeze()
        loss_dbp = self.aux_criterion(logits_for_dbp, dbp)

        logits_for_hdlc = logits_for_hdlc.squeeze()
        loss_hdlc = self.aux_criterion(logits_for_hdlc, hdlc)

        logits_for_tg = logits_for_tg.squeeze()
        loss_tg = self.aux_criterion(logits_for_tg, tg)

        logits_for_bs = logits_for_bs.squeeze()
        loss_bs = self.aux_criterion(logits_for_bs, bs)

        total_loss = (loss_y + loss_age + loss_ac + loss_sbp + loss_dbp + loss_hdlc + loss_tg + loss_bs)

        preds = torch.argmax(logits_for_y, dim=1)
        acc = (preds == y).float().mean()
        self.log('train_loss', total_loss, on_step=False, on_epoch=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        return total_loss

    def validation_step(self, batch, batch_idx):
        x, y, age, ac, sbp, dbp, hdlc, tg, bs = batch
        logits = self(x)
        logits_for_y, logits_for_age, logits_for_ac, logits_for_sbp, logits_for_dbp, logits_for_hdlc, logits_for_tg, logits_for_bs = torch.split(
            logits, [2, 1, 1, 1, 1, 1, 1, 1], dim=1)
        loss_y = self.criterion(logits_for_y, y)

        logits_for_age = logits_for_age.squeeze()
        loss_age = self.aux_criterion(logits_for_age, age)

        logits_for_ac = logits_for_ac.squeeze()
        loss_ac = self.aux_criterion(logits_for_ac, ac)

        logits_for_sbp = logits_for_sbp.squeeze()
        loss_sbp = self.aux_criterion(logits_for_sbp, sbp)

        logits_for_dbp = logits_for_dbp.squeeze()
        loss_dbp = self.aux_criterion(logits_for_dbp, dbp)

        logits_for_hdlc = logits_for_hdlc.squeeze()
        loss_hdlc = self.aux_criterion(logits_for_hdlc, hdlc)

        logits_for_tg = logits_for_tg.squeeze()
        loss_tg = self.aux_criterion(logits_for_tg, tg)

        logits_for_bs = logits_for_bs.squeeze()
        loss_bs = self.aux_criterion(logits_for_bs, bs)

        total_loss = (loss_y + loss_age + loss_ac + loss_sbp + loss_dbp + loss_hdlc + loss_tg + loss_bs)

        preds = torch.argmax(logits_for_y, dim=1)
        acc = (preds == y).float().mean()
        self.log('val_loss', total_loss, on_step=False, on_epoch=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y, age, ac, sbp, dbp, hdlc, tg, bs = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=MAX_EPOCH)
        return [optimizer], [scheduler]

### train

In [None]:
df = pd.read_csv(LABEL_TRAIN_PATH)

if DEBUG:
    df = df.head(300)

df['path'] = df['filename'].apply(lambda x: os.path.join(TRAIN_CROP_DATASET_DIR, x))

skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
folds = []
for train_index, val_index in skf.split(df, df['METS']):
    train_df = df.iloc[train_index]
    val_df = df.iloc[val_index]
    folds.append((train_df, val_df))

In [None]:
for fold in range(5):
    target_data = folds[fold]
    train_df = target_data[0]
    val_df = target_data[1]

    data_module = PupilDataModule(
        train_df,
        val_df,
        None,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        img_size=IMG_SIZE,
    )

    model = PupilModel(
        model_name=MODEL_NAME,
        pretrained=True,
        num_classes=9,
        learning_rate=LEARNING_RATE
    )

    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',
        dirpath='checkpoints',
        filename=f'fold_{fold}-' + '{epoch:02d}-{val_acc:.3f}',
        save_top_k=1,
        mode='max',
    )

    notebook_path = os.path.abspath('')
    run_name = os.path.basename(notebook_path)
    wandb.finish()
    wandb_logger = WandbLogger(project="ganka_ai_2024", name=f"{run_name}_seed_{SEED}_fold_{fold}")

    trainer = pl.Trainer(
        max_epochs=MAX_EPOCH,
        accelerator="gpu",
        devices=1,
        callbacks=[checkpoint_callback],
        log_every_n_steps=10,
        num_sanity_val_steps=0,
        logger=wandb_logger,
        deterministic=True
    )

    trainer.fit(model, datamodule=data_module)
    wandb_logger.finalize("success")
    wandb.finish()

    final_model_path = f'./checkpoints/final_model_seed_{SEED}_fold_{fold}.pth'
    trainer.save_checkpoint(final_model_path)
    print(f"fold {fold} Final model saved at: {final_model_path}")