In [26]:
# ConvProbe class
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import random
import numpy as np

# Dataset
import sys
sys.path.append('/home/fonta42/Desktop/masters-degree/data/vess-map/')
from vess_map_dataset import VessMapDataset

# Training loop
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import jaccard_score  # For calculating IoU
import numpy as np
import matplotlib.pyplot as plt

# Save the results
import json

# Clear memory
import gc

In [27]:
class ConvProbe(nn.Module):
    def __init__(self, model, layer_name, output_size=(256, 256), seed=None, single_channel=True):
        super(ConvProbe, self).__init__()

        # Load pre-trained ResNet-18 model
        self.model =  model
        self.layer_name = layer_name
        self.output_size = output_size
        self.single_channel = single_channel

        # Freeze model weights to prevent training
        for param in self.model.parameters():
            param.requires_grad = False

        # Get number of channels from the specified layer
        in_channels = 1 if self.single_channel else self.get_layer_channels(layer_name)


        # Initialize the first convolutional layer (to be trained)
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,  # Numbers of channels for the activation
            out_channels=16,  # TODO: check if 16 is ok
            kernel_size=3,
            padding=1  # Preserve spatial dimensions
        )

        # Initialize the second convolutional layer (to be trained)
        self.conv2 = nn.Conv2d(
            in_channels=16,  # Input channels from conv1
            out_channels=1,  # Single output channel for binary segmentation
            kernel_size=3,
            padding=1  # Preserve spatial dimensions
        )

        # Set a fixed random seed for consistency, if provided
        if seed is not None:
            self.set_seed(seed)

        # Register a hook to capture activations from the specified layer
        self.activations = {}
        self.register_hook()

    # Helper function to set random seed
    def set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Helper function to get the number of channels from the specified layer
    def get_layer_channels(self, layer_name):
        # Iterate over the named modules to find the specified layer
        for name, module in self.model.named_modules():
            if name == layer_name:
                if hasattr(module, 'out_channels'):
                    return module.out_channels
                elif hasattr(module, 'num_features'):
                    return module.num_features  # For BatchNorm layers
                else:
                    raise ValueError(f"Layer {layer_name} does not have out_channels or num_features.")
        raise ValueError(f"Layer {layer_name} not found.")

    # Register a forward hook to capture activations
    def register_hook(self):
        def hook_fn(module, input, output):
            self.activations[self.layer_name] = output

        # Register the hook on the specified layer
        for name, module in self.model.named_modules():
            if name == self.layer_name:
                module.register_forward_hook(hook_fn)
                break

    # Forward pass through the model
    def forward(self, x, channel_idx=None):
        _ = self.model(x)  # Forward pass to get activations
        activation = self.activations[self.layer_name]  # Retrieve the stored activation

        # Select a specific channel if provided
        if channel_idx is not None:
            if channel_idx < 0 or channel_idx >= activation.size(1):
                raise ValueError(f"Channel index {channel_idx} is out of bounds for activation with {activation.size(1)} channels.")
            activation = activation[:, channel_idx:channel_idx+1, :, :]  # Select the specific channel

        # Apply the first convolutional layer
        activation = self.conv1(activation)

        # Interpolate activation to match the desired output size
        activation = F.interpolate(activation, size=self.output_size, mode='bilinear', align_corners=True)

        # Apply the second convolutional layer
        out = self.conv2(activation) # TODO: add a second conv
        
        return out

In [28]:
# Define paths
image_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/images'
mask_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/labels'
skeleton_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/skeletons'

image_size = 256

# Initialize the dataset
vess_dataset = VessMapDataset(image_dir, mask_dir, skeleton_dir, image_size, apply_transform=True)

# Get the train and test loaders
train_loader, test_loader = vess_dataset.vess_map_dataloader(batch_size=32, train_size=0.8)

In [29]:
def calculate_iou(outputs, masks, threshold=0.5):
    # Apply sigmoid to outputs to get probabilities between 0 and 1
    preds = torch.sigmoid(outputs)
    preds = (preds > threshold).float()  # Convert probabilities to binary predictions

    # Ensure masks are float type
    masks = masks.float()

    # Compute intersection and union
    intersection = (preds * masks).sum(dim=(1, 2))
    union = ((preds + masks) > 0).float().sum(dim=(1, 2))

    # Avoid division by zero
    iou = torch.where(union == 0, torch.tensor(1.0).to(outputs.device), intersection / union)

    # Return mean IoU over the batch
    return iou.mean().item()

