# Downstream Evaluation of SSL Representations

This notebook performs the **linear evaluation protocol** on the frozen feature representations learned by our self-supervised models (SimCLR and BYOL) for the surgical phase recognition task.

**Objectives:**
1. **Load Pretrained Features**  
   - Read in the feature embeddings extracted from each frame by the pretrained encoder (with and without attention).

2. **Prepare Labels and Splits**  
   - Load the corresponding per-frame CSV label files.

3. **Train Linear Classifier**  
   - For each SSL model (SimCLR, BYOL, with/without attention), train a simple logistic regression or single-layer MLP on the training embeddings.
   - Keep the encoder frozen; only the classifier weights are updated.

4. **Evaluate Performance**  
   - Compute standard metrics:  
     - **Accuracy** over full video sequences  
     - **Macro-averaged Precision, Recall, and F1-score** across the 7 surgical phases  
     - **Confusion Matrix** to inspect class-wise performance  
   - Optionally, assess robustness by retraining/evaluating on reduced label subsets (50%, 25%, 10%, 5%).

5. **Visualize Results**  
   - Plot F1-score vs. fraction of labels used  
   - Display confusion matrices side-by-side for SimCLR vs. BYOL, and attention vs. no-attention configurations.



In [1]:
from __future__ import print_function, division
import os
import torch
import torchvision
import pandas as pd
from skimage import io, transform, util
from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
from torchvision import transforms, utils, models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
import random
from sklearn.metrics import confusion_matrix
import seaborn as sns
from collections import Counter
from copy import deepcopy
from sklearn import preprocessing


# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

Device: cuda:0
Number of workers: 16


## 2. Downstream task (Fine-tuning)


### Load dataset

In [3]:
class PituDataset(Dataset):
    """Pituitary Endoscopy dataset."""

    def __init__(self, csv_file, root_dir, transform=None, maxSize=0, unlabeled=False):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            maxSize (int, optional): Maximum size of the dataset (number of samples).
            unlabeled (bool, optional): If True, ignore labels.
        """
        self.dataset = pd.read_csv(csv_file, header=0, dtype={'id': str, 'label': int})
        
        if maxSize > 0:
            newDatasetSize = maxSize  # maxSize samples (Parameter to select a specific number of images)
            idx = np.random.RandomState(seed=42).permutation(range(len(self.dataset)))
            reduced_dataset = self.dataset.iloc[idx[0:newDatasetSize]]
            self.dataset = reduced_dataset.reset_index(drop=True)

        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'images')
        self.transform = transform
        self.unlabeled = unlabeled
        self.classes = ['Desconocida', 'Preparacion colgajo', 'Etmoidectomia', 'Apertura selar', 
                        'Apertura dural', 'Reseccion tumoral', 'Cierre']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # Read the image
        img_name = os.path.join(self.img_dir, self.dataset.id[idx] + '.png')
        image = io.imread(img_name)
        
        if self.unlabeled:
            sample = {'image': image, 'label': np.int64(-1)}  # Use -1 to indicate unlabeled, keep datatype
        else:
            sample = {'image': image, 'label': self.dataset.label[idx].astype(dtype=np.long)}
        
        if self.transform:
            sample = self.transform(sample)
        return sample 

In [5]:
#Preprocessing
class Rescale(object):
    """Re-scale image to a predefined size.

    Args:
        output_size (tuple or int): The desired size. If it is a tuple, output is the output_size. 
        If it is an int, the smallest dimension will be the output_size
            a we will keep fixed the original aspect ratio.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'],sample['label']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        return {'image': img, 'label' : label}

