In [1]:
import numpy as np 
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models 
import torchvision.transforms
from torchvision.transforms import ToTensor
import os
from PIL import Image 
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm

In [2]:
def open_image(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


# define dataset directory 
path = 'retina_dataset/dataset'
#img_size = 224 
#batch_size = 32

image = Image.open('retina_dataset/dataset/1_normal/NL_001.png')
#image.show()


image_transforms = transforms.Compose([
    transforms.Resize((204, 308)),  #1232, 816
    transforms.CenterCrop(size=(204, 206)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4563, 0.2717, 0.1612], std=[0.5519, 0.3326, 0.2021])
    #transforms.Normalize(mean=[0.3066, 0.1828, 0.1091], std=[0.324, 0.1947, 0.1162])
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


dataset = datasets.ImageFolder(root=path, transform=image_transforms)

In [3]:
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)

In [4]:
class CNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=4):  # num class = 10
        """
        Define the layers of the convolutional neural network.

        Parameters:
            in_channels: int
                The number of channels in the input image. For MNIST, this is 1 (grayscale images).
            num_classes: int
                The number of classes we want to predict, in our case 10 (digits 0 to 9).
        """
        in_channels: 3 
        num_classes: 4
        super(CNN, self).__init__()

        # First convolutional layer: 3 input channel, 8 output channels, 10x10 kernel, stride 1, padding 0
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=10, stride=1, padding=0)
        # subsequent convolutional layers: in_channels must be equal to the out_channels of the immediately preceding layer
        # Max pooling layer: 2x2 window, stride 2
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # Second convolutional layer: 8 input channels, 16 output channels, 8x8 kernel, stride 1, padding 3
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=8, stride=1, padding=3)
        # Fully connected layer: 32*13*13 input features (after two 2x2 poolings), 10 output features (num_classes)
      
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=5)
        self.fc1 = nn.Linear(in_features=5408, out_features=num_classes)
    def forward(self, x):
        """
        Define the forward pass of the neural network.

        Parameters:
            x: torch.Tensor
                The input tensor.

        Returns:
            torch.Tensor
                The output tensor after passing through the network.
        """
        x = F.relu(self.conv1(x))  # Apply first convolution and ReLU activation
        x = self.pool(x)           # Apply max pooling
        x = F.relu(self.conv2(x))  # Apply second convolution and ReLU activation
        x = self.pool(x)           # Apply max pooling
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)  # Flatten the tensor
        x = self.fc1(x)            # Apply fully connected layer
        return x

#print(train_loader.dataset[0].size())
input_size = 126072        #204*206*3
num_classes = 4  # 4 types of eye: normal, cataracts, glaucoma, retina disease
learning_rate = 0.001
batch_size = 32
num_epochs = 75  

device = "cuda" if torch.cuda.is_available() else "cpu"

model = CNN(in_channels=3, num_classes=num_classes).to(device)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

#tensor_to_image(train_loader.dataset[0][0][1])


for epoch in range(num_epochs):
    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    for batch_index, (data, targets) in enumerate(tqdm(train_loader)):
        # Move data and targets to the device (GPU/CPU)
        data = data.to(device)
        targets = targets.to(device)

        # Forward pass: compute the model output
        scores = model(data)
        loss = criterion(scores, targets)

        # Backward pass: compute the gradients
        optimizer.zero_grad()
        loss.backward()

        # Optimization step: update the model parameters
        optimizer.step()

Epoch [1/75]


100%|██████████| 121/121 [00:32<00:00,  3.77it/s]


Epoch [2/75]


100%|██████████| 121/121 [00:33<00:00,  3.57it/s]


Epoch [3/75]


100%|██████████| 121/121 [00:34<00:00,  3.56it/s]


Epoch [4/75]


100%|██████████| 121/121 [00:35<00:00,  3.41it/s]


Epoch [5/75]


100%|██████████| 121/121 [00:34<00:00,  3.47it/s]


Epoch [6/75]


100%|██████████| 121/121 [00:34<00:00,  3.55it/s]


Epoch [7/75]


100%|██████████| 121/121 [00:34<00:00,  3.52it/s]


