In [1]:
import sys
sys.path.append('/home/fonta42/Desktop/masters-degree/data/vess-map/')
from vess_map_dataset import VessMapDataset

import torch
from torchvision import models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import random

In [2]:
# Define paths
image_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/images'
mask_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/labels'
skeleton_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/skeletons'

image_size = 256

# Instantiate the dataset
dataset = VessMapDataset(image_dir, mask_dir, skeleton_dir, image_size)

# Access the images
images = dataset.images  

In [3]:
# Load pre-trained ResNet18 model
model = models.resnet18(pretrained=True)
model.eval()

activations = {}

# Function to get the activation of a layer
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook

# Register hooks to all layers
for name, layer in model.named_modules():
    layer.register_forward_hook(get_activation(name))



In [4]:
# Object to transform images to tensors
preprocess = transforms.Compose([
    transforms.ToTensor(),
])

# Test image
img = images[0]  # PIL Image

# To tensor
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0)  # Add batch dimension

# Move input to gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_batch = input_batch.to(device)
model.to(device)

# Pass the image through the model to collect activations
with torch.no_grad():
    output = model(input_batch)

In [None]:
from collections import defaultdict

# Initialize a dictionary to hold groups
layer_groups = defaultdict(list)

for layer_name in activations.keys():
    #print(layer_name)
    # Determine the group name
    if layer_name.startswith('layer'):
        # Extract the layer number, e.g., 'layer1', 'layer2'
        group_name = layer_name.split('.')[0]
    else:
        # Layers that dont start with "layer"
        group_name = layer_name

    # Add the layer to group
    layer_groups[group_name].append(layer_name)

for group_name in sorted(layer_groups.keys()):
    print(f"{group_name}:")
    for lname in layer_groups[group_name]:
        print(f"  {lname}")

In [6]:
def plot_activations(activations, original_image, num_activations=5, layer_types=None):
    for layer_name, activation in activations.items():
        if layer_types:
            # Only plot if layer_name contains any of the layer_types
            if not any(layer_type in layer_name for layer_type in layer_types):
                continue  

        # Check if activation is a 4D tensor (batch_size, channels, height, width)
        if activation.dim() == 4:
            # Get the number of channels
            num_channels = activation.size(1)
            # Randomly select channels to plot
            channel_indices = random.sample(range(num_channels), min(num_activations, num_channels))
                
            fig, axes = plt.subplots(1, len(channel_indices), figsize=(15, 5))
            original_plot_idx = 2
            for idx, channel_idx in enumerate(channel_indices):                    
                ax = axes[idx]
                # Extract the activation of the selected channel
                act = activation[0, channel_idx].cpu().numpy()
                
                # Plots the original image on the central plot
                if idx == original_plot_idx:
                    img = ax.imshow(original_image, cmap='gray')
                    ax.set_title(f'Original Image')
                else:
                    img = ax.imshow(act, cmap='RdYlGn')
                    ax.set_title(f'Layer: {layer_name}\nChannel: {channel_idx}')
                
                # Add colorbar for activations
                if idx != original_plot_idx:
                    plt.colorbar(img, ax=ax, orientation='vertical', fraction=0.046, pad=0.04)

                ax.axis('off')
            plt.tight_layout()
            plt.show()

In [None]:
plot_activations(activations, img, num_activations=5, layer_types=['conv'])

In [None]:
plot_activations(activations, img, num_activations=5)