class ToTensor(object):
    """Convert ndarrays into pytorch tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # Cambiamos los ejes
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        image = torch.from_numpy(image)
        
        label=torch.tensor(label,dtype=torch.long)
        
        return {'image':image,
                'label':label}
    
class Normalize(object):
    """Normalize data by subtracting means and dividing by standard deviations.

    Args:
        mean_vec: Vector with means. 
        std_vec: Vector with standard deviations.
    """

    def __init__(self, mean,std):
      
        assert len(mean)==len(std),'Length of mean and std vectors is not the same'
        self.mean = np.array(mean)
        self.std = np.array(std)

    def __call__(self, sample):
        image, label = sample['image'],sample['label']
        c, h, w = image.shape
        assert c==len(self.mean), 'Length of mean and image is not the same' 
        dtype = image.dtype
        mean = torch.as_tensor(self.mean, dtype=dtype, device=image.device)
        std = torch.as_tensor(self.std, dtype=dtype, device=image.device)
        image.sub_(mean[:, None, None]).div_(std[:, None, None])
    
        
        return {'image': image, 'label' : label}

class CenterCrop(object):
    """Crop the central area of the image

    Args:
        output_size (tupla or int): Crop size. If int, square crop

    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size
        rem_h = h - new_h
        rem_w = w - new_w
        
        if h>new_h:
            top = int(rem_h/2)
        else:
            top=0
            
        if w>new_w: 
            left = int(rem_w/2)
        else:
            left = 0
            
        image = image[top: top + new_h,
                     left: left + new_w]


        return {'image': image, 'label': label}

Load the data with labels.

In [6]:
pixel_mean = [0.312, 0.120, 0.117]
pixel_std = [0.280, 0.158, 0.160]

img_transforms = transforms.Compose([CenterCrop((256, 320)),
                                     Rescale((224,224)),
                                     ToTensor(),
                                     Normalize(mean=pixel_mean, std=pixel_std)])

train_img_data = PituDataset(csv_file="/home/train_set.csv",
                                      root_dir='/home',
                                      #maxSize=100000,
                                      transform=img_transforms)

val_img_data = PituDataset(csv_file="/home/val_set.csv",
                            root_dir='/home',
                            transform=img_transforms)


print("Number of training examples:", len(train_img_data))
print("Number of test examples:", len(val_img_data))

Number of training examples: 213907
Number of test examples: 56431


## Load the pre-trained model (SimCLR or BYOL)

### SimCLR implementation

In [4]:
class SimCLR(pl.LightningModule):
    
    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=100):
        super().__init__()
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        
        # Base model f(.): ResNet-50 
        self.convnet = torchvision.models.resnet50()
        in_features = self.convnet.fc.in_features  # 2048 for ResNet-50 (this is vector h)

        # The MLP for g(.) consists of Linear->ReLU->Linear
        # this is the projection head: 2048 → 4 * hidden_dim → hidden_dim
        self.convnet.fc = nn.Sequential(
            nn.Linear(in_features, 4 * hidden_dim),  # Linear(2048, 4*hidden_dim) (input h)
            nn.ReLU(inplace=True),
            nn.Linear(4 * hidden_dim, hidden_dim)  # Linear(4*hidden_dim, hidden_dim) (output z)
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), 
                                lr=self.hparams.lr, 
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr/50)
        return [optimizer], [lr_scheduler]
        
    def info_nce_loss(self, batch, mode='train'):
        imgs = torch.cat(batch['image'], dim=0) # doesn't use labels
        imgs = imgs.to(device=device, dtype=torch.float) 
        
        # Encode all images
        feats = self.convnet(imgs) # this is vector z
        
        # Calculate cosine similarity
        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1) 
        
        # Mask out cosine similarity to itself
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        
        cos_sim.masked_fill_(self_mask, -9e15)
        
        # Find positive example -> batch_size//2 away from the original example
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
        
        # InfoNCE loss
        cos_sim = cos_sim / self.hparams.temperature
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()
        
        ###--- Logg metrics ---###
        
        # Logging loss
        self.log(mode+'_loss', nll)
        
        # Get ranking position of positive example
        comb_sim = torch.cat([cos_sim[pos_mask][:,None],  # First position positive example
                              cos_sim.masked_fill(pos_mask, -9e15)], 
                             dim=-1)
        sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
        # Logging ranking metrics
        self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean())
        self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean())
        self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())
        
        return nll
        
    def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, mode='train')
        
    def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, mode='val')

### BYOL Implementation

