In [14]:
# ConvProbe class
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import random
import numpy as np

# Dataset
import sys
sys.path.append('/home/fonta42/Desktop/masters-degree/data/vess-map/')
from vess_map_dataset import VessMapDataset

# Training loop
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import jaccard_score  # For calculating IoU
import numpy as np

# Save the results
import json

# Clear memory
import gc

In [15]:
class ConvProbe(nn.Module):
    def __init__(self, layer_name, output_size=(256, 256), seed=None, single_channel=True):
        super(ConvProbe, self).__init__()

        # Load pre-trained ResNet-18 model
        self.model = models.resnet18(pretrained=True)
        self.layer_name = layer_name
        self.output_size = output_size
        self.single_channel = single_channel

        # Freeze model weights to prevent training
        for param in self.model.parameters():
            param.requires_grad = False

        # Get number of channels from the specified layer
        in_channels = 1 if self.single_channel else self.get_layer_channels(layer_name)


        # Initialize the convolutional probe (to be trained)
        self.conv_probe = nn.Conv2d(
            in_channels=in_channels,  # Use the layer's number of channels
            out_channels=1,  # Single output channel for binary segmentation
            kernel_size=3,
            padding=1  # Add padding to preserve spatial dimensions
        )

        # Set a fixed random seed for consistency, if provided
        if seed is not None:
            self.set_seed(seed)

        # Register a hook to capture activations from the specified layer
        self.activations = {}
        self.register_hook()

    # Helper function to set random seed
    def set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Helper function to get the number of channels from the specified layer
    def get_layer_channels(self, layer_name):
        # Iterate over the named modules to find the specified layer
        for name, module in self.model.named_modules():
            if name == layer_name:
                if hasattr(module, 'out_channels'):
                    return module.out_channels
                elif hasattr(module, 'num_features'):
                    return module.num_features  # For BatchNorm layers
                else:
                    raise ValueError(f"Layer {layer_name} does not have out_channels or num_features.")
        raise ValueError(f"Layer {layer_name} not found.")

    # Register a forward hook to capture activations
    def register_hook(self):
        def hook_fn(module, input, output):
            self.activations[self.layer_name] = output

        # Register the hook on the specified layer
        for name, module in self.model.named_modules():
            if name == self.layer_name:
                module.register_forward_hook(hook_fn)
                break

    # Forward pass through the model
    def forward(self, x, channel_idx=None):
        _ = self.model(x)  # Forward pass to get activations
        activation = self.activations[self.layer_name]  # Retrieve the stored activation

        # Select a specific channel if provided
        if channel_idx is not None:
            if channel_idx < 0 or channel_idx >= activation.size(1):
                raise ValueError(f"Channel index {channel_idx} is out of bounds for activation with {activation.size(1)} channels.")
            activation = activation[:, channel_idx:channel_idx+1, :, :]  # Select the specific channel

        
        activation = self.conv_probe(activation) # TODO: implement another conv 3x3
        
        # Interpolate activation to match the desired output size
        activation = F.interpolate(activation, size=self.output_size, mode='bilinear', align_corners=True)

        # Apply the convolutional probe to the activation
        out = self.conv_probe(activation)

        return out

In [16]:
# Define paths
image_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/images'
mask_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/labels'
skeleton_dir = '/home/fonta42/Desktop/masters-degree/data/vess-map/skeletons'

image_size = 256

# Initialize the dataset
vess_dataset = VessMapDataset(image_dir, mask_dir, skeleton_dir, image_size, apply_transform=True)

# Get the train and test loaders
train_loader, test_loader = vess_dataset.vess_map_dataloader(batch_size=32, train_size=0.8)

In [17]:
def calculate_iou(outputs, masks, threshold=0.5):
    # Apply sigmoid to outputs to get probabilities between 0 and 1
    preds = torch.sigmoid(outputs)
    preds = (preds > threshold).float()  # Convert probabilities to binary predictions

    # Ensure masks are float type
    masks = masks.float()

    # Compute intersection and union
    intersection = (preds * masks).sum(dim=(1, 2))
    union = ((preds + masks) > 0).float().sum(dim=(1, 2))

    # Avoid division by zero
    iou = torch.where(union == 0, torch.tensor(1.0).to(outputs.device), intersection / union)

    # Return mean IoU over the batch
    return iou.mean().item()

