In [1]:
!pip install torch torchvision imageio opencv-python-headless tifffile imagecodecs #installing all the required libraries

Collecting torch
  Downloading torch-2.4.0-cp312-none-macosx_11_0_arm64.whl.metadata (26 kB)
Collecting torchvision
  Downloading torchvision-0.19.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.0 kB)
Collecting opencv-python-headless
  Downloading opencv_python_headless-4.10.0.84-cp37-abi3-macosx_11_0_arm64.whl.metadata (20 kB)
Downloading torch-2.4.0-cp312-none-macosx_11_0_arm64.whl (62.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading torchvision-0.19.0-cp312-cp312-macosx_11_0_arm64.whl (1.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading opencv_python_headless-4.10.0.84-cp37-abi3-macosx_11_0_arm64.whl (54.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.8/54.8 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collec

In [2]:
import os

# Get the current working directory
current_dir = os.getcwd()

# Update the paths to point to the dataset directories
flori21_image_dir = os.path.join(current_dir, 'FLoRI21_DataPort_Extracted')
fire_image_dir = os.path.join(current_dir, 'FIRE_Extracted/Images')

# Print the paths to verify
print("FLoRI21 Image Directory:", flori21_image_dir)
print("FIRE Image Directory:", fire_image_dir)

FLoRI21 Image Directory: /Users/aryanmhalsank/Desktop/Retinal Classification/FLoRI21_DataPort_Extracted
FIRE Image Directory: /Users/aryanmhalsank/Desktop/Retinal Classification/FIRE_Extracted/Images


In [3]:
import os
from PIL import Image
import tifffile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as TF
import random

# Custom RandomZoom Transform
class RandomZoom(object):
    def __init__(self, zoom_range):
        self.zoom_range = zoom_range

    def __call__(self, img):
        # Define the zoom factor
        zoom_factor = random.uniform(1 - self.zoom_range, 1 + self.zoom_range)
        
        # Resize and crop
        width, height = img.size
        new_width = int(width * zoom_factor)
        new_height = int(height * zoom_factor)
        
        # Resize the image
        img = TF.resize(img, (new_height, new_width))
        
        # Center crop
        new_width, new_height = img.size
        crop_size = min(new_width, new_height)
        img = TF.center_crop(img, crop_size)
        
        return img

# Dataset definition with increased augmentation
class RetinalImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_files = []
        for root, dirs, files in os.walk(image_dir):
            for file in files:
                if file.lower().endswith(('.jpg', '.tif', '.tiff', '.png')):
                    self.image_files.append(os.path.join(root, file))
        self.transform = transform
        print(f"Found {len(self.image_files)} images in {self.image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        if img_path.lower().endswith(('.tif', '.tiff')):
            image = tifffile.imread(img_path)
            image = Image.fromarray(image).convert("L")  # Convert to grayscale
        else:
            image = Image.open(img_path).convert("L")  # Convert to grayscale
        
        if self.transform:
            image = self.transform(image)
    
        return image

# Path to the datasets
flori21_image_dir = '/Users/aryanmhalsank/Desktop/Retinal Classification/FLoRI21_DataPort'
fire_image_dir = '/Users/aryanmhalsank/Desktop/Retinal Classification/FIRE/Images'

# Data augmentation and normalization
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    RandomZoom(0.1),
    transforms.ToTensor(),  # Converts to tensor with values in [0, 1]
])

# Load the datasets
flori21_dataset = RetinalImageDataset(image_dir=flori21_image_dir, transform=transform)
fire_dataset = RetinalImageDataset(image_dir=fire_image_dir, transform=transform)

flori21_dataloader = DataLoader(flori21_dataset, batch_size=16, shuffle=True)
fire_dataloader = DataLoader(fire_dataset, batch_size=16, shuffle=True)

# Enhanced NN Model with residual connections and dropout
class EnhancedRetinalImageModel(nn.Module):
    def __init__(self):
        super(EnhancedRetinalImageModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Initialize the model, loss function, and optimizer with weight decay
model = EnhancedRetinalImageModel()
reconstruction_criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-5, max_lr=1e-3, step_size_up=5, mode='triangular')

# Training function
def train_model(dataloader, model, reconstruction_criterion, optimizer, scheduler, num_epochs=20):
    losses = []
    accuracies = []
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        correct = 0
        total = 0

        for images in dataloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = reconstruction_criterion(outputs, images)
            loss.backward()
            optimizer.step()
            scheduler.step()

            epoch_loss += loss.item()
            
            predicted = (outputs > 0.5).float()
            total += images.numel()
            correct += (predicted == images).sum().item()

        accuracy = correct / total
        losses.append(epoch_loss / len(dataloader))
        accuracies.append(accuracy * 100)

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(dataloader):.4f}, Accuracy: {accuracy * 100:.2f}%")

    return losses, accuracies

# Train the model
num_epochs = 20  # Increase number of epochs for better convergence
flori21_losses, flori21_accuracies = train_model(flori21_dataloader, model, reconstruction_criterion, optimizer, scheduler, num_epochs=num_epochs)
fire_losses, fire_accuracies = train_model(fire_dataloader, model, reconstruction_criterion, optimizer, scheduler, num_epochs=num_epochs)

# Plot loss and accuracy
epochs = range(1, num_epochs + 1)
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(epochs, flori21_losses, 'r', label='FLoRI21 Loss')
plt.plot(epochs, fire_losses, 'b', label='FIRE Loss')
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, flori21_accuracies, 'r', label='FLoRI21 Accuracy')
plt.plot(epochs, fire_accuracies, 'b', label='FIRE Accuracy')
plt.title('Training Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()


Found 20 images in /Users/aryanmhalsank/Desktop/Retinal Classification/FLoRI21_DataPort
Found 268 images in /Users/aryanmhalsank/Desktop/Retinal Classification/FIRE/Images


RuntimeError: stack expects each tensor to be equal size, but got [1, 248, 248] at entry 0 and [1, 230, 230] at entry 1