In [30]:
# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [68]:
def train_conv_probe_all_channels(model, layer_name, num_channels, train_loader, test_loader, image_size, device, num_epochs=50, lr=0.001, single_channel=True):
    iou_results = {}  # Dictionary to store IoU results for each channel

    for channel_idx in range(num_channels):
        print(f"Training for channel {channel_idx} of layer {layer_name}")
        probe_model = ConvProbe(model=model, layer_name=layer_name, output_size=(
            image_size, image_size), seed=42, single_channel=single_channel)
        probe_model.to(device)

        # BinaryCEWithLogitsLoss for binary segmentation
        criterion = nn.BCEWithLogitsLoss()

        # Define optimizer for ConvProbe parameters only #TODO adjust to train only convprobe parameters
        optimizer = optim.Adam(list(probe_model.conv1.parameters(
        )) + list(probe_model.conv2.parameters()), lr=lr)

        # TODO: save metrics for every epoch

        # Track metrics for this channel
        train_losses = []
        train_ious = []
        val_losses = []
        val_ious = []

        # Get a sample image from the validation loader for saving outputs
        sample_inputs, sample_masks, _ = next(iter(test_loader))
        sample_inputs = sample_inputs.to(device)[:1]  # Take one sample
        sample_masks = sample_masks.to(device).float()[:1]

        for epoch in range(num_epochs):
            probe_model.train()
            running_loss = 0.0
            train_iou = 0.0

            # Training loop
            for inputs, masks, _ in train_loader:
                inputs = inputs.to(device)
                masks = masks.to(device).float()  # Ensure masks are float type

                # Zero the parameter gradients
                optimizer.zero_grad()

                if single_channel:
                    # Forward pass using the current channel
                    outputs = probe_model(inputs, channel_idx=channel_idx)
                else:
                    outputs = probe_model(inputs)

                # Shape: [batch_size, height, width]
                outputs = outputs.squeeze(1)
                # Shape: [batch_size, height, width]
                masks = masks.squeeze(1)

                # Compute loss
                loss = criterion(outputs, masks)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                # Accumulate running loss
                running_loss += loss.item() * inputs.size(0)

                # Compute IoU for this batch
                batch_iou = calculate_iou(outputs, masks)
                train_iou += batch_iou * inputs.size(0)

            # Calculate epoch loss and IoU
            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_iou = train_iou / len(train_loader.dataset)
            print(
                f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train IoU: {epoch_iou:.4f}')

            # Append metrics
            train_losses.append(epoch_loss)
            train_ious.append(epoch_iou)

            # Validation loop
            probe_model.eval()
            val_loss = 0.0
            val_iou = 0.0
            with torch.no_grad():
                for val_inputs, val_masks, _ in test_loader:
                    val_inputs = val_inputs.to(device)
                    # Ensure masks are float type
                    val_masks = val_masks.to(device).float()

                    # Forward pass using the current channel
                    if single_channel:
                        val_outputs = probe_model(
                            val_inputs, channel_idx=channel_idx)
                    else:
                        val_outputs = probe_model(val_inputs)

                    # Shape: [batch_size, height, width]
                    val_outputs = val_outputs.squeeze(1)
                    # Shape: [batch_size, height, width]
                    val_masks = val_masks.squeeze(1)

                    # Compute loss
                    loss = criterion(val_outputs, val_masks)
                    val_loss += loss.item() * val_inputs.size(0)

                    # Compute IoU for this batch
                    batch_val_iou = calculate_iou(val_outputs, val_masks)
                    val_iou += batch_val_iou * val_inputs.size(0)

            # Calculate validation loss and IoU
            val_loss /= len(test_loader.dataset)
            val_iou /= len(test_loader.dataset)
            print(
                f'Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')

            # Append validation metrics
            val_losses.append(val_loss)
            val_ious.append(val_iou)

            # Save model output on sample image for debugging
            with torch.no_grad():
                if single_channel:
                    sample_output = probe_model(
                        sample_inputs, channel_idx=channel_idx)
                else:
                    sample_output = probe_model(
                        sample_inputs)
                sample_output = sample_output.squeeze(1)
                sample_output_np = sample_output.cpu().numpy()[0]

                if epoch == num_epochs - 1 and val_iou < 0.0001:
                    # Save the image
                    plt.imsave(
                        f"./conv-probe-debbug/output_layer_{layer_name}_channel_{channel_idx}_epoch_{epoch+1}.png", sample_output_np, cmap='gray')

        # Save per-epoch metrics for this channel
        iou_results[channel_idx] = {
            'train_loss': train_losses,
            'train_iou': train_ious,
            'val_loss': val_losses,
            'val_iou': val_ious
        }

        # Free GPU memory after training for the current channel is completed
        del probe_model, optimizer, criterion
        torch.cuda.empty_cache()
        gc.collect()
        if not single_channel:
            break

    del train_loader, test_loader
    torch.cuda.empty_cache()
    gc.collect()

    return iou_results

In [32]:
# Function to get layer names and number of channels from ResNet-18
def get_resnet18_layers_info():
    resnet18 = models.resnet18(pretrained=True)
    layers_info = {}
    for name, module in resnet18.named_modules():
        if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
            if hasattr(module, 'out_channels'):
                out_channels = module.out_channels
            elif hasattr(module, 'num_features'):
                out_channels = module.num_features
            else:
                continue
            layers_info[name] = out_channels
    return layers_info

In [33]:
# Function to save results to a JSON file
def save_results_to_json(results, filename):
    # Convert tensors or numpy types to native Python types
    def convert(o):
        if isinstance(o, np.float32) or isinstance(o, np.float64):
            return float(o)
        if isinstance(o, torch.Tensor):
            return o.item()
        raise TypeError

    with open(filename, 'w') as f:
        json.dump(results, f, default=convert)

In [34]:
# Function to train the probe for all layers and collect results
def train_conv_probe_all_layers(train_loader, test_loader, image_size, device, num_epochs=20, lr=0.001):
    all_layers_iou_results = {}
    layers_info = get_resnet18_layers_info()
    for layer_name, num_channels in layers_info.items():
        print(f"\nProcessing layer: {layer_name} with {num_channels} channels")
        
        # Declare the model to use
        model = models.resnet18(pretrained=True)
        
        iou_results = train_conv_probe_all_channels(
            model=model,
            layer_name=layer_name,
            num_channels=num_channels,
            train_loader=train_loader,
            test_loader=test_loader,
            image_size=image_size,
            device=device,
            num_epochs=num_epochs,
            lr=lr
        )
        
        save_results_to_json(iou_results, f'./conv-probe-layers-results/{layer_name}_activation_results.json') #TODO salvar por canal/camada
        all_layers_iou_results[layer_name] = iou_results
        
        # Free GPU memory
        torch.cuda.memory_reserved(device=device)
        torch.cuda.memory_allocated(device=device)
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()  # Collect unused GPU memory
        break #TODO: remove
    save_results_to_json(all_layers_iou_results, './conv-probe-layers-results/resnet18_layers_iou_results.json') #TODO salvar por canal/camada


    return all_layers_iou_results

In [36]:
# Call the training function
all_layers_iou_results = train_conv_probe_all_layers(
    train_loader=train_loader,
    test_loader=test_loader,
    image_size=image_size,
    device=device,
    num_epochs=50,
    lr=0.01
)

# Save the results to a JSON file
# TODO: add the probe na saida de uma UNET segmentacao para ver se o resultado é bom(deveria ser), ver questao de tempo de convergencia
# TODO: verificar porque esta indo a zero a IOU, ver resultado da rede para ver se nao esta todo branco ou preto




Processing layer: conv1 with 64 channels
Training for channel 0 of layer conv1
Epoch 1/50, Loss: 0.6707, Train IoU: 0.0990
Validation Loss: 0.5855, Validation IoU: 0.0000
Epoch 2/50, Loss: 0.5663, Train IoU: 0.0000
Validation Loss: 0.5496, Validation IoU: 0.0001
Epoch 3/50, Loss: 0.5596, Train IoU: 0.0000
Validation Loss: 0.5766, Validation IoU: 0.0000
Epoch 4/50, Loss: 0.5658, Train IoU: 0.0000
Validation Loss: 0.5557, Validation IoU: 0.0000
Epoch 5/50, Loss: 0.5482, Train IoU: 0.0002
Validation Loss: 0.5375, Validation IoU: 0.0001
Epoch 6/50, Loss: 0.5418, Train IoU: 0.0024
Validation Loss: 0.5410, Validation IoU: 0.0025
Epoch 7/50, Loss: 0.5448, Train IoU: 0.0084
Validation Loss: 0.5431, Validation IoU: 0.0047
Epoch 8/50, Loss: 0.5404, Train IoU: 0.0100
Validation Loss: 0.5381, Validation IoU: 0.0030
Epoch 9/50, Loss: 0.5385, Train IoU: 0.0076
Validation Loss: 0.5338, Validation IoU: 0.0021
Epoch 10/50, Loss: 0.5370, Train IoU: 0.0065
Validation Loss: 0.5348, Validation IoU: 0.0028

KeyboardInterrupt: 

In [37]:
import segmentation_models_pytorch as smp

def train_unet_model(train_loader, val_loader, device, num_epochs=50, lr=0.001):
    # Define the model
    model = smp.Unet(
        encoder_name="resnet18",        # Choose encoder, e.g. resnet18
        encoder_weights=None,           # Use NONE pre-trained weights for encoderFalse
        in_channels=3,                  # Input channels (3 for RGB images)
        classes=1,                      # Output channels (1 for binary segmentation)
        activation=None                 # No activation function on output
    )
    model.to(device)

    # Define loss function
    criterion = nn.BCEWithLogitsLoss()

    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, masks, _ in train_loader:
            inputs = inputs.to(device)
            masks = masks.to(device).float()

            optimizer.zero_grad()

            outputs = model(inputs)
            outputs = outputs.squeeze(1)
            masks = masks.squeeze(1)

            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

        # You can add validation and metric calculation similar to previous code

    # Save the trained model
    torch.save(model.state_dict(), './models/conv_probe_unet_vess_map.pth')

    return model

  from .autonotebook import tqdm as notebook_tqdm


In [38]:
unet_model = train_unet_model(train_loader=train_loader, val_loader=test_loader, device=device, num_epochs=50, lr=0.001)

Epoch 1/50, Loss: 0.8090
Epoch 2/50, Loss: 0.6150
Epoch 3/50, Loss: 0.5303
Epoch 4/50, Loss: 0.4718
Epoch 5/50, Loss: 0.4314
Epoch 6/50, Loss: 0.3875
Epoch 7/50, Loss: 0.3466
Epoch 8/50, Loss: 0.3224
Epoch 9/50, Loss: 0.3065
Epoch 10/50, Loss: 0.2992
Epoch 11/50, Loss: 0.2844
Epoch 12/50, Loss: 0.2557
Epoch 13/50, Loss: 0.2484
Epoch 14/50, Loss: 0.2404
Epoch 15/50, Loss: 0.2244
Epoch 16/50, Loss: 0.2144
Epoch 17/50, Loss: 0.2085
Epoch 18/50, Loss: 0.2041
Epoch 19/50, Loss: 0.2001
Epoch 20/50, Loss: 0.1942
Epoch 21/50, Loss: 0.1914
Epoch 22/50, Loss: 0.1932
Epoch 23/50, Loss: 0.1844
Epoch 24/50, Loss: 0.1824
Epoch 25/50, Loss: 0.1788
Epoch 26/50, Loss: 0.1764
Epoch 27/50, Loss: 0.1706
Epoch 28/50, Loss: 0.1711
Epoch 29/50, Loss: 0.1666
Epoch 30/50, Loss: 0.1660
Epoch 31/50, Loss: 0.1674
Epoch 32/50, Loss: 0.1646
Epoch 33/50, Loss: 0.1619
Epoch 34/50, Loss: 0.1638
Epoch 35/50, Loss: 0.1710
Epoch 36/50, Loss: 0.1641
Epoch 37/50, Loss: 0.1648
Epoch 38/50, Loss: 0.1582
Epoch 39/50, Loss: 0.

In [57]:
for name, module in unet_model.named_modules():
    print(name)


encoder
encoder.conv1
encoder.bn1
encoder.relu
encoder.maxpool
encoder.layer1
encoder.layer1.0
encoder.layer1.0.conv1
encoder.layer1.0.bn1
encoder.layer1.0.relu
encoder.layer1.0.conv2
encoder.layer1.0.bn2
encoder.layer1.1
encoder.layer1.1.conv1
encoder.layer1.1.bn1
encoder.layer1.1.relu
encoder.layer1.1.conv2
encoder.layer1.1.bn2
encoder.layer2
encoder.layer2.0
encoder.layer2.0.conv1
encoder.layer2.0.bn1
encoder.layer2.0.relu
encoder.layer2.0.conv2
encoder.layer2.0.bn2
encoder.layer2.0.downsample
encoder.layer2.0.downsample.0
encoder.layer2.0.downsample.1
encoder.layer2.1
encoder.layer2.1.conv1
encoder.layer2.1.bn1
encoder.layer2.1.relu
encoder.layer2.1.conv2
encoder.layer2.1.bn2
encoder.layer3
encoder.layer3.0
encoder.layer3.0.conv1
encoder.layer3.0.bn1
encoder.layer3.0.relu
encoder.layer3.0.conv2
encoder.layer3.0.bn2
encoder.layer3.0.downsample
encoder.layer3.0.downsample.0
encoder.layer3.0.downsample.1
encoder.layer3.1
encoder.layer3.1.conv1
encoder.layer3.1.bn1
encoder.layer3.1.re

In [64]:
def get_out_channels(module):
    if isinstance(module, nn.Conv2d):
        return module.out_channels
    # If module is a fused type like Conv2dReLU, check its children
    for submodule in module.children():
        if isinstance(submodule, nn.Conv2d):
            return submodule.out_channels
    raise AttributeError("No Conv2d found in this module")

layer_name = 'decoder.blocks.4.conv2.0' #TODO: check if this is the best layer indeed

for name, module in unet_model.named_modules():
    if name == layer_name:
        print(f"Found layer: {name}")
        try:
            num_channels = get_out_channels(module)
            print(f"Number of channels: {num_channels}")
        except AttributeError as e:
            print(e)

num_channels


Found layer: decoder.blocks.4.conv2.0
Number of channels: 16


16

In [69]:
# Initialize ConvProbe with the UNet model
iou_results = train_conv_probe_all_channels(
            model=unet_model,
            layer_name=layer_name,
            num_channels=num_channels,
            train_loader=train_loader,
            test_loader=test_loader,
            image_size=image_size,
            device=device,
            num_epochs=10,
            lr=0.001,
            single_channel=False)

Training for channel 0 of layer decoder.blocks.4.conv2.0
Epoch 1/10, Loss: 0.7170, Train IoU: 0.2141
Validation Loss: 0.5918, Validation IoU: 0.4335
Epoch 2/10, Loss: 0.5630, Train IoU: 0.5123
Validation Loss: 0.5046, Validation IoU: 0.6133
Epoch 3/10, Loss: 0.4746, Train IoU: 0.6433
Validation Loss: 0.4469, Validation IoU: 0.6257
Epoch 4/10, Loss: 0.4140, Train IoU: 0.6601
Validation Loss: 0.4037, Validation IoU: 0.6364
Epoch 5/10, Loss: 0.3688, Train IoU: 0.6696
Validation Loss: 0.3594, Validation IoU: 0.6535
Epoch 6/10, Loss: 0.3250, Train IoU: 0.6919
Validation Loss: 0.3267, Validation IoU: 0.6649
Epoch 7/10, Loss: 0.2911, Train IoU: 0.7027
Validation Loss: 0.3055, Validation IoU: 0.6736
Epoch 8/10, Loss: 0.2663, Train IoU: 0.7116
Validation Loss: 0.2719, Validation IoU: 0.6880
Epoch 9/10, Loss: 0.2411, Train IoU: 0.7223
Validation Loss: 0.2534, Validation IoU: 0.6920
Epoch 10/10, Loss: 0.2217, Train IoU: 0.7298
Validation Loss: 0.2369, Validation IoU: 0.6952
