In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import PIL
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
# Set up device agnostic code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Define Model

In [None]:
class Concatenate(nn.Module):
    def forward(self, x1, x2):
        return torch.cat([x1, x2], dim=1)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.3):
        super(ResBlock, self).__init__()

        branch_channels = out_channels // 2

        # 3x3 convolution branch
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, 3, 1, 1),
            nn.BatchNorm2d(branch_channels),
            nn.LeakyReLU(),
            nn.Dropout(dropout_prob)
        )

        # 5x5 convolution branch
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, 5, 1, 2), 
            nn.BatchNorm2d(branch_channels),
            nn.LeakyReLU(),
            nn.Dropout(dropout_prob)
        )

        def get_concatenated_features(self, x):
          # Apply both branches
          out3x3 = self.branch3x3(x)
          out5x5 = self.branch5x5(x)

          # Concatenate along channel dimension
          out = torch.cat([out3x3, out5x5], dim=1)
          return out

        # Define skip connection and adapt channels
        self.residual = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

        self.concatenate = Concatenate()

    def forward(self, x):
        # Apply both branches
        out3x3 = self.branch3x3(x)
        out5x5 = self.branch5x5(x)

        # Concatenate along channel dimension
        out = torch.cat([out3x3, out5x5], dim=1)

        # Apply the residual connection
        res = self.residual(x)
        return out + res

class Aletheia4Net(nn.Module):
    def __init__(self, dropout_prob=0.3):
        super(Aletheia4Net, self).__init__()

        # Convolutional layers with residual blocks and max-pooling
        self.conv_layers = nn.Sequential(
            ResBlock(3, 16),
            nn.MaxPool2d(2),
            ResBlock(16, 32),
            nn.MaxPool2d(2),
            ResBlock(32, 64),
            nn.MaxPool2d(2),
            ResBlock(64, 128),
            nn.MaxPool2d(2),
            ResBlock(128, 256),
            nn.MaxPool2d(2),
            ResBlock(256, 512)
        )

        # Global Average Pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.Dropout(dropout_prob),
            nn.Linear(512, 3)
        )

    def get_last_resblock_output(self, x):
        # Apply all conv_layers except the last ResBlock
        for layer in self.conv_layers[:-1]:
            x = layer(x)

        # Get concatenated features from the last ResBlock
        return self.conv_layers[-1].get_concatenated_features(x)

    def feature_size(self):
        # Testing feature size with 256x256 input
        return self.conv_layers(torch.zeros(1, 3, 256, 256)).view(1, -1).size(1)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

    @property
    def last_conv_layer(self):
        return self.conv_layers[-1].branch3x3[0]

In [None]:
# Load Aletheia4 model
model = Aletheia4Net().to(device)

sample_input = torch.randn(1, 3, 256, 256).to(device)
model(sample_input)

# Specify path to the trained model weights
model_path = "../../../SyntheticEyeLocal/StateDicts/Aletheia/4_0/model_epoch_18.pth"

# Load trained weights into the model
model.load_state_dict(torch.load(model_path))

model = model.to(device)

## Prepare Data

In [None]:
# Define mean and std for normalizing images
mean = [0.499, 0.415, 0.372]
std = [0.245, 0.223, 0.220]

In [None]:
# Albumentation Transforms
test_albumentations_transform = A.Compose([
    A.SmallestMaxSize(max_size=304),
    A.CenterCrop(256, 256),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

# Albumentation Transforms
original_albumentations_transform = A.Compose([
    A.SmallestMaxSize(max_size=304),
    A.CenterCrop(256, 256),
    ToTensorV2()
])

In [None]:
def apply_test_transforms(image_path):
    """
    Apply test transformations to a single image.
    """
    # Load image
    img = Image.open(image_path).convert("RGB")

    # Convert PIL image to numpy array
    img_np = np.array(img)

    # Apply Albumentations transforms (assuming test_albumentations_transform is an Albumentations transform)
    transformed = test_albumentations_transform(image=img_np)
    transformed_image = transformed["image"]

    return transformed_image

In [None]:
def apply_original_transforms(image_path):
    """
    Apply test transformations to a single image.
    """
    # Load image
    img = Image.open(image_path).convert("RGB")

    # Convert PIL image to numpy array
    img_np = np.array(img)

    # Apply Albumentations transforms (assuming test_albumentations_transform is an Albumentations transform)
    transformed = original_albumentations_transform(image=img_np)
    transformed_image = transformed["image"]

    return transformed_image

## Load and show image

In [None]:
# Import and show image
file_path = "C:\\Users\\jacob\\OneDrive\\Desktop\\SyntheticEye\\SampleData\\GAN\\1.jpg"
img = PIL.Image.open(file_path)
img

In [None]:
# Apply transformations to image
img_transformed = apply_test_transforms(file_path).unsqueeze(0) # The model will be evaluated on this image
img_original = apply_original_transforms(file_path).unsqueeze(0) # The original resized image will be displayed und the gradient mask
img_original = img_original.to(device)
img_transformed = img_transformed.to(device)

## Get Prediction on Image

In [None]:
model.eval()
model = model.to(device)
scores = model(img_transformed)

# Apply softmax to get probabilities for each class
probabilities = torch.softmax(scores, dim=1).squeeze()

# Get the predicted class (the one with the highest probability)
predicted_class = torch.argmax(probabilities).item()

probabilities_list = probabilities.tolist()

print(probabilities_list)

In [None]:
def generate_grad_cam(model, img_path, target_layer, class_of_interest=None):
    # Load and preprocess the image
    input_img = apply_test_transforms(img_path).unsqueeze(0).to(device)

    # Get the model's feature and hook the target layer
    feature_maps = None
    def get_features_hook(module, input, output):
        nonlocal feature_maps
        feature_maps = output.detach()

    hook = target_layer.register_forward_hook(get_features_hook)

    # Forward pass
    model.eval()
    scores = model(input_img)
    probabilities = torch.softmax(scores, dim=1).squeeze()

    # Choose the class of interest
    if class_of_interest is None:
        class_of_interest = torch.argmax(probabilities).item()

    # Target for backprop
    target = scores[0, class_of_interest]

    # Backward pass
    model.zero_grad()
    target.backward()

    # Hook removal
    hook.remove()

    # Get gradients and feature maps
    gradients = target_layer.weight.grad
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    # Weight the channels by corresponding gradients
    for i in range(feature_maps.shape[1]):
        feature_maps[:, i, :, :] *= pooled_gradients[i]

    # Average the channels of the feature maps
    grad_cam = torch.mean(feature_maps, dim=1).squeeze()

    # ReLU on top of the heatmap
    grad_cam = F.relu(grad_cam)

    # Normalize the heatmap
    grad_cam = grad_cam / grad_cam.max()

    return grad_cam.cpu().numpy()

In [None]:
# Generate Grad-CAM map
grad_cam_map = generate_grad_cam(model, file_path, model.last_conv_layer)

# Load and preprocess image for plotting
original_img = Image.open(file_path)
original_img = apply_original_transforms(file_path).cpu().numpy().transpose(1, 2, 0)

# Resize Grad-CAM map to the size of the original image
grad_cam_map = np.uint8(255 * grad_cam_map)  # Convert to uint8
grad_cam_map = np.array(Image.fromarray(grad_cam_map).resize((original_img.shape[1], original_img.shape[0]), PIL.Image.LANCZOS))

# Plotting
plt.figure(figsize=(10, 10))
plt.imshow(original_img, alpha=0.6)
plt.imshow(grad_cam_map, cmap='jet', alpha=0.4)
plt.axis('off')
plt.show()