In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
import random
from PIL import Image

## **Utilities**

In [None]:
def plot_loss(loss_history):
    """Plots the training loss over epochs."""
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, len(loss_history) + 1), loss_history, marker='o', linestyle='-')
    plt.xlabel("Epoch")
    plt.ylabel("Average Loss")
    plt.title("Training Loss Over Epochs")
    plt.grid(True)
    plt.show()

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        """
        Args:
            data (tensor): Input image data.
            labels (tensor): Corresponding labels.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

def get_data_loader(batch_size=32):
    # Define transformations
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of horizontal flip
        transforms.RandomRotation(degrees=(-10, 10)),  # Random rotation between -10 to 10 degrees
        transforms.ToTensor()  # Convert images to tensors
    ])

    # Example: Generate random tensor data (Replace with actual dataset)
    num_samples = 1000  # Modify based on dataset size
    image_size = (3, 197, 135)  # Example size (C, H, W)

    random_data = torch.rand(num_samples, *image_size)  # Fake image data
    random_labels = torch.randint(0, 5, (num_samples,))  # Fake labels (5 classes)

    dataset = CustomDataset(random_data, random_labels, transform=transform)

    # Create DataLoader
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    return data_loader

## **Implement Model**


In [None]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()

        # Initial standard convolutional layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)

        # Dilated convolutional layers
        self.dilated3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, dilation=(3, 2), padding=2)
        self.dilated4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, dilation=(6, 4), padding=4)
        self.dilated5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, dilation=(12, 8), padding=8)
        self.dilated6 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, dilation=(24, 16), padding=16)
        self.dilated7 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, dilation=(48, 32), padding=32)

        # Concatenation will be handled in forward pass
        self.conv9 = nn.Conv2d(in_channels=160, out_channels=128, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1)
        self.conv11 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1)

    def forward(self, x):
        # Initial conv layers
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))

        # Dilated convolutions
        d3 = F.relu(self.dilated3(x2))
        d4 = F.relu(self.dilated4(d3))
        d5 = F.relu(self.dilated5(d4))
        d6 = F.relu(self.dilated6(d5))
        d7 = F.relu(self.dilated7(d6))

        # Concatenation of layer 2 and layer 7
        concat = torch.cat((x2, d7), dim=1)

        # Final convolutions
        x9 = F.relu(self.conv9(concat))
        x10 = F.relu(self.conv10(x9))
        x11 = self.conv11(x10)  # No activation as it's typically used for logits

        return x11

# Create model
model = CNN()
print(model)

CNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (dilated3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(3, 2))
  (dilated4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(6, 4))
  (dilated5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(8, 8), dilation=(12, 8))
  (dilated6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(16, 16), dilation=(24, 16))
  (dilated7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(32, 32), dilation=(48, 32))
  (conv9): Conv2d(160, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv10): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv11): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)


## **Training**

In [None]:
num_epochs = 16  # Total epochs (8 with lr=0.001, then 8 with lr=0.0001)
iteration = 55

In [None]:
def train_model(model, train_loader, device):
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.RMSprop(model.parameters(), lr=0.001)

    loss_history = []  # Store loss per epoch

    model.to(device)

    for epoch in range(num_epochs):
        if epoch == 8:  # Reduce learning rate after first 8 epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = 0.0001

        model.train()  # Set model to training mode
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            if i >= iteration:  # Stop after the given number of iterations per epoch
                break

            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if (i + 1) % 10 == 0:  # Print every 10 iterations
                print(f"Epoch [{epoch+1}/{num_epochs}], Iteration [{i+1}/{iteration}], Loss: {loss.item():.4f}")

        avg_loss = running_loss / iteration
        loss_history.append(avg_loss)
        print(f"Epoch [{epoch+1}/{num_epochs}] completed. Avg Loss: {avg_loss:.4f}")

    print("Training finished!")
    plot_loss(loss_history)  # Call the loss plot function

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load data
train_loader = get_data_loader(batch_size=32)

# train
train_model(model, train_loader, device)

## **Locate Fovea**

In [None]:
def gaussian_kernel_2d(size=3, sigma=1.0):
    """Creates a 2D Gaussian kernel."""
    x = torch.arange(size) - size // 2
    y = torch.arange(size) - size // 2
    x, y = torch.meshgrid(x, y, indexing='ij')

    kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
    kernel /= kernel.sum()

    return kernel


def apply_gaussian_smoothing_2d(output, kernel_size=3, sigma=1.0):
    """Applies 2D Gaussian smoothing to each depth slice of a 3D network output."""
    kernel = gaussian_kernel_2d(kernel_size, sigma).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, H, W)

    smoothed_slices = []
    for d in range(output.shape[0]):  # Loop over depth
        slice_2d = output[d].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, H, W)
        smoothed_slice = F.conv2d(slice_2d, kernel, padding=kernel_size // 2).squeeze(0).squeeze(0)
        smoothed_slices.append(smoothed_slice)

    return torch.stack(smoothed_slices)  # Shape: (D, H, W)

def fill_max_probability_color(smoothed_output):
    """Finds the max probability location and highlights it on a 2D grayscale image."""
    # Convert 3D (D, H, W) tensor to 2D by taking max along depth
    max_projection, max_depth = torch.max(smoothed_output, dim=0)  # Shape: (H, W)

    # Get max probability location
    max_idx = torch.argmax(max_projection)
    max_y, max_x = np.unravel_index(max_idx.cpu().numpy(), max_projection.shape)

    # Normalize grayscale image to 255 range
    gray_image = (max_projection.cpu().numpy() * 255).astype(np.uint8)

    # Convert grayscale to PIL image
    img = Image.fromarray(gray_image, mode='L')  # 'L' for grayscale

    # Highlight the max probability pixel with red
    img_colored = img.convert("RGB")  # Convert grayscale to RGB
    pixels = img_colored.load()
    pixels[max_x, max_y] = (255, 0, 0)  # Set max prob pixel to red

    return img_colored, (max_x, max_y, max_depth[max_y, max_x].item())

def process_output(output, kernel_size=3, sigma=1.0):
    """Applies 2D Gaussian smoothing and marks the highest probability pixel."""
    smoothed_output = apply_gaussian_smoothing_2d(output, kernel_size, sigma)
    color_image, max_pos = fill_max_probability_color(smoothed_output)

    return color_image, max_pos

def show_image(image):
    """Displays an image using PIL."""
    image.show()

In [None]:
# Example usage
output = torch.rand(5, 5, 5)  # Example grayscale output tensor
color_img, max_position = process_output(output, kernel_size=1, sigma=1.0)

print("Max Probability Position:", max_position)
show_image(color_img)  # Show the result using PIL

Max Probability Position: (2, 4, 0)
