In [1]:
import torch
import torchvision
import os
import cv2
import numpy as np
from torch import nn
from glob import glob
import timm
import torch.nn.functional as F

In [2]:
image_paths = './Data/*/*'
image_paths = glob(image_paths)

In [3]:
len(image_paths)

1944

In [4]:
labels = []
images = []

for image_path in image_paths:
    label = image_path.split(os.path.sep)[2]
    image = image_path.split(os.path.sep)[3]
    labels.append(label)
    images.append(image)

In [7]:
np.unique(labels)

array(['Cataracts', 'Glaucoma', 'Healthy', 'Uveitis'], dtype='<U9')

In [5]:
labels = np.array(labels, dtype='str')
label2pred = dict(zip(np.unique(labels), range(0, 4)))
pred2label = dict(zip(range(0, 4), np.unique(labels)))
n_classes = len(np.unique(labels))

In [None]:
class IrisDisease(torch.utils.data.Dataset):
    
    def __init__(self, image_paths, labels, transforms=None):
        super(IrisDisease, self).__init__()
        
        self.image_paths = image_paths
        self.labels = labels
        self.transforms = transforms
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (300, 300))
            
        label = self.labels[idx]
        label = label2pred[label]
        label = torch.tensor(label)
        
        if self.transforms:
            image = self.transforms(image)
        
        return image, label        

In [None]:
device='cpu'

In [None]:
from sklearn.model_selection import train_test_split

train_image_paths, test_image_paths, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.25, random_state=42)

In [None]:
tr_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.CenterCrop(300),
])

val_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

In [None]:
train_dataset = IrisDisease(image_paths=train_image_paths, labels=train_labels, transforms=tr_transforms)
test_dataset = IrisDisease(image_paths=test_image_paths, labels=test_labels, transforms=val_transforms)
# train_dataset = Dataset(image_paths=image_paths, label=labels, transforms=torchvision.transforms.ToTensor())

In [None]:
len(train_dataset)

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(train_dataset[10][0].permute(1, 2, 0))
print(pred2label[train_dataset[10][1].item()])

In [None]:
class ClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        out = self(images)
        loss = F.cross_entropy(out, labels)
        acc = accuracy(out, labels)          
        return loss, acc

    def validation_step(self, batch):
        images, labels = batch 
        images = images.to(device)
        labels = labels.to(device)
        out = self(images)                    
        loss = F.cross_entropy(out, labels)  
        acc = accuracy(out, labels)          
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()    
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
        

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

In [None]:
class EfficientNetB3(ClassificationBase):
    
    def __init__(self):
        super().__init__()
        
        self.network = timm.create_model('efficientnet_b3', pretrained=True)
        num_ftrs = self.network.classifier.in_features
        self.network.classifier = nn.Linear(num_ftrs, n_classes)
        
        
    def forward(self, batch):
        batch = batch.to(device)
        return torch.sigmoid(self.network(batch))
        
        
        
model = EfficientNetB3()     

In [None]:
def fit(epochs, model, train_loader, val_loader, opt_func=torch.optim.Adam):
    history = []
    optimizer = opt_func(model.parameters(), 1e-5)
    for epoch in range(epochs):
        lrs = []
        loss = 0
        acc = 0
        for batch in tqdm.tqdm(train_loader):
            loss, acc = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print("Epoch [{}], loss: {:.4f}, acc: {:.4f}".format(epoch, loss, acc))
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)
    return history

In [None]:
model = model.to(device)

In [None]:
evaluate(model, test_dataloader)

In [None]:
import tqdm
history += fit(20, model, train_dataloader, test_dataloader)

In [None]:
torch.save(model.state_dict(), 'model_1.pth')

In [None]:
losses = []
accs = []
for i in range(len(history)):
  losses.append(history[i]['val_loss'])
  accs.append(history[i]['val_acc'])

In [None]:
plt.plot(np.linspace(1, 30, 30).astype(int), losses)

In [None]:
plt.plot(np.linspace(1, 30, 30).astype(int), accs)

In [None]:
image = cv2.imread('./Data/Cataracts/102_1.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.imshow(image)

In [None]:
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((300, 300))
])

image = test_transforms(image)
image = torch.reshape(image, (1, 3, 300, 300))

In [None]:
model.eval()
pred2label[np.argmax(model(image).cpu().detach().numpy())]

In [None]:
model.load_state_dict(torch.load('model_1.pth', map_location=torch.device('cpu')))
model