In [None]:
class BYOL(pl.LightningModule):
    """
    BYOL (Bootstrap Your Own Latent) implementation.

    This class defines the architecture and training process for a self-supervised learning
    model, allowing it to learn useful representations without using labeled data.
    """
    
    def __init__(self, hidden_dim, projection_size, lr, momentum, weight_decay, moving_average_decay):
        super().__init__()
        """
        Args:
            hidden_dim (int): The size of the hidden vector in the MLPs of the student and teacher projection heads.
            projection_size (int): The size of the output vector from the projection head (dimension of the embedding space).
            lr (float): Learning rate for the optimizer.
            momentum (float): Momentum parameter for the SGD optimizer.
            weight_decay (float): Weight decay for L2 regularization.
            moving_average_decay (float): Decay factor for the exponential moving average used to update the teacher model. e.g. 0.99
        """
        self.save_hyperparameters()
        
        # Base encoder f(.): ResNet-50
        self.backbone = torchvision.models.resnet50()
        in_features = self.backbone.fc.in_features  # 2048 for ResNet-50 (this is vector h)
        self.backbone.fc = nn.Identity()  # Remove the final classification layer to get the feature vector
        
        # Projection head g(·) --> consists of Linear->BN->ReLU->Linear
        self.student_projector = MLP(in_features, hidden_dim, projection_size) #(2048,4096,512)
        
        # Prediction head q(·)
        self.student_predictor = MLP(projection_size, hidden_dim, projection_size)  #(512,4096,512) (output vector q)
        
        # Teacher model
        self.teacher_projector = copy.deepcopy(self.student_projector)
        
        # EMA parameters
        self.moving_average_decay = moving_average_decay
        
   
    def configure_optimizers(self):
        '''optimizer = optim.AdamW(self.parameters(), 
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        '''
        optimizer = optim.SGD(self.parameters(),
                              lr=self.hparams.lr,
                              weight_decay=self.hparams.weight_decay,
                              momentum=self.hparams.momentum)
        
        
        return optimizer


    @torch.no_grad()
    def update_moving_average(self):
        """
        Updates the weights of the teacher model as a moving average of the student model's weights.
        """
        for student_params, teacher_params in zip(self.student_projector.parameters(), self.teacher_projector.parameters()):
            teacher_params.data = teacher_params.data * self.moving_average_decay + (1. - self.moving_average_decay) * student_params.data
          
        
    def initializes_target_network(self):
        '''
        Initializes the target (teacher) network with the same weights as the student model.
        Ensures the teacher's parameters do not require gradient updates.
        '''
        
        for student_params, teacher_params in zip(self.student_projector.parameters(), self.teacher_projector.parameters()):
            teacher_params.data.copy_(student_params.data)  # initialize
            teacher_params.requires_grad = False  # not update by gradient
            
    
    def on_train_start(self):
        # Initialize the teacher network at the start of training
        self.initializes_target_network()


    def forward(self, x):
        '''
        Forward pass through the student network and student projector
        '''
        features = self.backbone(x)
        student_projection = self.student_projector(features)
        student_prediction = self.student_predictor(student_projection)
        return student_prediction
    

    def shared_step(self, img1, img2):
    
        # get student projections: backbone + MLP projection head
        feats1 = self.backbone(img1) #this is h
        feats2 = self.backbone(img2)
        
        student_proj1 = self.student_projector(feats1) #this is g
        student_proj2 = self.student_projector(feats2)

        # Apply the predictor MLP to the student's projections
        student_pred1 = self.student_predictor(student_proj1) # this is q
        student_pred2 = self.student_predictor(student_proj2)

        # Get teacher projections (no gradient updates)
        with torch.no_grad():
            # teacher processes the images and makes projections: backbone + MLP
            teacher_proj1 = self.teacher_projector(feats1) 
            teacher_proj2 = self.teacher_projector(feats2)
        
        # calculate loss
        loss = L2_loss(student_pred1, teacher_proj2)
        loss += L2_loss(student_pred2, teacher_proj1)        

        return loss.mean() #loss = (loss1 + loss2).mean()

    
    def training_step(self, batch, batch_idx):
        img1, img2 = batch['image'][0], batch['image'][1]
        img1 = img1.to(device=device, dtype=torch.float)
        img2 = img2.to(device=device, dtype=torch.float)

        loss = self.shared_step(img1, img2)
        self.log('train_loss', loss)

        # Update the teacher model
        self.update_moving_average()
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        img1, img2 = batch['image'][0], batch['image'][1]
        img1 = img1.to(device=device, dtype=torch.float)
        img2 = img2.to(device=device, dtype=torch.float)

        loss = self.shared_step(img1, img2)
        self.log('val_loss', loss)
                
        return loss

