In [2]:
# ConvProbe class
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import random
import numpy as np

# Dataset
import sys
sys.path.append('../data/vess-map/')
from vess_map_dataset import VessMapDataset

# Training loop
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Save the results
import json
import time

# Clear memory
import gc

In [2]:
class ConvProbe(nn.Module):
    def __init__(self, model, layer_name, output_size=(256, 256), seed=None, single_channel=True):
        super(ConvProbe, self).__init__()

        # Load pre-trained ResNet-18 model
        self.model =  model
        self.layer_name = layer_name
        self.output_size = output_size
        self.single_channel = single_channel

        # Freeze model weights to prevent training
        for param in self.model.parameters():
            param.requires_grad = False

        # Get number of channels from the specified layer
        in_channels = 1 if self.single_channel else self.get_layer_channels(layer_name)


        # Initialize the first convolutional layer (to be trained)
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,  # Numbers of channels for the activation
            out_channels=1,  
            kernel_size=3,
            padding=1  # Preserve spatial dimensions
        )

        # Initialize the second convolutional layer (to be trained)
        self.conv2 = nn.Conv2d(
            in_channels=1,  # Input channels from conv1
            out_channels=1,  # Single output channel for binary segmentation
            kernel_size=3,
            padding=1  # Preserve spatial dimensions
        )

        # Set a fixed random seed for consistency, if provided
        if seed is not None:
            self.set_seed(seed)

        # Register a hook to capture activations from the specified layer
        self.activations = {}
        self.register_hook()

    # Helper function to set random seed
    def set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Helper function to get the number of channels from the specified layer
    def get_layer_channels(self, layer_name):
        # Iterate over the named modules to find the specified layer
        for name, module in self.model.named_modules():
            if name == layer_name:
                if hasattr(module, 'out_channels'):
                    return module.out_channels
                elif hasattr(module, 'num_features'):
                    return module.num_features  # For BatchNorm layers
                else:
                    raise ValueError(f"Layer {layer_name} does not have out_channels or num_features.")
        raise ValueError(f"Layer {layer_name} not found.")

    # Register a forward hook to capture activations
    def register_hook(self):
        def hook_fn(module, input, output):
            self.activations[self.layer_name] = output

        # Register the hook on the specified layer
        for name, module in self.model.named_modules():
            if name == self.layer_name:
                module.register_forward_hook(hook_fn)
                break

    # Forward pass through the model
    def forward(self, x, channel_idx=None):
        _ = self.model(x)  # Forward pass to get activations
        activation = self.activations[self.layer_name]  # Retrieve the stored activation

        # Select a specific channel if provided
        if channel_idx is not None:
            if channel_idx < 0 or channel_idx >= activation.size(1):
                raise ValueError(f"Channel index {channel_idx} is out of bounds for activation with {activation.size(1)} channels.")
            activation = activation[:, channel_idx:channel_idx+1, :, :]  # Select the specific channel

        # Apply the first convolutional layer
        activation = self.conv1(activation)

        # Interpolate activation to match the desired output size
        activation = F.interpolate(activation, size=self.output_size, mode='bilinear', align_corners=True)

        # Apply the second convolutional layer
        out = self.conv2(activation) 
        
        return out

In [3]:
# Define paths
image_dir = '../data/vess-map/images'
mask_dir = '../data/vess-map/labels'
skeleton_dir = '../data/vess-map/skeletons'

image_size = 256

# Initialize the dataset
vess_dataset = VessMapDataset(image_dir, mask_dir, skeleton_dir, image_size, apply_transform=True)

# Get the train and test loaders
train_loader, test_loader = vess_dataset.vess_map_dataloader(batch_size=80, train_size=0.8)

In [5]:
def calculate_iou(outputs, masks, threshold=0.5):
    # Apply sigmoid to outputs to get probabilities between 0 and 1
    preds = torch.sigmoid(outputs)
    preds = (preds > threshold).float()  # Convert probabilities to binary predictions

    # Ensure masks are float type
    masks = masks.float()

    # Compute intersection and union
    intersection = (preds * masks).sum(dim=(1, 2))
    union = ((preds + masks) > 0).float().sum(dim=(1, 2))

    # Avoid division by zero
    iou = torch.where(union == 0, torch.tensor(1.0).to(outputs.device), intersection / union)

    # Return mean IoU over the batch
    return iou.mean().item()

