In [None]:
import pandas as pd
import os
from PIL import Image, ImageOps
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
from torchvision.models import EfficientNet_B0_Weights
import torch.autograd.profiler as profiler

In [None]:
# Path to the Excel file with refractive error data
excel_file_path = './ODOCS RED REFLEX DATABASE/Choithram Netralaya Data/acuityvalues.xlsx'

# Load the Excel file containing refractive error data
acuity_data = pd.read_excel(excel_file_path)

# Display the first few rows to check the loaded data
print(acuity_data.head())

In [None]:
# Path to your dataset directory (relative to the notebook's location)
dataset_path = './ODOCS RED REFLEX DATABASE/Choithram Netralaya Data/Images'

# Supported image formats
image_extensions = ['.jpg', '.jpeg', '.png']

# Function to load and organize images, checking for both raw and cropped versions
def load_images_from_directory(directory_path):
    images = []
    image_metadata = []
    
    for root, _, files in os.walk(directory_path):
        if not files:
            print(f"Skipping empty folder: {root}")
            continue  # Skip empty folders
        
        for file in files:
            # Check if file is an image
            if any(file.lower().endswith(ext) for ext in image_extensions):
                file_path = os.path.join(root, file)
                
                # Check if the image is a raw or cropped version
                is_cropped = file.endswith('s.jpg')
                
                # Extract relevant metadata (eye type, date, cropped/raw) from the filename
                eye_type = "OD" if "OD" in file else "OS"  # Right Eye or Left Eye
                date_info = file.split('-')[0]  # Extract date from filename (e.g., 2022-12-08)
                
                try:
                    # Load the image using Pillow (PIL)
                    image = Image.open(file_path)
                    
                    # Append the image and its metadata to the list
                    images.append(image)
                    image_metadata.append({
                        'filename': file,
                        'eye_type': eye_type,
                        'date_info': date_info,
                        'is_cropped': is_cropped,
                        'path': file_path
                    })
                    
                    print(f"Loaded {file} from {root}")
                    
                except Exception as e:
                    print(f"Error loading image {file}: {e}")
    
    return images, image_metadata

# Load all images from the dataset
images, metadata = load_images_from_directory(dataset_path)

# Check the number of images loaded
print(f"Total images loaded: {len(images)}")

In [None]:
# Function to map images to acuity data using the folder name for patient ID
def map_images_to_acuity(images_metadata, acuity_data):
    mapped_data = []
    
    for meta in images_metadata:
        # Extract patient ID from the folder structure (assuming the folder name is the patient number)
        # Use the os.path.split() to get the folder name from the file path
        folder_name = os.path.basename(os.path.dirname(meta['path']))
        
        try:
            patient_id = int(folder_name)  # Convert folder name to integer for matching
        except ValueError:
            # If folder name is not an integer, skip this entry
            print(f"Skipping {meta['filename']} as folder name '{folder_name}' is not a valid patient ID.")
            continue
        
        eye = meta['eye_type']
        
        # Find corresponding acuity data for the patient
        patient_data = acuity_data[acuity_data['patient'] == patient_id]
        if not patient_data.empty:
            if eye == 'OD':
                meta['sphere'] = patient_data['r sphere'].values[0]
                meta['cylinder'] = patient_data['r cylinder'].values[0]
            else:
                meta['sphere'] = patient_data['l sphere'].values[0]
                meta['cylinder'] = patient_data['l cylinder'].values[0]
            
            mapped_data.append(meta)
    
    return mapped_data

# Map images to acuity data
mapped_images = map_images_to_acuity(metadata, acuity_data)

# Display the mapped data
print(mapped_images[:5])


In [None]:
# Check if CUDA is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Preprocessing transformations (e.g., for EfficientNet)
transform = transforms.Compose([
    transforms.Resize((299, 299)),  # Resize to 299x299 (EfficientNet input size)
    transforms.ToTensor(),          # Convert to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize as required
])

# Custom dataset class to handle our images and labels
class RedReflexDataset(Dataset):
    def __init__(self, images_metadata, transform=None):
        self.images_metadata = images_metadata
        self.transform = transform

    def __len__(self):
        return len(self.images_metadata)

    def __getitem__(self, idx):
        img_path = self.images_metadata[idx]['path']
        try:
            image = Image.open(img_path).convert('RGB')  # Convert to RGB
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return None, None, None
        
        # Apply transformations if they exist
        if self.transform:
            image = self.transform(image)
        
        # Get sphere and cylinder labels
        sphere = torch.tensor([self.images_metadata[idx]['sphere']], dtype=torch.float32)
        cylinder = torch.tensor([self.images_metadata[idx]['cylinder']], dtype=torch.float32)

        return image, sphere, cylinder

# Load the dataset
dataset = RedReflexDataset(mapped_images, transform=transform)

# Create DataLoader with multiple workers to load data in parallel
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)

# Load EfficientNet model with updated weights parameter
from torchvision.models import EfficientNet_B0_Weights
model = models.efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=2)  # Adjust for sphere and cylinder output

# Move model to GPU
model = model.to(device)

# Print model architecture
print(model)


In [None]:
# Mixed precision training setup
scaler = torch.amp.GradScaler('cuda')

# Loss function (Mean Squared Error, suitable for regression tasks like ours)
criterion = nn.MSELoss()

# Optimizer (we use Adam for optimization)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Training loop
epochs = 5  # Number of training epochs

for epoch in range(epochs):
    model.train()  # Set model to training mode
    
    running_loss = 0.0
    for images, spheres, cylinders in dataloader:
        images, spheres, cylinders = images.to(device), spheres.to(device), cylinders.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():
            outputs = model(images)
            predicted_spheres, predicted_cylinders = outputs[:, 0], outputs[:, 1]
            
            # Calculate losses
            loss_sphere = criterion(predicted_spheres, spheres)
            loss_cylinder = criterion(predicted_cylinders, cylinders)
            loss = loss_sphere + loss_cylinder
        
        # Backward pass and optimization with scaled gradients
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # Track loss
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")

print("Training complete!")