In [7]:
# Load the saved .pt model
model_path = '/home/simclr_models/simclr_model.pt'

# Initialize the model
loaded_model = SimCLR(max_epochs=5, hidden_dim=64, lr=5e-4, temperature=0.07, weight_decay=1e-4)

# Load the state dictionary into the model
loaded_model.load_state_dict(torch.load(model_path))
print(loaded_model)

SimCLR(
  (convnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

In [None]:
# discard everything except the encoder
model=loaded_model.convnet

# eliminate the last classification layer
encoder = nn.Sequential(*list(model.children())[:-1])    
encoder.to(device)

## Encode images
Next, we implement a small function to encode the images in our datasets. The output representations are then used as inputs to the Logistic Regression model.

In [9]:
@torch.no_grad()
def prepare_data_features(encoder, dataset):
    """
    Extracts features from the encoder for a given dataset and returns a TensorDataset.
    
    Args:
        encoder (nn.Module): Pre-trained encoder model without the final classification layer.
        dataset (Dataset): Dataset for which features need to be extracted.
        
    Returns:
        TensorDataset: A dataset containing the extracted features and corresponding labels.
    """
    # Set encoder to evaluation mode and move to the correct device
    encoder.eval()
    encoder = encoder.float()  # Ensure the encoder uses float precision
    encoder.to(device)
    
    feats = []
    labels_list = []
    
    # Prepare the data loader
    data_loader = DataLoader(dataset, batch_size=64, num_workers=NUM_WORKERS, shuffle=True, drop_last=False)

    # Get the features from the pre-trained model
    for i, batch in enumerate(data_loader):
        imgs = batch['image'].to(device, dtype=torch.float)
        labels = batch['label'].to(device)
        
        with torch.no_grad():
            features = encoder(imgs)
            
            feats.append(features.detach().cpu())
            labels_list.append(labels.detach().cpu())
    
    feats = torch.cat(feats, dim=0)
    labels = torch.cat(labels_list, dim=0)

    return feats, labels

The original data loaders `train_loader, val_loader` are used to extract features from the images using a pre-trained encoder `encoder`.

In [10]:
# The function extracts features for each input image batch from the encoder
x_train, y_train = prepare_data_features(encoder, train_img_data)
x_test, y_test = prepare_data_features(encoder, val_img_data)

print("Training data shape:", x_train.shape, y_train.shape)
print("Testing data shape:", x_test.shape, y_test.shape)

Training data shape: torch.Size([213907, 2048, 1, 1]) torch.Size([213907])
Testing data shape: torch.Size([56431, 2048, 1, 1]) torch.Size([56431])


In [11]:
# checks if the feature tensor has more than two dimensions: [N, feature_dim, height, width]
if len(x_test.shape) > 2:
    x_train = torch.mean(x_train, dim=[2, 3]) # reduce the shape of the features to [N, feature_dim]
    x_test = torch.mean(x_test, dim=[2, 3])

print("Training data shape:", x_train.shape, y_train.shape)
print("Testing data shape:", x_test.shape, y_test.shape)

Training data shape: torch.Size([213907, 2048]) torch.Size([213907])
Testing data shape: torch.Size([56431, 2048]) torch.Size([56431])


After feature extraction, the features `x_train, x_test` are standardized for better training of the classifier.

In [12]:
# standardize the extracted features: mean of 0 and a standard deviation of 1
scaler = preprocessing.StandardScaler()  
scaler.fit(x_train)
x_train = scaler.transform(x_train).astype(np.float32) # convert to float32 
x_test = scaler.transform(x_test).astype(np.float32)

New data loaders are created to work with the scaled features instead of raw image data.


In [13]:
def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test):
    
    '''- Input: Takes feature vectors (X_train and X_test) and corresponding labels (y_train and y_test).
        
       - Purpose: Converts the feature arrays and their labels into TensorDataset objects. 
       This allows the features and labels to be combined as tensors, which is the format that PyTorch expects.'''

    train = torch.utils.data.TensorDataset(X_train, y_train)
    train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True)

    test = torch.utils.data.TensorDataset(X_test, y_test)
    val_loader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=False)
    
    return train_loader, val_loader