In [18]:
# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [19]:
# Training code
def train_conv_probe_all_channels(layer_name, num_channels, train_loader, test_loader, image_size, device, num_epochs=50, lr=0.001):
    iou_results = {}  # Dictionary to store IoU results for each channel

    for channel_idx in range(num_channels):
        print(f"Training for channel {channel_idx} of layer {layer_name}")

        probe_model = ConvProbe(layer_name=layer_name, output_size=(image_size, image_size), seed=42)
        probe_model.to(device)

        # BinaryCEWithLogitsLoss for binary segmentation
        criterion = nn.BCEWithLogitsLoss()

        # Define optimizer
        optimizer = optim.Adam(probe_model.conv_probe.parameters(), lr=lr)

        # Track the best IoU for this channel
        best_iou = 0.0
        best_iou_train = 0.0

        for epoch in range(num_epochs):
            probe_model.train()
            running_loss = 0.0
            train_iou = 0.0

            # Training loop
            for inputs, masks, _ in train_loader:
                inputs = inputs.to(device)
                masks = masks.to(device).float()  # Ensure masks are float type

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass using the current channel
                outputs = probe_model(inputs, channel_idx=channel_idx)
                outputs = outputs.squeeze(1)  # Shape: [batch_size, height, width]
                masks = masks.squeeze(1)      # Shape: [batch_size, height, width]

                # Compute loss
                loss = criterion(outputs, masks)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                # Accumulate running loss
                running_loss += loss.item() * inputs.size(0)

                # Compute IoU for this batch
                batch_iou = calculate_iou(outputs, masks)
                train_iou += batch_iou * inputs.size(0)

            # Calculate epoch loss and IoU
            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_iou = train_iou / len(train_loader.dataset)
            #if epoch == num_epochs - 1:
            #    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train IoU: {epoch_iou:.4f}')
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Train IoU: {epoch_iou:.4f}')
            

            # Validation loop
            probe_model.eval()
            val_loss = 0.0
            val_iou = 0.0
            with torch.no_grad():  # Disable gradient calculations during validation
                for val_inputs, val_masks, _ in test_loader:
                    val_inputs = val_inputs.to(device)
                    val_masks = val_masks.to(device).float()  # Ensure masks are float type

                    # Forward pass using the current channel
                    val_outputs = probe_model(val_inputs, channel_idx=channel_idx)
                    val_outputs = val_outputs.squeeze(1)  # Shape: [batch_size, height, width]
                    val_masks = val_masks.squeeze(1)      # Shape: [batch_size, height, width]

                    # Compute loss
                    loss = criterion(val_outputs, val_masks)
                    val_loss += loss.item() * val_inputs.size(0)

                    # Compute IoU for this batch
                    batch_val_iou = calculate_iou(val_outputs, val_masks)
                    val_iou += batch_val_iou * val_inputs.size(0)

            # Calculate validation loss and IoU
            val_loss /= len(test_loader.dataset)
            val_iou /= len(test_loader.dataset)
            
            #if epoch == num_epochs - 1:
            #    print(f'Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')
            print(f'Validation Loss: {val_loss:.4f}, Validation IoU: {val_iou:.4f}')
                


        # Save IoU metrics for this channel
        iou_results[channel_idx] = {'train_iou': best_iou_train, 'val_iou': best_iou}
        
        # Free GPU memory after training for the current channel is completed
        del probe_model, optimizer, criterion
        del inputs, masks, outputs
        del val_inputs, val_masks, val_outputs
        torch.cuda.empty_cache()  # Clear CUDA memory
        gc.collect() 
        
    del train_loader, test_loader
    torch.cuda.empty_cache()  # Clear CUDA memory
    gc.collect() 
    


    return iou_results

