In [1]:
# ConvProbe class
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import random
import numpy as np

# Dataset
import sys
sys.path.append('/home/fonta42/Desktop/masters-degree/data/vess-map/')
from vess_map_dataset import VessMapDataset

# Training loop
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import jaccard_score  # For calculating IoU
import numpy as np

In [2]:
class ConvProbe(nn.Module):
    def __init__(self, layer_name, output_size=(256, 256), seed=None, single_channel=True):
        super(ConvProbe, self).__init__()

        # Load pre-trained ResNet-18 model
        self.model = models.resnet18(pretrained=True)
        self.layer_name = layer_name
        self.output_size = output_size
        self.single_channel = single_channel

        # Freeze model weights to prevent training
        for param in self.model.parameters():
            param.requires_grad = False

        # Get number of channels from the specified layer
        in_channels = 1 if self.single_channel else self.get_layer_channels(layer_name)


        # Initialize the convolutional probe (to be trained)
        self.conv_probe = nn.Conv2d(
            in_channels=in_channels,  # Use the layer's number of channels
            out_channels=1,  # Single output channel for binary segmentation
            kernel_size=3,
            padding=1  # Add padding to preserve spatial dimensions
        )

        # Set a fixed random seed for consistency, if provided
        if seed is not None:
            self.set_seed(seed)

        # Register a hook to capture activations from the specified layer
        self.activations = {}
        self.register_hook()

    # Helper function to set random seed
    def set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Helper function to get the number of channels from the specified layer
    def get_layer_channels(self, layer_name):
        # Iterate over the named modules to find the specified layer
        for name, module in self.model.named_modules():
            if name == layer_name:
                if hasattr(module, 'out_channels'):
                    return module.out_channels
                elif hasattr(module, 'num_features'):
                    return module.num_features  # For BatchNorm layers
                else:
                    raise ValueError(f"Layer {layer_name} does not have out_channels or num_features.")
        raise ValueError(f"Layer {layer_name} not found.")

    # Register a forward hook to capture activations
    def register_hook(self):
        def hook_fn(module, input, output):
            self.activations[self.layer_name] = output

        # Register the hook on the specified layer
        for name, module in self.model.named_modules():
            if name == self.layer_name:
                module.register_forward_hook(hook_fn)
                break

    # Forward pass through the model
    def forward(self, x, channel_idx=None):
        _ = self.model(x)  # Forward pass to get activations
        activation = self.activations[self.layer_name]  # Retrieve the stored activation

        # Select a specific channel if provided
        if channel_idx is not None:
            if channel_idx < 0 or channel_idx >= activation.size(1):
                raise ValueError(f"Channel index {channel_idx} is out of bounds for activation with {activation.size(1)} channels.")
            activation = activation[:, channel_idx:channel_idx+1, :, :]  # Select the specific channel

        # Interpolate activation to match the desired output size
        activation = F.interpolate(activation, size=self.output_size, mode='bilinear', align_corners=False)

        # Apply the convolutional probe to the activation
        out = self.conv_probe(activation)

        return out

In [3]:
# Define paths
image_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/images'
mask_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/labels'
skeleton_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/skeletons'

image_size = 256

# Initialize the dataset
vess_dataset = VessMapDataset(image_dir, mask_dir, skeleton_dir, image_size, apply_transform=True)

# Get the train and test loaders
train_loader, test_loader = vess_dataset.vess_map_dataloader(batch_size=100, train_size=0.8)

In [4]:
def calculate_iou(outputs, masks, threshold=0.5):
    # Apply sigmoid to outputs to get probabilities between 0 and 1
    preds = torch.sigmoid(outputs)
    preds = (preds > threshold).float()  # Convert probabilities to binary predictions

    # Ensure masks are float type
    masks = masks.float()

    # Compute intersection and union
    intersection = (preds * masks).sum(dim=(1, 2))
    union = ((preds + masks) > 0).float().sum(dim=(1, 2))

    # Avoid division by zero
    iou = torch.where(union == 0, torch.tensor(1.0).to(outputs.device), intersection / union)

    # Return mean IoU over the batch
    return iou.mean().item()