Epoch [8/75]


100%|██████████| 121/121 [00:34<00:00,  3.52it/s]


Epoch [9/75]


100%|██████████| 121/121 [01:27<00:00,  1.39it/s]


Epoch [10/75]


100%|██████████| 121/121 [10:54<00:00,  5.41s/it]  


Epoch [11/75]


100%|██████████| 121/121 [00:31<00:00,  3.82it/s]


Epoch [12/75]


100%|██████████| 121/121 [00:32<00:00,  3.70it/s]


Epoch [13/75]


100%|██████████| 121/121 [00:33<00:00,  3.64it/s]


Epoch [14/75]


100%|██████████| 121/121 [00:37<00:00,  3.21it/s]


Epoch [15/75]


100%|██████████| 121/121 [00:34<00:00,  3.52it/s]


Epoch [16/75]


100%|██████████| 121/121 [00:35<00:00,  3.43it/s]


Epoch [17/75]


100%|██████████| 121/121 [00:35<00:00,  3.37it/s]


Epoch [18/75]


100%|██████████| 121/121 [00:36<00:00,  3.32it/s]


Epoch [19/75]


100%|██████████| 121/121 [00:35<00:00,  3.40it/s]


Epoch [20/75]


100%|██████████| 121/121 [00:35<00:00,  3.37it/s]


Epoch [21/75]


100%|██████████| 121/121 [00:34<00:00,  3.47it/s]


Epoch [22/75]


100%|██████████| 121/121 [00:34<00:00,  3.47it/s]


Epoch [23/75]


100%|██████████| 121/121 [00:35<00:00,  3.45it/s]


Epoch [24/75]


100%|██████████| 121/121 [00:35<00:00,  3.45it/s]


Epoch [25/75]


100%|██████████| 121/121 [00:35<00:00,  3.44it/s]


Epoch [26/75]


100%|██████████| 121/121 [00:35<00:00,  3.43it/s]


Epoch [27/75]


100%|██████████| 121/121 [00:35<00:00,  3.42it/s]


Epoch [28/75]


100%|██████████| 121/121 [00:35<00:00,  3.41it/s]


Epoch [29/75]


100%|██████████| 121/121 [00:35<00:00,  3.40it/s]


Epoch [30/75]


100%|██████████| 121/121 [00:35<00:00,  3.40it/s]


Epoch [31/75]


100%|██████████| 121/121 [00:35<00:00,  3.40it/s]


Epoch [32/75]


100%|██████████| 121/121 [00:35<00:00,  3.41it/s]


Epoch [33/75]


100%|██████████| 121/121 [00:35<00:00,  3.41it/s]


Epoch [34/75]


100%|██████████| 121/121 [00:35<00:00,  3.40it/s]


Epoch [35/75]


100%|██████████| 121/121 [00:35<00:00,  3.38it/s]


Epoch [36/75]


100%|██████████| 121/121 [00:35<00:00,  3.39it/s]


Epoch [37/75]


100%|██████████| 121/121 [00:35<00:00,  3.37it/s]


Epoch [38/75]


100%|██████████| 121/121 [00:36<00:00,  3.34it/s]


Epoch [39/75]


100%|██████████| 121/121 [00:37<00:00,  3.27it/s]


Epoch [40/75]


100%|██████████| 121/121 [00:37<00:00,  3.26it/s]


Epoch [41/75]


100%|██████████| 121/121 [00:36<00:00,  3.28it/s]


Epoch [42/75]


100%|██████████| 121/121 [01:20<00:00,  1.51it/s]


Epoch [43/75]


100%|██████████| 121/121 [00:33<00:00,  3.58it/s]


Epoch [44/75]


100%|██████████| 121/121 [00:34<00:00,  3.46it/s]


Epoch [45/75]


100%|██████████| 121/121 [00:35<00:00,  3.37it/s]


Epoch [46/75]


100%|██████████| 121/121 [00:35<00:00,  3.38it/s]


Epoch [47/75]


100%|██████████| 121/121 [00:36<00:00,  3.36it/s]


Epoch [48/75]


100%|██████████| 121/121 [00:36<00:00,  3.34it/s]