In [14]:
train_loader, val_loader = create_data_loaders_from_arrays(torch.from_numpy(x_train), y_train, torch.from_numpy(x_test), y_test)

These new loaders `train_loader, val_loader` will be used to train a simple classifier (e.g., logistic regression) on the extracted, standardized features.

### Logistic Regression

Now apply the extracted characteristics on a supervised downstream task.

In [16]:
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.linear(x)

In [17]:
def compute_eval_metrics(predicted_labels, labels):
    #Accuracy
    accuracy = np.mean(predicted_labels == labels)
    
    # Recall y precision por clase
    class_precision = []
    class_recall = []
    unique_labels = np.unique(labels)

    
    for label in unique_labels: #Para cada clase
        VP = np.sum((predicted_labels == label) & (labels == label)) #Number of correct detections 
        FP = np.sum((predicted_labels == label) & (labels != label)) #Number of incorrect detections
        FN = np.sum((predicted_labels != label) & (labels == label)) 

        Precision = VP/(VP+FP)
        Recall = VP/(VP+FN)

        class_precision.append(Precision)
        class_recall.append(Recall)
    
    precision = np.mean(class_precision)
    recall = np.mean(class_recall)
    
    f1_score = 2 * (precision*recall)/(precision+recall)
        
    return accuracy, class_precision, class_recall, precision, recall, f1_score

In [18]:
# create an instance of the model, input_dim=2048, output_dim=7
output_feature_dim = loaded_model.convnet.fc[0].in_features # 2048

logreg = LogisticRegression(output_feature_dim, 7) #(2048,7)
logreg = logreg.to(device)

In [19]:
# Define the criterion - Weighted CE loss   
num_etiquetas = [9378, 10254, 37582, 45600, 20624, 50000, 46586]    

weights = []
for num in num_etiquetas:
    weight_for_class_i = sum(num_etiquetas) / num
    weights.append(weight_for_class_i)
    

criterion = nn.CrossEntropyLoss(weights=weights)
criterion = criterion.to(device)

# Define the optimizer
optimizer = optim.AdamW(logreg.parameters(), lr=0.0001, weight_decay = 0.01)

# Define the loss function
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [None]:
# Variables to store metrics for plotting
train_losses = []
val_accuracies = []
val_precisions = []
val_recalls = []
val_f1_scores = []

# Variables for best epoch
eval_every_n_epochs = 10
best_acc = 0
best_epoch = -1