In [5]:
# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
# Training code
def train_conv_probe_all_channels(layer_name, num_channels, train_loader, test_loader, image_size, device, num_epochs=50, lr=0.001):
    iou_results = {}  # Dictionary to store IoU results for each channel

    for channel_idx in range(num_channels):
        print(f"Training for channel {channel_idx} of layer {layer_name}")

        # Initialize the ConvProbe model for the current channel
        probe_model = ConvProbe(layer_name=layer_name, output_size=(image_size, image_size), seed=42)
        probe_model.to(device)

        # Define loss function (BCEWithLogitsLoss for binary segmentation)
        criterion = nn.BCEWithLogitsLoss()

        # Define optimizer
        optimizer = optim.Adam(probe_model.conv_probe.parameters(), lr=lr)

        # Track the best IoU for this channel
        best_iou = 0.0
        best_iou_train = 0.0

        for epoch in range(num_epochs):
            probe_model.train()
            running_loss = 0.0
            train_iou = 0.0

            # Training loop
            for inputs, masks, _ in train_loader:
                inputs = inputs.to(device)
                masks = masks.to(device).float()  # Ensure masks are float type

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass using the current channel
                outputs = probe_model(inputs, channel_idx=channel_idx)
                outputs = outputs.squeeze(1)  # Shape: [batch_size, height, width]
                masks = masks.squeeze(1)      # Shape: [batch_size, height, width]

                # Compute loss
                loss = criterion(outputs, masks)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                # Accumulate running loss
                running_loss += loss.item() * inputs.size(0)

                # Compute IoU for this batch
                batch_iou = calculate_iou(outputs, masks)
                train_iou += batch_iou * inputs.size(0)

            # Calculate epoch loss and IoU
            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_iou = train_iou / len(train_loader.dataset)
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train IoU: {epoch_iou:.4f}')

            # Validation loop
            probe_model.eval()
            val_loss = 0.0
            val_iou = 0.0
            with torch.no_grad():  # Disable gradient calculations during validation
                for val_inputs, val_masks, _ in test_loader:
                    val_inputs = val_inputs.to(device)
                    val_masks = val_masks.to(device).float()  # Ensure masks are float type

                    # Forward pass using the current channel
                    val_outputs = probe_model(val_inputs, channel_idx=channel_idx)
                    val_outputs = val_outputs.squeeze(1)  # Shape: [batch_size, height, width]
                    val_masks = val_masks.squeeze(1)      # Shape: [batch_size, height, width]

                    # Compute loss
                    loss = criterion(val_outputs, val_masks)
                    val_loss += loss.item() * val_inputs.size(0)

                    # Compute IoU for this batch
                    batch_val_iou = calculate_iou(val_outputs, val_masks)
                    val_iou += batch_val_iou * val_inputs.size(0)

            # Calculate validation loss and IoU
            val_loss /= len(test_loader.dataset)
            val_iou /= len(test_loader.dataset)
            print(f'Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')

            # Save the model if the IoU on validation set improves
            if val_iou > best_iou:
                best_iou = val_iou
                best_iou_train = epoch_iou
                torch.save(probe_model.state_dict(), f"best_model_layer_{layer_name}_channel_{channel_idx}.pth")

        # Save IoU metrics for this channel
        iou_results[channel_idx] = {'train_iou': best_iou_train, 'val_iou': best_iou}

    return iou_results

# Specify the layer name you want to probe and the number of channels in the layer
layer_name = 'layer1.0.conv1'
num_channels = 64  # Specify the correct number of channels for this layer

# Call the training function
iou_results = train_conv_probe_all_channels(layer_name, num_channels, train_loader, test_loader, image_size, device)

# The IoU results are now stored in the iou_results dictionary
iou_results


Training for channel 0 of layer layer1.0.conv1




Epoch 1/50, Loss: 0.6538, Train IoU: 0.1934
Validation Loss: 0.6429, Validation IoU: 0.0051
Epoch 2/50, Loss: 0.6525, Train IoU: 0.1898
Validation Loss: 0.6420, Validation IoU: 0.0051
Epoch 3/50, Loss: 0.6524, Train IoU: 0.1804
Validation Loss: 0.6413, Validation IoU: 0.0047
Epoch 4/50, Loss: 0.6515, Train IoU: 0.1741
Validation Loss: 0.6429, Validation IoU: 0.0052
Epoch 5/50, Loss: 0.6509, Train IoU: 0.1692
Validation Loss: 0.6412, Validation IoU: 0.0047
Epoch 6/50, Loss: 0.6498, Train IoU: 0.1600
Validation Loss: 0.6404, Validation IoU: 0.0042
Epoch 7/50, Loss: 0.6501, Train IoU: 0.1567
Validation Loss: 0.6397, Validation IoU: 0.0047
Epoch 8/50, Loss: 0.6489, Train IoU: 0.1456
Validation Loss: 0.6385, Validation IoU: 0.0047
Epoch 9/50, Loss: 0.6493, Train IoU: 0.1399
Validation Loss: 0.6375, Validation IoU: 0.0043
Epoch 10/50, Loss: 0.6485, Train IoU: 0.1361
Validation Loss: 0.6377, Validation IoU: 0.0047
Epoch 11/50, Loss: 0.6476, Train IoU: 0.1237
Validation Loss: 0.6387, Validatio

{0: {'train_iou': 0.019819479435682297, 'val_iou': 0.007086846511811018},
 1: {'train_iou': 0.3331749439239502, 'val_iou': 0.2583027482032776},
 2: {'train_iou': 0.22809568047523499, 'val_iou': 0.23373261094093323},
 3: {'train_iou': 0.023337652906775475, 'val_iou': 0.07883557677268982},
 4: {'train_iou': 0.0019053773721680045, 'val_iou': 0.0022562628146260977},
 5: {'train_iou': 0.0030571818351745605, 'val_iou': 0.002256092382594943},
 6: {'train_iou': 0.15609829127788544, 'val_iou': 0.23499742150306702},
 7: {'train_iou': 0.2236640900373459, 'val_iou': 0.22666750848293304},
 8: {'train_iou': 0.0024948965292423964, 'val_iou': 0.0022370938677340746},
 9: {'train_iou': 0.042495984584093094, 'val_iou': 0.028424764052033424},
 10: {'train_iou': 0.23182006180286407, 'val_iou': 0.2348329871892929},
 11: {'train_iou': 0.002046882873401046, 'val_iou': 0.002749454928562045},
 12: {'train_iou': 0.17671269178390503, 'val_iou': 0.23449666798114777},
 13: {'train_iou': 0.02316795289516449, 'val_io