In [None]:
import os
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm

from sklearn.model_selection import train_test_split

import albumentations as A
from albumentations.pytorch import ToTensorV2

import torch.nn as nn
import torch.nn.functional as F

from torchmetrics import Accuracy, F1Score ,CohenKappa

from dretino.dataloader.build_features import DRDataModule
from dretino.models.mseloss import ModelMSE
from dretino.models.train_model import Model, train
from dretino.visualization.visualize import show_images, cal_mean, plot_metrics

In [None]:
PATH = '../data/processed/'

dfx = pd.read_csv(PATH + '2.Groundtruths/a.IDRiD_Disease_Grading_Training_Labels.csv',usecols = ['Image name','Retinopathy grade'])
df_test = pd.read_csv(PATH + '2.Groundtruths/b.IDRiD_Disease_Grading_Testing_Labels.csv',usecols = ['Image name','Retinopathy grade'])

dfx.head()

In [None]:
df_train, df_valid = train_test_split(
        dfx,
        test_size=0.1, 
        random_state=42, 
        stratify=dfx['Retinopathy grade'].values
    )

df_train = df_train.reset_index(drop=True)
df_valid = df_valid.reset_index(drop=True)

df_train.head()

In [None]:
train_transforms = A.Compose(
    [
        A.Resize(width=250, height=250),
        A.RandomCrop(height=224, width=224),
        A.Normalize(
            mean=(0.5237, 0.2542, 0.0853),
            std=(0.2649, 0.1497, 0.0876),
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

val_transforms = A.Compose(
    [
        A.Resize(height=224, width=224),
        A.Normalize(
            mean=(0.5237, 0.2542, 0.0853),
            std=(0.2649, 0.1497, 0.0876),
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

test_transforms = A.Compose(
    [
        A.Resize(height=224, width=224),
        A.Normalize(
            mean=(0.5237, 0.2542, 0.0853),
            std=(0.2649, 0.1497, 0.0876),
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

In [None]:
dm = DRDataModule(df_train, df_valid, df_test,
                      train_path=PATH + 'images_resized',
                      valid_path=PATH + 'images_resized',
                      test_path=PATH + 'test_images_resized',
                      train_transforms=train_transforms,
                      val_transforms=val_transforms,
                      test_transforms=test_transforms,
                      num_workers=4,
                      batch_size=4)

show_images(dm.train_dataloader())

In [None]:
model = ModelMSE(model_name='resnet50d',
                 num_classes=5,
                 additional_layers=False)

print(model)

In [None]:
def train_one_epoch(model,dataloader,optimizer,loss_fn):
    loss_ = 0
    acc_ = []
    f1_ = []
    kappa_ = []
    for idx,(x,y) in tqdm(enumerate(dataloader)):
        optimizer.zero_grad()
        y = torch.argmax(y,dim=-1)
        y = torch.unsqueeze(y,1).to(torch.float32)
        logits = model(x)
        loss = loss_fn(y,logits)
        loss_+=loss.item()
        loss.backward()
        optimizer.step()
        predictions = logits.data
        predictions[predictions < 0.5] = 0
        predictions[(predictions >= 0.5) & (predictions < 1.5)] = 1
        predictions[(predictions >= 1.5) & (predictions < 2.5)] = 2
        predictions[(predictions >= 2.5) & (predictions < 3.5)] = 3
        predictions[(predictions >= 3.5) & (predictions < 1000000000000)] = 4
        preds = predictions.long().view(-1)
        acc_.append(accuracy(preds, y.to(torch.int16).view(-1)))
        f1_.append(metric(preds, y.to(torch.int16).view(-1)))
        kappa_.append(kappametric(preds, y.to(torch.int16).view(-1)))
    print(f"loss : {loss_/idx}\n",
          f"acc :  {np.array(acc_).mean()}\n",
          f"f1 :  {np.array(f1_).mean()}\n",
          f"kappa :  {np.array(kappa_).mean()}")


In [None]:
accuracy = Accuracy()
metric = F1Score(num_classes=5)
kappametric = CohenKappa(num_classes=5)

optimizer = torch.optim.Adam(model.parameters(),lr=3e-4)
loss_fn = nn.MSELoss()

train_one_epoch(model,dm.train_dataloader(),optimizer,loss_fn)