# time
since = time.time()
n=100
for epoch in range(n):
    
    # Training loop
    logreg.train()
    train_loss = 0.0
    
    for imgs, labels in train_loader:
        imgs = imgs.to(device).float()
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()        
        
        # Forward pass
        outputs = logreg(imgs)
        
        # Compute loss
        loss = criterion(outputs, labels)
        train_loss += loss.item()
        
        # Backward pass and optimization step
        loss.backward()
        optimizer.step()
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Evaluation loop every 10 epochs
    if epoch % eval_every_n_epochs == 0:
        logreg.eval()
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs = imgs.to(device).float()
                labels = labels.to(device)

                # Forward pass
                outputs = logreg(imgs)
                predictions = torch.argmax(outputs, dim=1)
                
                # Store predictions and labels for metric calculation
                all_preds.extend(predictions.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Convert lists to numpy arrays for metric calculations
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        
        # Compute metrics
        unique_labels = np.unique(all_labels)
        accuracy, class_precision, class_recall, precision, recall, f1_score = compute_eval_metrics(all_preds, all_labels)
        
        # Store evaluation metrics
        val_accuracies.append(accuracy)
        val_precisions.append(precision)
        val_recalls.append(recall)
        val_f1_scores.append(f1_score)
        
        # Deep copy the best model
        if f1_score > best_acc:
            best_acc = f1_score
            best_model_wts = copy.deepcopy(model.state_dict())
            best_epoch = epoch
            
        # Print the results
        print(f"Epoch {epoch}:")
        print(f"  - Train Loss: {train_loss:.4f}")
        print(f"  - Accuracy: {accuracy * 100:.2f}%")
        print(f"  - Precision: {precision * 100:.2f}%")
        print(f"  - Recall: {recall * 100:.2f}%")
        print(f"  - F1-Score: {f1_score * 100:.2f}%\n")
        

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best model in epoch {:d} at val F1-Score: {:.4f}'.format(best_epoch, best_acc))

In [None]:
# Plot training loss
plt.figure(figsize=(6, 3))
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.show()

# Plot validation metrics
plt.figure(figsize=(6, 3))
plt.plot(range(0, n, eval_every_n_epochs), val_accuracies, label='Validation Accuracy')
plt.plot(range(0, n, eval_every_n_epochs), val_precisions, label='Validation Precision')
plt.plot(range(0, n, eval_every_n_epochs), val_recalls, label='Validation Recall')
plt.plot(range(0, n, eval_every_n_epochs), val_f1_scores, label='Validation F1-Score')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title('Validation Metrics')
plt.legend()
plt.show()


In [22]:
def metrics_per_patient(outputs, labels, patients, paciente_indices): 
    
    predicted_labels = outputs
    unique_labels = np.unique(labels)

    acc_total = []
    recall_total = []
    precision_total = []
    f1_score_total = []
    class_recall_total = np.empty((len(patients), len(unique_labels)))
    class_precision_total = np.empty((len(patients), len(unique_labels)))

    for i in range(len(patients)):
        pred = np.array(predicted_labels[paciente_indices[i]:paciente_indices[i+1]])
        labels_GT = np.array(labels[paciente_indices[i]:paciente_indices[i+1]])
        
        accuracy, class_precision, class_recall, precision, recall, f1_score = compute_eval_metrics(pred, labels_GT)
        
        acc_total.append(accuracy)
        recall_total.append(recall)
        precision_total.append(precision)
        f1_score_total.append(f1_score)
        
        class_recall_total[i, :] = class_recall
        class_precision_total[i, :] = class_precision
    
    acc_patients = sum(acc_total)/len(acc_total)
    acc_patients_std = np.std(np.array(acc_total))
    
    recall_patients = sum(recall_total)/len(recall_total)
    recall_patients_std = np.std(np.array(recall_total))
    
    precision_patients = sum(precision_total)/len(precision_total)
    precision_patients_std = np.std(np.array(precision_total))
    
    f1_score_patients = sum(f1_score_total)/len(f1_score_total)
    f1_score_patients_std = np.std(np.array(f1_score_total))
    
    
    class_recall_patients_mean = np.nanmean(class_recall_total,0)
    class_recall_patients_std = np.nanstd(class_recall_total,0)
    class_precision_patients_mean = np.nanmean(class_precision_total,0)
    class_precision_patients_std = np.nanstd(class_precision_total,0)
    
    recall_porfase = np.nanmean(class_recall_patients_mean)
    recall_porfase_std = np.nanstd(class_recall_patients_mean)
    precision_porfase = np.nanmean(class_precision_patients_mean)
    precision_porfase_std = np.nanstd(class_precision_patients_mean)
    f1_porfase = (2*recall_porfase*precision_porfase)/(recall_porfase+precision_porfase)
    f1_porfase_std = (2*recall_porfase_std*precision_porfase_std)/(recall_porfase_std+precision_porfase_std)

    return (acc_patients, acc_patients_std, recall_patients, recall_patients_std, precision_patients, 
            precision_patients_std, f1_score_patients, f1_score_patients_std, class_recall_patients_mean, 
            class_recall_patients_std, class_precision_patients_mean, class_precision_patients_std,recall_porfase,
            recall_porfase_std, precision_porfase,precision_porfase_std, f1_porfase, f1_porfase_std)

In [24]:
def obtain_patients_index(ruta_csv):
    data = pd.read_csv(ruta_csv, header=0)
    paciente_indices = []
    paciente_actual = 0
    patients = []
    for indice, fila in data.iterrows():
        id_paciente = fila["id"].split("_")[0]  # Obtener el número de paciente
        if id_paciente != paciente_actual:
            patients.append(id_paciente)
            inicio = indice
            paciente_indices.append(inicio)
            paciente_actual = id_paciente
    
    paciente_indices.append(len(data))
    return patients, paciente_indices

patients, paciente_indices = obtain_patients_index("/home/test_set.csv")

In [None]:
(acc_patients, acc_patients_std, recall_patients, recall_patients_std, precision_patients, 
 precision_patients_std, f1_score_patients, f1_score_patients_std, class_recall_patients_mean, 
 class_recall_patients_std, class_precision_patients_mean, class_precision_patients_std,
 recall_porfase, recall_porfase_std, precision_porfase, precision_porfase_std, f1_porfase, f1_porfase_std) = metrics_per_patient(all_preds, all_labels, patients, paciente_indices)

print('TEST RESULTS: ')
print(f'- Accuracy: {acc_patients:.2f} ± {acc_patients_std:.2f}')
print(f'- Recall: {recall_patients:.2f} ± {recall_patients_std:.2f}')
print(f'- Precision: {precision_patients:.2f} ± {precision_patients_std:.2f}')
print(f'- F1_Score: {f1_score_patients:.2f} ± {f1_score_patients_std:.2f}')
print(f'- Class recall: {class_recall_patients_mean} ± {class_recall_patients_std}')
print(f'- Class precision: {class_precision_patients_mean} ± {class_precision_patients_std}')

print('PER PHASE RESULTS: ')
print(f'- Recall: {recall_porfase:.2f} ± {recall_porfase_std:.2f}')
print(f'- Precision: {precision_porfase:.2f} ± {precision_porfase_std:.2f}')
print(f'- F1-Score: {f1_porfase:.2f} ± {f1_porfase_std:.2f}\n')

In [26]:
# Function to obtain model outputs and labels
def obtain_label(model, dataset, dataloader):
    numClasses = 7
    model.eval()   # Set the model to evaluation mode
    device = next(model.parameters()).device  # Get the device of the model
    
    numSamples = len(dataset)  # Size of the dataset
    outputs_m = np.zeros((numSamples, numClasses), dtype=np.float)
    labels_m = np.zeros((numSamples,), dtype=np.int)
    contSamples = 0

    # Iterate over the data
    for inputs, labels in dataloader:
        batchSize = inputs.size(0)

        # Move data to the same device as the model
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # Apply softmax to the output
            outputs = F.softmax(outputs, dim=1)
            outputs_m[contSamples:contSamples + batchSize, ...] = outputs.cpu().numpy()
            labels_m[contSamples:contSamples + batchSize] = labels.cpu().numpy()
            contSamples += batchSize

    return outputs_m, labels_m

In [None]:
test_img_data = PituDataset(csv_file="/home/test_set.csv",
                            root_dir='/home',
                            transform=img_transforms)

x_test, y_test = prepare_data_features(encoder, test_img_data)
x_test = torch.mean(x_test, dim=[2, 3])
x_test = scaler.transform(x_test).astype(np.float32)

_, test_loader = create_data_loaders_from_arrays(torch.from_numpy(x_train), y_train, torch.from_numpy(x_test), y_test)

In [27]:
outputs_test, labels_test = obtain_label(logreg, test_img_data, test_loader)

In [None]:
# Compute confusion matrix
predicted_labels_test = np.argmax(outputs_test, axis=1)
cm = confusion_matrix(labels_test, predicted_labels_test, normalize='true')

# Create a heatmap of the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='0.2f', cmap='Blues')

# Customize the plot
plt.title('Confusion Matrix for SimCLR Model')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()