In [11]:
from model import EyeDiseaseClassifierCNN

In [1]:
import torch 
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import numpy as np

from sklearn.metrics import confusion_matrix,classification_report

In [2]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [3]:
dataset = datasets.ImageFolder(root="processing_dataset_phase_final/", transform=transform)

total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [4]:
dataset.class_to_idx

{'a_healthy_eye': 0,
 'cataract': 1,
 'dry_eye_syndrome': 2,
 'exopthalmos': 3,
 'jaundice': 4,
 'pterygium': 5,
 'stye': 6,
 'subcon_hemorrage': 7}

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
device

device(type='cuda')

In [7]:
def model_train(model, optimizer, criterion, dataloader):
    model.train()
    actual = []
    predicted = []
    for batch in dataloader:
        images, labels = batch
        labels = labels.to(device)
        images = images.to(device)
        
        preds = model(images) 
        loss = criterion(preds, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _, predicted_classes = torch.max(preds, 1)
        
        for label in labels.tolist(): actual.append(int(label))
        for predicted_class in predicted_classes.tolist(): predicted.append(int(predicted_class))
    
    actual = np.array(actual) 
    predicted = np.array(predicted)
    
    return actual, predicted

In [8]:
def model_eval(model,dataloader):
    model.eval()
    actual = []
    predicted = []
    with torch.no_grad():
        for batch in dataloader:
            images, labels = batch
            
            labels = labels.to(device)
            images = images.to(device)
            
            preds = model(images)
            #loss = criterion(preds, labels)
            
            _, predicted_classes = torch.max(preds, 1)
            
            for label in labels.tolist(): actual.append(int(label))
            for predicted_class in predicted_classes.tolist(): predicted.append(int(predicted_class)) 
        
    actual = np.array(actual) 
    predicted = np.array(predicted)
    
    return actual, predicted

In [12]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [13]:
learning_rate = 0.00007

model = EyeDiseaseClassifierCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [14]:
num_epochs = 5

for i in range(num_epochs):
    print(f"Epoch #{i+1}")
    
    train_actual, train_predicted = model_train(model, optimizer, criterion, train_loader)
    eval_actual, eval_predicted = model_eval(model, val_loader)
    
    train_report = classification_report(train_actual, train_predicted, output_dict=True)
    print("Training: \n", classification_report(train_actual, train_predicted, zero_division=0))
    
    eval_report = classification_report(eval_actual, eval_predicted, output_dict=True)
    print("Evaluation: \n", classification_report(eval_actual, eval_predicted, zero_division=0))

Epoch #1
Training: 
               precision    recall  f1-score   support

           0       0.28      0.34      0.31       120
           1       0.34      0.32      0.33       105
           2       0.26      0.25      0.26        87
           3       0.00      0.00      0.00        34
           4       0.33      0.19      0.24        73
           5       0.38      0.42      0.40       113
           6       0.35      0.47      0.40       105
           7       0.28      0.27      0.27        71

    accuracy                           0.32       708
   macro avg       0.28      0.28      0.28       708
weighted avg       0.31      0.32      0.31       708

Evaluation: 
               precision    recall  f1-score   support

           0       0.73      0.62      0.67        13
           1       0.00      0.00      0.00        16
           2       0.18      0.33      0.24         6
           3       0.00      0.00      0.00         4
           4       1.00      0.10      0.18

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


KeyboardInterrupt: 