In [20]:
# Function to get layer names and number of channels from ResNet-18
def get_resnet18_layers_info():
    resnet18 = models.resnet18(pretrained=True)
    layers_info = {}
    for name, module in resnet18.named_modules():
        if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
            if hasattr(module, 'out_channels'):
                out_channels = module.out_channels
            elif hasattr(module, 'num_features'):
                out_channels = module.num_features
            else:
                continue
            layers_info[name] = out_channels
    return layers_info

In [21]:
# Function to train the probe for all layers and collect results
def train_conv_probe_all_layers(train_loader, test_loader, image_size, device, num_epochs=20, lr=0.001):
    all_layers_iou_results = {}
    layers_info = get_resnet18_layers_info()
    for layer_name, num_channels in layers_info.items():
        print(f"\nProcessing layer: {layer_name} with {num_channels} channels")
        iou_results = train_conv_probe_all_channels(
            layer_name=layer_name,
            num_channels=num_channels,
            train_loader=train_loader,
            test_loader=test_loader,
            image_size=image_size,
            device=device,
            num_epochs=num_epochs,
            lr=lr
        )
        all_layers_iou_results[layer_name] = iou_results
        
        # Free GPU memory
        torch.cuda.memory_reserved(device=device)
        torch.cuda.memory_allocated(device=device)
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()  # Collect unused GPU memory

    return all_layers_iou_results

In [22]:
# Function to save results to a JSON file
def save_results_to_json(results, filename):
    # Convert tensors or numpy types to native Python types
    def convert(o):
        if isinstance(o, np.float32) or isinstance(o, np.float64):
            return float(o)
        if isinstance(o, torch.Tensor):
            return o.item()
        raise TypeError

    with open(filename, 'w') as f:
        json.dump(results, f, default=convert)

In [23]:
# Call the training function
all_layers_iou_results = train_conv_probe_all_layers(
    train_loader=train_loader,
    test_loader=test_loader,
    image_size=image_size,
    device=device,
    num_epochs=50,
)

# Save the results to a JSON file
save_results_to_json(all_layers_iou_results, 'resnet18_layers_iou_results.json') #TODO salvar por canal/camada
#TODO: add the probe na saida de uma UNET segmentacao para ver se o resultado é bom(deveria ser), ver questao de tempo de convergencia
# TODO: verificar porque esta indo a zero a IOU, ver resultado da rede para ver se nao esta todo branco ou preto


Processing layer: conv1 with 64 channels
Training for channel 0 of layer conv1
Epoch 1/50, Loss: 0.6860, Train IoU: 0.2130
Validation Loss: 0.6852, Validation IoU: 0.1960
Epoch 2/50, Loss: 0.6842, Train IoU: 0.1910
Validation Loss: 0.6831, Validation IoU: 0.1799
Epoch 3/50, Loss: 0.6828, Train IoU: 0.1652
Validation Loss: 0.6816, Validation IoU: 0.1581
Epoch 4/50, Loss: 0.6812, Train IoU: 0.1410
Validation Loss: 0.6806, Validation IoU: 0.1401
Epoch 5/50, Loss: 0.6796, Train IoU: 0.1198
Validation Loss: 0.6785, Validation IoU: 0.1177
Epoch 6/50, Loss: 0.6779, Train IoU: 0.1066
Validation Loss: 0.6769, Validation IoU: 0.1045
Epoch 7/50, Loss: 0.6764, Train IoU: 0.0911
Validation Loss: 0.6747, Validation IoU: 0.0937
Epoch 8/50, Loss: 0.6750, Train IoU: 0.0804
Validation Loss: 0.6731, Validation IoU: 0.0843
Epoch 9/50, Loss: 0.6737, Train IoU: 0.0705
Validation Loss: 0.6710, Validation IoU: 0.0794
Epoch 10/50, Loss: 0.6717, Train IoU: 0.0666
Validation Loss: 0.6694, Validation IoU: 0.0757

KeyboardInterrupt: 

In [None]:
# TODO: save metrics for every epoch
{
    "layer": {
        "channel_idx": {"iou_train": 00,
                        "iou_validation": 00}
    }
}