In [1]:
# ConvProbe class
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import random
import numpy as np

# Dataset
import sys
sys.path.append('/home/fonta42/Desktop/masters-degree/data/vess-map/')
from vess_map_dataset import VessMapDataset

# Training loop
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import jaccard_score  # For calculating IoU
import numpy as np

In [2]:
class ConvProbe(nn.Module):
    def __init__(self, layer_name, output_size=(256, 256), seed=42):
        super(ConvProbe, self).__init__()

        # Load pre-trained ResNet-18 model
        self.model = models.resnet18(pretrained=True)
        self.layer_name = layer_name
        self.output_size = output_size

        # Freeze model weights
        for param in self.model.parameters():
            param.requires_grad = False

        # Initialize the convolutional probe (to be trained)
        self.conv_probe = nn.Conv2d(
            in_channels=self.get_layer_channels(layer_name), # Use the layer channels
            out_channels=2,  # For segmentation output
            kernel_size=3,
            padding=1  # Add padding to preserve input size
        )

        # Set a fixed random seed for consistency 
        self.set_seed(seed)

        # Register a hook for the selected layer
        self.activations = {}
        self.register_hook()

    # Set a bunch of seed to consistency training
    def set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    def get_layer_channels(self, layer_name):
        # Helper function to get the number of channels for the specified layer
        for name, module in self.model.named_modules():
            if name == layer_name:
                if isinstance(module, nn.Conv2d) or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.ReLU):
                    return module.out_channels
                elif isinstance(module, nn.Sequential):
                    # For sequential layers, get the out_channels of the last module
                    return list(module.children())[-1].out_channels
                else:
                    raise ValueError(f"Unsupported module type: {type(module)} for layer {layer_name}")
        raise ValueError(f"Layer {layer_name} not found in the model.")

    def register_hook(self):
        def hook_fn(module, input, output):
            self.activations[self.layer_name] = output

        # Register the hook
        for name, module in self.model.named_modules():
            if name == self.layer_name:
                module.register_forward_hook(hook_fn)
                break

    def forward(self, x):
        _ = self.model(x)  # Forward pass to get activations
        activation = self.activations[self.layer_name]  # Get hooked activation

        # Interpolate activation to match input size
        activation = F.interpolate(
            activation, size=self.output_size, mode='bilinear', align_corners=False
        )

        # Apply convolutional probe
        out = self.conv_probe(activation)

        return out

In [3]:
# Define paths
image_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/images'
mask_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/labels'
skeleton_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/skeletons'

image_size = 256

# Initialize the dataset
vess_dataset = VessMapDataset(image_dir, mask_dir, skeleton_dir, image_size, apply_transform=True)

# Get the train and test loaders
train_loader, test_loader = vess_dataset.vess_map_dataloader(batch_size=4, train_size=0.8)

In [4]:
# Define IoU calculation function
def calculate_iou(preds, labels):
    preds = torch.argmax(preds, dim=1)  
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()

    preds_flat = preds.flatten()
    labels_flat = labels.flatten()

    iou = jaccard_score(labels_flat, preds_flat, average='macro')
    return iou

In [None]:
# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
layer_name = 'layer1.0.conv1'  

# Initialize the ConvProbe model
probe_model = ConvProbe(layer_name=layer_name, output_size=(image_size, image_size), seed=42)

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
probe_model.to(device)

# Define loss function (Cross Entropy Loss for multi-class segmentation)
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.Adam(probe_model.conv_probe.parameters(), lr=0.001)

# Number of epochs
num_epochs = 5

# Track the best IoU
best_iou = 0.0

for epoch in range(num_epochs):
    probe_model.train()
    running_loss = 0.0
    train_iou = 0.0
    
    # Training loop
    for inputs, masks, _ in train_loader:  # Assuming skeletons are not needed
        inputs = inputs.to(device)
        masks = masks.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = probe_model(inputs)

        # Compute loss (use masks with long type for CrossEntropyLoss)
        masks = masks.squeeze(1)  # Shape: [batch_size, height, width]
        loss = criterion(outputs, masks.long())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

        # Compute IoU for this batch
        batch_iou = calculate_iou(outputs, masks)
        train_iou += batch_iou * inputs.size(0)

    # Calculate epoch loss and IoU
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_iou = train_iou / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train IoU: {epoch_iou:.4f}')

    # Validation loop
    probe_model.eval()
    val_loss = 0.0
    val_iou = 0.0
    with torch.no_grad():  # Disable gradient calculations during validation
        for val_inputs, val_masks, _ in test_loader:  # Assuming skeletons are not needed
            val_inputs = val_inputs.to(device)
            val_masks = val_masks.to(device)
            val_masks = val_masks.squeeze(1)  # Shape: [batch_size, height, width]

            # Forward pass
            val_outputs = probe_model(val_inputs)

            # Compute loss
            val_loss += criterion(val_outputs, val_masks.long()).item() * val_inputs.size(0)

            # Compute IoU for this batch
            batch_val_iou = calculate_iou(val_outputs, val_masks)
            val_iou += batch_val_iou * val_inputs.size(0)

    # Calculate validation loss and IoU
    val_loss /= len(test_loader.dataset)
    val_iou /= len(test_loader.dataset)
    print(f'Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')

    # Save the model if the IoU on validation set improves
    if val_iou > best_iou:
        best_iou = val_iou
        torch.save(probe_model.state_dict(), f"best_model_layer_{layer_name}.pth")

print('Training complete')