In [6]:
# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
def train_conv_probe_all_channels(model, layer_name, num_channels, train_loader, test_loader, image_size, device, num_epochs=50, lr=0.001, single_channel=True):
    iou_results = {}  # Dictionary to store IoU results for each channel

    for channel_idx in range(num_channels):
        print(f"Training for channel {channel_idx} of layer {layer_name}")
        probe_model = ConvProbe(model=model, layer_name=layer_name, output_size=(
            image_size, image_size), seed=42, single_channel=single_channel)
        probe_model.to(device)

        # BinaryCEWithLogitsLoss for binary segmentation
        criterion = nn.BCEWithLogitsLoss()

        # Define optimizer for ConvProbe parameters only
        optimizer = optim.Adam(list(probe_model.conv1.parameters(
        )) + list(probe_model.conv2.parameters()), lr=lr)

        # Track metrics for this channel
        train_losses = []
        train_ious = []
        val_losses = []
        val_ious = []

        # Get a sample image from the validation loader for saving outputs
        sample_inputs, sample_masks, _ = next(iter(test_loader))
        sample_inputs = sample_inputs.to(device)[:1]  # Take one sample
        sample_masks = sample_masks.to(device).float()[:1]

        for epoch in range(num_epochs):
            probe_model.train()
            running_loss = 0.0
            train_iou = 0.0

            # Training loop
            for inputs, masks, _ in train_loader:
                inputs = inputs.to(device)
                masks = masks.to(device).float()  # Ensure masks are float type

                # Zero the parameter gradients
                optimizer.zero_grad()

                if single_channel:
                    # Forward pass using the current channel
                    outputs = probe_model(inputs, channel_idx=channel_idx)
                else:
                    outputs = probe_model(inputs)

                # Shape: [batch_size, height, width]
                outputs = outputs.squeeze(1)
                # Shape: [batch_size, height, width]
                masks = masks.squeeze(1)

                # Compute loss
                loss = criterion(outputs, masks)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                # Accumulate running loss
                running_loss += loss.item() * inputs.size(0)

                # Compute IoU for this batch
                batch_iou = calculate_iou(outputs, masks)
                train_iou += batch_iou * inputs.size(0)

            # Calculate epoch loss and IoU
            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_iou = train_iou / len(train_loader.dataset)
            if epoch % (num_epochs - 1) == 0:
                print(
                    f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train IoU: {epoch_iou:.4f}')

            # Append metrics
            train_losses.append(epoch_loss)
            train_ious.append(epoch_iou)

            # Validation loop
            probe_model.eval()
            val_loss = 0.0
            val_iou = 0.0
            with torch.no_grad():
                for val_inputs, val_masks, _ in test_loader:
                    val_inputs = val_inputs.to(device)
                    # Ensure masks are float type
                    val_masks = val_masks.to(device).float()

                    # Forward pass using the current channel
                    if single_channel:
                        val_outputs = probe_model(
                            val_inputs, channel_idx=channel_idx)
                    else:
                        val_outputs = probe_model(val_inputs)

                    # Shape: [batch_size, height, width]
                    val_outputs = val_outputs.squeeze(1)
                    # Shape: [batch_size, height, width]
                    val_masks = val_masks.squeeze(1)

                    # Compute loss
                    loss = criterion(val_outputs, val_masks)
                    val_loss += loss.item() * val_inputs.size(0)

                    # Compute IoU for this batch
                    batch_val_iou = calculate_iou(val_outputs, val_masks)
                    val_iou += batch_val_iou * val_inputs.size(0)

            # Calculate validation loss and IoU
            val_loss /= len(test_loader.dataset)
            val_iou /= len(test_loader.dataset)
            if epoch % (num_epochs - 1) == 0:
                print(
                f'Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')

            # Append validation metrics
            val_losses.append(val_loss)
            val_ious.append(val_iou)

            # Save model output on sample image for debugging
            if epoch == num_epochs - 1 and val_iou < 0.0001:
                with torch.no_grad():
                    if single_channel:
                        sample_output = probe_model(
                            sample_inputs, channel_idx=channel_idx)
                    else:
                        sample_output = probe_model(
                            sample_inputs)
                    sample_output = sample_output.squeeze(1)
                    sample_output_np = sample_output.cpu().numpy()[0]
                    # Save the image
                    plt.imsave(f"./conv-probe-debbug/output_layer_{layer_name}_channel_{channel_idx}_epoch_{epoch+1}.png",
                               sample_output_np,
                               cmap='gray')

        # Save per-epoch metrics for this channel
        iou_results[channel_idx] = {
            'train_loss': train_losses,
            'train_iou': train_ious,
            'val_loss': val_losses,
            'val_iou': val_ious
        }

        # Free GPU memory after training for the current channel is completed
        del probe_model, optimizer, criterion
        torch.cuda.empty_cache()
        gc.collect()
        if not single_channel:
            break

    del train_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()

    return iou_results

In [7]:
# Function to get layer names and number of channels from ResNet-18
def get_resnet18_layers_info():
    resnet18 = models.resnet18(pretrained=True)
    layers_info = {}
    for name, module in resnet18.named_modules():
        if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
            if hasattr(module, 'out_channels'):
                out_channels = module.out_channels
            elif hasattr(module, 'num_features'):
                out_channels = module.num_features
            else:
                continue
            layers_info[name] = out_channels
    return layers_info

In [8]:
# Function to save results to a JSON file
def save_results_to_json(results, filename):
    # Convert tensors or numpy types to native Python types
    def convert(o):
        if isinstance(o, np.float32) or isinstance(o, np.float64):
            return float(o)
        if isinstance(o, torch.Tensor):
            return o.item()
        raise TypeError

    with open(filename, 'w') as f:
        json.dump(results, f, default=convert)

In [None]:
import time
import json

# Function to train the probe for all layers and collect results
def train_conv_probe_all_layers(train_loader, test_loader, image_size, device, num_epochs=20, lr=0.001):
    all_layers_iou_results = {}
    time_results = {}  # Dictionary to store time taken for each layer
    layers_info = get_resnet18_layers_info()
    
    start_total_time = time.time()  # Start timing for the entire process
    
    for layer_name, num_channels in layers_info.items():
        print(f"\nProcessing layer: {layer_name} with {num_channels} channels")
        
        # Start timing for each layer
        start_layer_time = time.time()
        
        # Declare the model to use
        model = models.resnet18(pretrained=True)
        
        # Train the conv probe for the current layer
        iou_results = train_conv_probe_all_channels(
            model=model,
            layer_name=layer_name,
            num_channels=num_channels,
            train_loader=train_loader,
            test_loader=test_loader,
            image_size=image_size,
            device=device,
            num_epochs=num_epochs,
            lr=lr
        )
        
        # Save individual layer results
        save_results_to_json(iou_results, f'./conv-probe-layers-results/{layer_name}_activation_results.json')
        all_layers_iou_results[layer_name] = iou_results
        
        # Calculate time taken for this layer and add to time_results
        layer_time_taken = time.time() - start_layer_time
        time_results[layer_name] = layer_time_taken
        print(f"Time taken for {layer_name}: {layer_time_taken:.2f} seconds")
        
        # Free GPU memory
        del model
        torch.cuda.memory_reserved(device=device)
        torch.cuda.memory_allocated(device=device)
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()  # Collect unused GPU memory

    # Calculate and store the total time taken
    total_time_taken = time.time() - start_total_time
    time_results['total_time'] = total_time_taken
    print(f"Total time taken: {total_time_taken:.2f} seconds")
    
    # Save all layers' IoU results and time results to JSON
    save_results_to_json(all_layers_iou_results, './conv-probe-layers-results/resnet18_layers_iou_results.json')
    with open('./conv-probe-layers-results/time_results.json', 'w') as f:
        json.dump(time_results, f, indent=4)

    return all_layers_iou_results, time_results

In [11]:
# Call the training function
all_layers_iou_results = train_conv_probe_all_layers(
    train_loader=train_loader,
    test_loader=test_loader,
    image_size=image_size,
    device=device,
    num_epochs=150,
    lr=0.01
)


Processing layer: layer4.1.bn1 with 512 channels
Training for channel 0 of layer layer4.1.bn1




Epoch 1/150, Loss: 0.7216, Train IoU: 0.2320
Validation Loss: 0.7136, Validation IoU: 0.2544
Epoch 150/150, Loss: 0.5400, Train IoU: 0.0000
Validation Loss: 0.5714, Validation IoU: 0.0000
Training for channel 1 of layer layer4.1.bn1
Epoch 1/150, Loss: 0.7173, Train IoU: 0.2320
Validation Loss: 0.7068, Validation IoU: 0.2547
Epoch 150/150, Loss: 0.5378, Train IoU: 0.0000
Validation Loss: 0.5686, Validation IoU: 0.0000
Training for channel 2 of layer layer4.1.bn1
Epoch 1/150, Loss: 0.7174, Train IoU: 0.2320
Validation Loss: 0.7069, Validation IoU: 0.2545
Epoch 150/150, Loss: 0.5392, Train IoU: 0.0000
Validation Loss: 0.5706, Validation IoU: 0.0000
Training for channel 3 of layer layer4.1.bn1
Epoch 1/150, Loss: 0.7176, Train IoU: 0.2320
Validation Loss: 0.7071, Validation IoU: 0.2545
Epoch 150/150, Loss: 0.5397, Train IoU: 0.0000
Validation Loss: 0.5699, Validation IoU: 0.0000
Training for channel 4 of layer layer4.1.bn1
Epoch 1/150, Loss: 0.7174, Train IoU: 0.2320
Validation Loss: 0.7071

In [12]:
""" import segmentation_models_pytorch as smp

def train_unet_model(train_loader, val_loader, device, num_epochs=50, lr=0.001):
    # Define the model
    model = smp.Unet(
        encoder_name="resnet18",        # Choose encoder, e.g. resnet18
        encoder_weights=None,           # Use NONE pre-trained weights for encoderFalse
        in_channels=3,                  # Input channels (3 for RGB images)
        classes=1,                      # Output channels (1 for binary segmentation)
        activation=None                 # No activation function on output
    )
    model.to(device)

    # Define loss function
    criterion = nn.BCEWithLogitsLoss()

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, masks, _ in train_loader:
            inputs = inputs.to(device)
            masks = masks.to(device).float()

            optimizer.zero_grad()

            outputs = model(inputs)
            outputs = outputs.squeeze(1)
            masks = masks.squeeze(1)

            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

        # You can add validation and metric calculation similar to previous code

    # Save the trained model
    torch.save(model.state_dict(), './models/conv_probe_unet_vess_map.pth')

    return model

unet_model = train_unet_model(train_loader=train_loader, val_loader=test_loader, device=device, num_epochs=200, lr=0.001) """

' import segmentation_models_pytorch as smp\n\ndef train_unet_model(train_loader, val_loader, device, num_epochs=50, lr=0.001):\n    # Define the model\n    model = smp.Unet(\n        encoder_name="resnet18",        # Choose encoder, e.g. resnet18\n        encoder_weights=None,           # Use NONE pre-trained weights for encoderFalse\n        in_channels=3,                  # Input channels (3 for RGB images)\n        classes=1,                      # Output channels (1 for binary segmentation)\n        activation=None                 # No activation function on output\n    )\n    model.to(device)\n\n    # Define loss function\n    criterion = nn.BCEWithLogitsLoss()\n\n    # Define optimizer\n    optimizer = optim.Adam(model.parameters(), lr=lr)\n\n    # Training loop\n    for epoch in range(num_epochs):\n        model.train()\n        running_loss = 0.0\n\n        for inputs, masks, _ in train_loader:\n            inputs = inputs.to(device)\n            masks = masks.to(device).float

In [13]:
""" layer_name = 'decoder.blocks.4.conv2.0' #TODO: check if this is the best layer to test, the results are OK

num_channels = 16

# Initialize ConvProbe with the UNet model
iou_results = train_conv_probe_all_channels(
            model=unet_model,
            layer_name=layer_name,
            num_channels=num_channels,
            train_loader=train_loader,
            test_loader=test_loader,
            image_size=image_size,
            device=device,
            num_epochs=100,
            lr=0.001,
            single_channel=False) """

" layer_name = 'decoder.blocks.4.conv2.0' #TODO: check if this is the best layer to test, the results are OK\n\nnum_channels = 16\n\n# Initialize ConvProbe with the UNet model\niou_results = train_conv_probe_all_channels(\n            model=unet_model,\n            layer_name=layer_name,\n            num_channels=num_channels,\n            train_loader=train_loader,\n            test_loader=test_loader,\n            image_size=image_size,\n            device=device,\n            num_epochs=100,\n            lr=0.001,\n            single_channel=False) "