# Blindness Detection (Diabetic Retinopathy) Model Training

### EfficientNetV2 B4 - 0.92 Validation Accuracy ( 0.6 Validation Loss )

In [None]:
import torch
import torchvision
import os
import cv2
import numpy as np
from torch import nn
from glob import glob
import pandas as pd
import torch.nn.functional as F

In [None]:
image_paths = './Data/*/*'
image_paths = glob(image_paths)

In [None]:
labels = []
images = []

for image_path in image_paths:
    label = image_path.split(os.path.sep)[2]
    image = image_path.split(os.path.sep)[3]
    labels.append(label)
    images.append(image)

In [None]:
np.unique(labels)

In [None]:
labels = np.array(labels, dtype='str')
label2pred = dict(zip(np.unique(labels), range(0, 5)))
pred2label = dict(zip(range(0, 5), np.unique(labels)))
n_classes = len(np.unique(labels))

In [None]:
class DiabeticRetinopathy(torch.utils.data.Dataset):
    
    def __init__(self, image_paths, labels, transforms=None):
        super(DiabeticRetinopathy, self).__init__()
        
        self.image_paths = image_paths
        self.labels = labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (400, 400))
            
        label = self.labels[idx]
        label = label2pred[label]
        label = torch.tensor(label)
        
        if self.transforms:
            image = self.transforms(image)
        
        return image, label        

In [None]:
device='cuda'

In [None]:
from sklearn.model_selection import train_test_split
image_paths, _, labels, _ = train_test_split(image_paths, labels, test_size=0.95, shuffle=True, stratify=labels)
train_image_paths, test_image_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.2, shuffle=True, stratify=labels)

In [None]:
tr_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.CenterCrop(380),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


val_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.CenterCrop(380),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


In [None]:
train_dataset = DiabeticRetinopathy(image_paths=train_image_paths, labels=train_labels, transforms=tr_transforms)
test_dataset = DiabeticRetinopathy(image_paths=test_image_paths, labels=test_labels, transforms=val_transforms)

In [None]:
len(train_dataset)

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=20, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=20, shuffle=False)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(train_dataset[10][0].permute(1, 2, 0))
print(pred2label[train_dataset[10][1].item()])

In [None]:
class ClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        out = self(images)
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)          
        return loss, acc

    def validation_step(self, batch):
        images, labels = batch 
        images = images.to(device)
        labels = labels.to(device)
        out = self(images)                    
        loss = F.cross_entropy(out, labels) 
        acc = accuracy(out, labels)          
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()    
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}\n".format(epoch, result['val_loss'], result['val_acc']))
        

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

In [None]:
!pip install efficientnet_pytorch

In [None]:
from efficientnet_pytorch import EfficientNet

class EfficientNetB4(ClassificationBase):
    
    def __init__(self):
        super().__init__()
        
        self.network = EfficientNet.from_pretrained('efficientnet-b4')
        self.network._fc = nn.Linear(1792, n_classes)
        
    def forward(self, batch):
        batch = batch.to(device)
        return self.network(batch)
        
        
        
model = EfficientNetB4()

In [None]:
def fit(epochs, model, train_loader, val_loader, opt_func=torch.optim.Adam):
    history = []
    optimizer = opt_func(model.parameters(), 1e-4, weight_decay=1e-5)

    for epoch in range(epochs):
        lrs = []
        loss = 0
        acc = 0
        for batch in tqdm.tqdm(train_loader):
            model.train()
            loss, acc = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(
            "Epoch [{}] , loss: {:.4f}, acc: {:.4f}".format(epoch, loss, acc))
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [None]:
model = model.to(device)

In [None]:
evaluate(model, test_dataloader)

In [None]:
import tqdm
model.train()
history = fit(20, model, train_dataloader, test_dataloader)
model.eval()
result = evaluate(model, test_dataloader)

In [None]:
result

In [None]:
torch.save(model.state_dict(), f'./EfficientNetB4-0.9.pth')

In [None]:
model.eval()
evaluate(model, test_dataloader)
#model.train()

In [None]:
losses = []
accs = []
for i in range(len(history)):
  losses.append(history[i]['val_loss'])
  accs.append(history[i]['val_acc'])

In [None]:
plt.plot(np.linspace(1, 20, 20).astype(int), accs)

In [None]:
plt.plot(np.linspace(1, 20, 20).astype(int), losses)

In [None]:
model.load_state_dict(torch.load(f'./EfficientNetB4-0.9.pth'))

In [None]:
image = cv2.imread('./Data/0/10_right.jpeg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (400, 400))
plt.imshow(image)

In [None]:
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((380, 380)),
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

image = test_transforms(image)
image = torch.reshape(image, (1, 3, 380, 380))

In [None]:
model.eval()
pred2label[np.argmax(model(image).cpu().detach().numpy())]