Epoch [49/75]


100%|██████████| 121/121 [00:37<00:00,  3.26it/s]


Epoch [50/75]


100%|██████████| 121/121 [00:38<00:00,  3.13it/s]


Epoch [51/75]


100%|██████████| 121/121 [00:38<00:00,  3.12it/s]


Epoch [52/75]


100%|██████████| 121/121 [00:38<00:00,  3.13it/s]


Epoch [53/75]


100%|██████████| 121/121 [00:38<00:00,  3.16it/s]


Epoch [54/75]


100%|██████████| 121/121 [00:38<00:00,  3.17it/s]


Epoch [55/75]


100%|██████████| 121/121 [00:38<00:00,  3.17it/s]


Epoch [56/75]


100%|██████████| 121/121 [00:38<00:00,  3.16it/s]


Epoch [57/75]


100%|██████████| 121/121 [00:38<00:00,  3.15it/s]


Epoch [58/75]


100%|██████████| 121/121 [00:38<00:00,  3.15it/s]


Epoch [59/75]


100%|██████████| 121/121 [00:38<00:00,  3.17it/s]


Epoch [60/75]


100%|██████████| 121/121 [00:38<00:00,  3.17it/s]


Epoch [61/75]


100%|██████████| 121/121 [00:37<00:00,  3.22it/s]


Epoch [62/75]


100%|██████████| 121/121 [00:37<00:00,  3.23it/s]


Epoch [63/75]


100%|██████████| 121/121 [00:39<00:00,  3.07it/s]


Epoch [64/75]


100%|██████████| 121/121 [00:37<00:00,  3.26it/s]


Epoch [65/75]


100%|██████████| 121/121 [00:45<00:00,  2.66it/s]


Epoch [66/75]


100%|██████████| 121/121 [00:42<00:00,  2.86it/s]


Epoch [67/75]


100%|██████████| 121/121 [00:41<00:00,  2.93it/s]


Epoch [68/75]


100%|██████████| 121/121 [00:39<00:00,  3.06it/s]


Epoch [69/75]


100%|██████████| 121/121 [00:39<00:00,  3.07it/s]


Epoch [70/75]


100%|██████████| 121/121 [00:39<00:00,  3.06it/s]


Epoch [71/75]


100%|██████████| 121/121 [00:43<00:00,  2.81it/s]


Epoch [72/75]


100%|██████████| 121/121 [00:41<00:00,  2.94it/s]


Epoch [73/75]


100%|██████████| 121/121 [00:40<00:00,  2.95it/s]


Epoch [74/75]


100%|██████████| 121/121 [00:41<00:00,  2.95it/s]


Epoch [75/75]


100%|██████████| 121/121 [00:41<00:00,  2.94it/s]


In [5]:
def check_accuracy(loader, model, dataset_type):
    """
    Checks the accuracy of the model on the given dataset loader.

    Parameters:
        loader: DataLoader
            The DataLoader for the dataset to check accuracy on.
        model: nn.Module
            The neural network model.
    """
    # if loader.dataset.train:
    #     print("Checking accuracy on training data")
    # else:
    #     print("Checking accuracy on test data")

    print(f"Checking accuracy on {dataset_type} data") 

    num_correct = 0
    num_samples = 0
    model.eval()  # Set the model to evaluation mode

    with torch.no_grad():  # Disable gradient calculation
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            # Forward pass: compute the model output
            scores = model(x)
            _, predictions = scores.max(1)  # Get the index of the max log-probability
            num_correct += (predictions == y).sum()  # Count correct predictions
            num_samples += predictions.size(0)  # Count total samples

        # Calculate accuracy
        accuracy = float(num_correct) / float(num_samples) * 100
        print(f"Got {num_correct}/{num_samples} with accuracy {accuracy:.2f}%")
    
    model.train()  # Set the model back to training mode

# Final accuracy check on training and test sets
check_accuracy(train_loader, model, 'training')
check_accuracy(test_loader, model, 'test')


Checking accuracy on training data
Got 481/481 with accuracy 100.00%
Checking accuracy on test data
Got 63/120 with accuracy 52.50%
