In [1]:
%%capture
# installing libs
!pip install -q tf-nightly[and-cuda]
!pip install -q monai

## Load Libs

In [2]:
# import torch
# from monai.networks.nets import SegResNet

# # Define the model parameters
# spatial_dims = 2
# in_channels = 1
# out_channels = 2
# blocks_down = [3, 4, 23, 3]
# blocks_up = [3, 4, 3]
# upsample_mode = "deconv"

# blocks_down = [3, 4, 23, 3]  # Standard ResNet101 block configuration
# blocks_up = [3, 6, 3]        # Enhanced decoder
# init_filters = 32

# # Instantiate the model
# model = SegResNet(
#     spatial_dims=spatial_dims,
#     in_channels=in_channels,
#     out_channels=out_channels,
#     blocks_down=blocks_down,
#     blocks_up=blocks_up,
#     upsample_mode=upsample_mode,
#     init_filters = init_filters
# )

# # Calculate the number of parameters
# num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

# print(f"The SegResNet model has {num_parameters/1000000} Million trainable parameters.")

# # Optional: Calculate model size in MB (assuming float32 parameters)
# bytes_per_parameter = 4 # for float32
# model_size_mb = (num_parameters * bytes_per_parameter) / (1024 * 1024)
# print(f"The estimated model size is {model_size_mb:.2f} MB (assuming float32).")

In [3]:
%%writefile utils.py

import os
import h5py
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from monai.networks.nets import SegResNet
from tqdm.notebook import tqdm, trange

class EmbeddingsDataset(Dataset):
    """Helper class to load and work with the stored embeddings"""
    
    def __init__(self, embeddings_path, metadata_path, transform=None):
        """
        Initialize the dataset
        
        Args:
            embeddings_path: Path to the directory containing H5 embedding files
            metadata_path: Path to the directory containing metadata files
            transform: Optional transforms to apply to the data
        """
        self.embeddings_path = embeddings_path
        self.metadata_path = metadata_path
        self.transform = transform
        self.master_metadata = pd.read_parquet(os.path.join(metadata_path, "master_metadata.parquet"))
        # Limit to data with labels
        self.metadata = self.master_metadata.dropna(subset=['label'])
        
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        """Get embedding and label for a specific index"""
        row = self.metadata.iloc[idx]
        batch_name = row['embedding_batch']
        embedding_index = row['embedding_index']
        label = row['label']
        
        # Load the embedding
        h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
        with h5py.File(h5_path, 'r') as h5f:
            embedding = h5f['embeddings'][embedding_index]
        
        # Convert to PyTorch tensor
        embedding = torch.tensor(embedding, dtype=torch.float32)
        
        # Reshape for CNN input - we expect embeddings of shape (384,)
        # Reshape to (1, 384, 1, 1) for network input
        embedding = embedding.view(1, 384, 1)
        
        # Convert label to tensor (0=negative, 1=positive)
        label = torch.tensor(label, dtype=torch.long)
        
        if self.transform:
            embedding = self.transform(embedding)
            
        return embedding, label
    
    def get_embedding(self, file_id):
        """Get embedding for a specific file ID"""
        # Find the file in metadata
        file_info = self.master_metadata[self.master_metadata['file_id'] == file_id]
        
        if len(file_info) == 0:
            raise ValueError(f"File ID {file_id} not found in metadata")
        
        # Get the batch and index
        batch_name = file_info['embedding_batch'].iloc[0]
        embedding_index = file_info['embedding_index'].iloc[0]
        
        # Load the embedding
        h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
        with h5py.File(h5_path, 'r') as h5f:
            embedding = h5f['embeddings'][embedding_index]
            
        return embedding, file_info['label'].iloc[0] if 'label' in file_info.columns else None

class SelfSupervisedHead(nn.Module):
    """Self-supervised learning head for cancer classification
    
    Since no coordinates or bounding boxes are available, this head focuses on
    learning from the entire embedding through self-supervision.
    """
    def __init__(self, in_channels, num_classes=2):
        super(SelfSupervisedHead, self).__init__()
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
        self.bn = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Self-supervised projector (MLP)
        self.projector = nn.Sequential(
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128)
        )
        
        # Classification layer
        self.fc = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        
        # Apply projector for self-supervised learning
        features = self.projector(x)
        
        # Classification output
        output = self.fc(features)
        return output, features

class SelfSupervisedCancerModel(nn.Module):
    """SegResNet with self-supervised learning head for cancer classification"""
    def __init__(self, num_classes=2):
        super(SelfSupervisedCancerModel, self).__init__()
        # Initialize SegResNet as backbone
        # Modified to work with 1-channel input and small input size
        self.backbone = SegResNet(
            spatial_dims=2,
            in_channels=1,
            out_channels=2,  # This is the number of output channels
            blocks_down=[3, 4, 23, 3],
            blocks_up=[3, 6, 3],
            upsample_mode="deconv",
            init_filters=32,
        )
        
        # We know from the structure that the final conv layer outputs 2 channels
        # Look at the print of self.backbone.conv_final showing Conv2d(8, 2, ...)
        backbone_out_channels = 2
        
        # Replace classifier with our self-supervised head
        self.ssl_head = SelfSupervisedHead(backbone_out_channels, num_classes)
        
        # Remove original classifier if needed
        if hasattr(self.backbone, 'class_layers'):
            self.backbone.class_layers = nn.Identity()
        
    def forward(self, x, return_features=False):
        # Run through backbone
        features = self.backbone(x)
        
        # Apply self-supervised head
        output, proj_features = self.ssl_head(features)
        
        if return_features:
            return output, proj_features
        return output

# NTXent Loss for contrastive learning
class NTXentLoss(nn.Module):
    """
    Normalized Temperature-scaled Cross Entropy Loss for contrastive learning
    """
    def __init__(self, temperature=0.07):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="mean")
        
    def forward(self, features, labels):
        # Normalize features
        features = F.normalize(features, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T) / self.temperature
        
        # Create masks for positive and negative pairs
        batch_size = features.shape[0]
        mask = torch.zeros_like(similarity_matrix)
        
        # For each anchor, samples with same label are positive pairs
        for i in range(batch_size):
            for j in range(batch_size):
                if i != j and labels[i] == labels[j]:
                    mask[i, j] = 1.0
        
        # Remove self-similarity from the matrix
        mask_self = torch.eye(batch_size, device=features.device)
        mask_not_self = 1 - mask_self
        similarity_matrix = similarity_matrix * mask_not_self
        
        # Create labels for contrastive loss
        # For each row, indices with same label are positive pairs
        pos_mask = mask.bool()
        if pos_mask.sum() == 0:  # If no positive pairs, return 0
            return torch.tensor(0.0, device=features.device)
            
        # Create labels for each anchor: the class of positive samples
        contrastive_labels = torch.zeros(batch_size, device=features.device).long()
        for i in range(batch_size):
            pos_indices = pos_mask[i].nonzero(as_tuple=True)[0]
            if len(pos_indices) > 0:
                contrastive_labels[i] = pos_indices[0]
                
        # Calculate loss
        loss = self.criterion(similarity_matrix, contrastive_labels)
        return loss

def create_data_loaders(embeddings_path, metadata_path, batch_size=32, test_size=0.2, random_state=42):
    """Create PyTorch DataLoaders for training and validation"""
    # Create full dataset
    full_dataset = EmbeddingsDataset(embeddings_path, metadata_path)
    
    # Split indices for train/val
    train_indices, val_indices = train_test_split(
        range(len(full_dataset)),
        test_size=test_size,
        random_state=random_state,
        stratify=[full_dataset.metadata['label'].iloc[i] for i in range(len(full_dataset))]
    )
    
    # Create subset datasets
    from torch.utils.data import Subset
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    return train_loader, val_loader

def train_model_with_logging(model, train_loader, val_loader, checkpoint_path:str,
                             device, writer, num_epochs=10, learning_rate=0.001,
                             temperature=0.07):
    """Train the cancer detection model with self-supervised learning and TensorBoard logging"""
    # Define loss functions and optimizer
    classification_criterion = nn.CrossEntropyLoss()
    contrastive_criterion = NTXentLoss(temperature)  # Contrastive loss for self-supervision
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )

    
    # Improved learning rate scheduler - CosineAnnealingWarmRestarts
    # This provides cyclical learning rates with warm restarts
    # Good for avoiding local minima and finding better global minima
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=5,           # Restart every 5 epochs
        T_mult=1,        # Keep the same cycle length
        eta_min=1e-5     # Minimum learning rate
    )

    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Only load scheduler if it's compatible with our current scheduler
        try:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        except:
            print("Warning: Couldn't load scheduler state from checkpoint (may be different type)")
            # Manually advance scheduler to match the epoch
            for _ in range(checkpoint['epoch']):
                scheduler.step()
                
        start_epoch = checkpoint['epoch']
        if 'best_val_loss' in checkpoint:
            best_val_loss = checkpoint['best_val_loss']
        print(f"Resuming training from epoch {start_epoch}")

    
    # Training loop
    best_val_auc = 0.0
    best_model_weights = None
    
    # Use trange for epoch progress
    epoch_iterator = trange(num_epochs, desc="Epochs")
    for epoch in epoch_iterator:
        print("epoc: ", epoch)
        # Training phase
        model.train()
        train_loss = 0.0
        train_class_loss = 0.0
        train_ssl_loss = 0.0
        train_correct = 0
        train_total = 0
        
        # Use tqdm for batch progress in training
        batch_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch_idx, (embeddings, labels) in enumerate(batch_iterator):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with features
            outputs, features = model(embeddings, return_features=True)
            
            # Classification loss
            class_loss = classification_criterion(outputs, labels)
            
            # Self-supervised contrastive loss
            # Group features by class for positive pairs
            pos_indices = (labels == 1).nonzero(as_tuple=True)[0]
            neg_indices = (labels == 0).nonzero(as_tuple=True)[0]
            
            # Only compute contrastive loss if we have samples from both classes
            if len(pos_indices) > 0 and len(neg_indices) > 0:
                pos_features = features[pos_indices]
                neg_features = features[neg_indices]
                
                # Create positive and negative pairs for contrastive learning
                all_features = torch.cat([pos_features, neg_features], dim=0)
                all_labels = torch.cat([
                    torch.ones(len(pos_indices), device=device),
                    torch.zeros(len(neg_indices), device=device)
                ])
                
                # Compute contrastive loss
                ssl_loss = contrastive_criterion(all_features, all_labels)
                
                # Total loss (weighted combination)
                loss = class_loss + 0.5 * ssl_loss
                train_ssl_loss += ssl_loss.item() * embeddings.size(0)
            else:
                # If we don't have samples from both classes, just use classification loss
                loss = class_loss
                ssl_loss = torch.tensor(0.0)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            train_loss += loss.item() * embeddings.size(0)
            train_class_loss += class_loss.item() * embeddings.size(0)
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            # Update progress bar with current loss
            batch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
            
            # Log batch-level metrics (every 10 batches)
            if batch_idx % 10 == 0:
                global_step = epoch * len(train_loader) + batch_idx
                writer.add_scalar('Batch/Loss/train', loss.item(), global_step)
                writer.add_scalar('Batch/ClassLoss/train', class_loss.item(), global_step)
                writer.add_scalar('Batch/SSLLoss/train', ssl_loss.item(), global_step)
                
                # Add histograms of model parameters
                if batch_idx % 50 == 0:
                    for name, param in model.named_parameters():
                        if param.requires_grad:
                            writer.add_histogram(f'Parameters/{name}', param.data, global_step)
        
        train_loss = train_loss / len(train_loader.dataset)
        train_class_loss = train_class_loss / len(train_loader.dataset)
        train_ssl_loss = train_ssl_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        all_labels = []
        all_probs = []
        all_preds = []
        
        # Use tqdm for batch progress in validation
        val_iterator = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}", leave=False)
        with torch.no_grad():
            for embeddings, labels in val_iterator:
                embeddings, labels = embeddings.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(embeddings)
                loss = classification_criterion(outputs, labels)
                
                # Statistics
                val_loss += loss.item() * embeddings.size(0)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                # Store for metrics calculation
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())
                probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                all_probs.extend(probs)
                
                # Update progress bar
                val_iterator.set_postfix({"val_loss": f"{loss.item():.4f}"})
        
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total
        
        # Calculate AUC and other metrics
        from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score, f1_score, confusion_matrix
        val_auc = roc_auc_score(all_labels, all_probs)
        val_ap = average_precision_score(all_labels, all_probs)
        val_f1 = f1_score(all_labels, all_preds)
        tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        
        # Update learning rate
        scheduler.step(val_auc)
        
        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_model_weights = model.state_dict().copy()
            torch.save(best_model_weights, f"{LOG_DIR}/best_model_epoch_{epoch}.pth")
            writer.add_text('Training', f'New best model saved at epoch {epoch} with AUC {best_val_auc:.4f}', epoch)
        
        # Update epoch progress bar with metrics
        epoch_iterator.set_postfix({
            "train_loss": f"{train_loss:.4f}", 
            "train_acc": f"{train_acc:.4f}",
            "val_loss": f"{val_loss:.4f}", 
            "val_acc": f"{val_acc:.4f}", 
            "val_auc": f"{val_auc:.4f}"
        })
        
        # Log epoch-level metrics to TensorBoard
        writer.add_scalar('Epoch/Loss/train', train_loss, epoch)
        writer.add_scalar('Epoch/Loss/val', val_loss, epoch)
        writer.add_scalar('Epoch/ClassLoss/train', train_class_loss, epoch)
        writer.add_scalar('Epoch/SSLLoss/train', train_ssl_loss, epoch)
        writer.add_scalar('Epoch/Accuracy/train', train_acc, epoch)
        writer.add_scalar('Epoch/Accuracy/val', val_acc, epoch)
        writer.add_scalar('Epoch/AUC/val', val_auc, epoch)
        writer.add_scalar('Epoch/AP/val', val_ap, epoch)
        writer.add_scalar('Epoch/F1/val', val_f1, epoch)
        writer.add_scalar('Epoch/Sensitivity/val', sensitivity, epoch)
        writer.add_scalar('Epoch/Specificity/val', specificity, epoch)
        writer.add_scalar('Epoch/LearningRate', optimizer.param_groups[0]['lr'], epoch)
        
        # Log PR curve (once every few epochs)
        if epoch % 2 == 0:
            precision, recall, _ = precision_recall_curve(all_labels, all_probs)
            # Create figure
            import matplotlib.pyplot as plt
            fig = plt.figure()
            plt.plot(recall, precision, marker='.', label=f'AP={val_ap:.3f}')
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.title(f'PR Curve - Epoch {epoch}')
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            # Add to tensorboard
            writer.add_figure(f'PR_Curve/epoch_{epoch}', fig, epoch)
            plt.close(fig)
            
            # Add confusion matrix as figure
            fig = plt.figure(figsize=(8, 6))
            plt.imshow([[tn, fp], [fn, tp]], cmap='Blues', interpolation='nearest')
            plt.colorbar()
            plt.title(f'Confusion Matrix - Epoch {epoch}')
            plt.xlabel('Predicted')
            plt.ylabel('Actual')
            thresh = (tn + fp + fn + tp) / 2
            for i in range(2):
                for j in range(2):
                    text = plt.text(j, i, [[tn, fp], [fn, tp]][i][j],
                                    ha="center", va="center", color="white" if [[tn, fp], [fn, tp]][i][j] > thresh else "black")
            plt.xticks([0, 1], ['Negative', 'Positive'])
            plt.yticks([0, 1], ['Negative', 'Positive'])
            writer.add_figure(f'Confusion_Matrix/epoch_{epoch}', fig, epoch)
            plt.close(fig)
    
    # Log model graph with sample input
    try:
        sample_input = torch.rand(1, 1, 384, 1).to(device)
        writer.add_graph(model, sample_input)
    except Exception as e:
        print(f"Couldn't add model graph to TensorBoard: {e}")
        
    # Load best model weights
    model.load_state_dict(best_model_weights)
    return model

Writing utils.py


In [4]:
import os
import h5py
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from monai.networks.nets import SegResNet
from tqdm.notebook import tqdm, trange

class EmbeddingsDataset(Dataset):
    """Helper class to load and work with the stored embeddings"""
    
    def __init__(self, embeddings_path, metadata_path, transform=None):
        """
        Initialize the dataset
        
        Args:
            embeddings_path: Path to the directory containing H5 embedding files
            metadata_path: Path to the directory containing metadata files
            transform: Optional transforms to apply to the data
        """
        self.embeddings_path = embeddings_path
        self.metadata_path = metadata_path
        self.transform = transform
        self.master_metadata = pd.read_parquet(os.path.join(metadata_path, "master_metadata.parquet"))
        # Limit to data with labels
        self.metadata = self.master_metadata.dropna(subset=['label'])
        
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        """Get embedding and label for a specific index"""
        row = self.metadata.iloc[idx]
        batch_name = row['embedding_batch']
        embedding_index = row['embedding_index']
        label = row['label']
        
        # Load the embedding
        h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
        with h5py.File(h5_path, 'r') as h5f:
            embedding = h5f['embeddings'][embedding_index]
        
        # Convert to PyTorch tensor
        embedding = torch.tensor(embedding, dtype=torch.float32)
        
        # Reshape for CNN input - we expect embeddings of shape (384,)
        # Reshape to (1, 384, 1, 1) for network input
        embedding = embedding.view(1, 384, 1)
        
        # Convert label to tensor (0=negative, 1=positive)
        label = torch.tensor(label, dtype=torch.long)
        
        if self.transform:
            embedding = self.transform(embedding)
            
        return embedding, label
    
    def get_embedding(self, file_id):
        """Get embedding for a specific file ID"""
        # Find the file in metadata
        file_info = self.master_metadata[self.master_metadata['file_id'] == file_id]
        
        if len(file_info) == 0:
            raise ValueError(f"File ID {file_id} not found in metadata")
        
        # Get the batch and index
        batch_name = file_info['embedding_batch'].iloc[0]
        embedding_index = file_info['embedding_index'].iloc[0]
        
        # Load the embedding
        h5_path = os.path.join(self.embeddings_path, f"{batch_name}.h5")
        with h5py.File(h5_path, 'r') as h5f:
            embedding = h5f['embeddings'][embedding_index]
            
        return embedding, file_info['label'].iloc[0] if 'label' in file_info.columns else None

class SelfSupervisedHead(nn.Module):
    """Self-supervised learning head for cancer classification
    
    Since no coordinates or bounding boxes are available, this head focuses on
    learning from the entire embedding through self-supervision.
    """
    def __init__(self, in_channels, num_classes=2):
        super(SelfSupervisedHead, self).__init__()
        self.conv = nn.Conv2d(in_channels, 128, kernel_size=1)
        self.bn = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Self-supervised projector (MLP)
        self.projector = nn.Sequential(
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128)
        )
        
        # Classification layer
        self.fc = nn.Linear(128, num_classes)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        
        # Apply projector for self-supervised learning
        features = self.projector(x)
        
        # Classification output
        output = self.fc(features)
        return output, features

class SelfSupervisedCancerModel(nn.Module):
    """SegResNet with self-supervised learning head for cancer classification"""
    def __init__(self, num_classes=2):
        super(SelfSupervisedCancerModel, self).__init__()
        # Initialize SegResNet as backbone
        # Modified to work with 1-channel input and small input size
        self.backbone = SegResNet(
            spatial_dims=2,
            in_channels=1,
            out_channels=2,  # This is the number of output channels
            blocks_down=[3, 4, 23, 3],
            blocks_up=[3, 6, 3],
            upsample_mode="deconv",
            init_filters=32,
        )
        
        # We know from the structure that the final conv layer outputs 2 channels
        # Look at the print of self.backbone.conv_final showing Conv2d(8, 2, ...)
        backbone_out_channels = 2
        
        # Replace classifier with our self-supervised head
        self.ssl_head = SelfSupervisedHead(backbone_out_channels, num_classes)
        
        # Remove original classifier if needed
        if hasattr(self.backbone, 'class_layers'):
            self.backbone.class_layers = nn.Identity()
        
    def forward(self, x, return_features=False):
        # Run through backbone
        features = self.backbone(x)
        
        # Apply self-supervised head
        output, proj_features = self.ssl_head(features)
        
        if return_features:
            return output, proj_features
        return output

# NTXent Loss for contrastive learning
class NTXentLoss(nn.Module):
    """
    Normalized Temperature-scaled Cross Entropy Loss for contrastive learning
    """
    def __init__(self, temperature=0.07):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="mean")
        
    def forward(self, features, labels):
        # Normalize features
        features = F.normalize(features, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T) / self.temperature
        
        # Create masks for positive and negative pairs
        batch_size = features.shape[0]
        mask = torch.zeros_like(similarity_matrix)
        
        # For each anchor, samples with same label are positive pairs
        for i in range(batch_size):
            for j in range(batch_size):
                if i != j and labels[i] == labels[j]:
                    mask[i, j] = 1.0
        
        # Remove self-similarity from the matrix
        mask_self = torch.eye(batch_size, device=features.device)
        mask_not_self = 1 - mask_self
        similarity_matrix = similarity_matrix * mask_not_self
        
        # Create labels for contrastive loss
        # For each row, indices with same label are positive pairs
        pos_mask = mask.bool()
        if pos_mask.sum() == 0:  # If no positive pairs, return 0
            return torch.tensor(0.0, device=features.device)
            
        # Create labels for each anchor: the class of positive samples
        contrastive_labels = torch.zeros(batch_size, device=features.device).long()
        for i in range(batch_size):
            pos_indices = pos_mask[i].nonzero(as_tuple=True)[0]
            if len(pos_indices) > 0:
                contrastive_labels[i] = pos_indices[0]
                
        # Calculate loss
        loss = self.criterion(similarity_matrix, contrastive_labels)
        return loss

def create_data_loaders(embeddings_path, metadata_path, batch_size=32, test_size=0.2, random_state=42):
    """Create PyTorch DataLoaders for training and validation"""
    # Create full dataset
    full_dataset = EmbeddingsDataset(embeddings_path, metadata_path)
    
    # Split indices for train/val
    train_indices, val_indices = train_test_split(
        range(len(full_dataset)),
        test_size=test_size,
        random_state=random_state,
        stratify=[full_dataset.metadata['label'].iloc[i] for i in range(len(full_dataset))]
    )
    
    # Create subset datasets
    from torch.utils.data import Subset
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    return train_loader, val_loader

def train_model_with_logging(model, train_loader, val_loader, checkpoint_path:str,
                             device, writer, num_epochs=10, learning_rate=0.001,
                             temperature=0.07):
    """Train the cancer detection model with self-supervised learning and TensorBoard logging"""
    # Define loss functions and optimizer
    classification_criterion = nn.CrossEntropyLoss()
    contrastive_criterion = NTXentLoss(temperature)  # Contrastive loss for self-supervision
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )

    
    # Improved learning rate scheduler - CosineAnnealingWarmRestarts
    # This provides cyclical learning rates with warm restarts
    # Good for avoiding local minima and finding better global minima
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=5,           # Restart every 5 epochs
        T_mult=1,        # Keep the same cycle length
        eta_min=1e-5     # Minimum learning rate
    )

    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Only load scheduler if it's compatible with our current scheduler
        try:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        except:
            print("Warning: Couldn't load scheduler state from checkpoint (may be different type)")
            # Manually advance scheduler to match the epoch
            for _ in range(checkpoint['epoch']):
                scheduler.step()
                
        start_epoch = checkpoint['epoch']
        if 'best_val_loss' in checkpoint:
            best_val_loss = checkpoint['best_val_loss']
        print(f"Resuming training from epoch {start_epoch}")

    
    # Training loop
    best_val_auc = 0.0
    best_model_weights = None
    
    # Use trange for epoch progress
    epoch_iterator = trange(num_epochs, desc="Epochs")
    for epoch in epoch_iterator:
        print("epoc: ", epoch)
        # Training phase
        model.train()
        train_loss = 0.0
        train_class_loss = 0.0
        train_ssl_loss = 0.0
        train_correct = 0
        train_total = 0
        
        # Use tqdm for batch progress in training
        batch_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch_idx, (embeddings, labels) in enumerate(batch_iterator):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with features
            outputs, features = model(embeddings, return_features=True)
            
            # Classification loss
            class_loss = classification_criterion(outputs, labels)
            
            # Self-supervised contrastive loss
            # Group features by class for positive pairs
            pos_indices = (labels == 1).nonzero(as_tuple=True)[0]
            neg_indices = (labels == 0).nonzero(as_tuple=True)[0]
            
            # Only compute contrastive loss if we have samples from both classes
            if len(pos_indices) > 0 and len(neg_indices) > 0:
                pos_features = features[pos_indices]
                neg_features = features[neg_indices]
                
                # Create positive and negative pairs for contrastive learning
                all_features = torch.cat([pos_features, neg_features], dim=0)
                all_labels = torch.cat([
                    torch.ones(len(pos_indices), device=device),
                    torch.zeros(len(neg_indices), device=device)
                ])
                
                # Compute contrastive loss
                ssl_loss = contrastive_criterion(all_features, all_labels)
                
                # Total loss (weighted combination)
                loss = class_loss + 0.5 * ssl_loss
                train_ssl_loss += ssl_loss.item() * embeddings.size(0)
            else:
                # If we don't have samples from both classes, just use classification loss
                loss = class_loss
                ssl_loss = torch.tensor(0.0)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            train_loss += loss.item() * embeddings.size(0)
            train_class_loss += class_loss.item() * embeddings.size(0)
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            # Update progress bar with current loss
            batch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
            
            # Log batch-level metrics (every 10 batches)
            if batch_idx % 10 == 0:
                global_step = epoch * len(train_loader) + batch_idx
                writer.add_scalar('Batch/Loss/train', loss.item(), global_step)
                writer.add_scalar('Batch/ClassLoss/train', class_loss.item(), global_step)
                writer.add_scalar('Batch/SSLLoss/train', ssl_loss.item(), global_step)
                
                # Add histograms of model parameters
                if batch_idx % 50 == 0:
                    for name, param in model.named_parameters():
                        if param.requires_grad:
                            writer.add_histogram(f'Parameters/{name}', param.data, global_step)
        
        train_loss = train_loss / len(train_loader.dataset)
        train_class_loss = train_class_loss / len(train_loader.dataset)
        train_ssl_loss = train_ssl_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        all_labels = []
        all_probs = []
        all_preds = []
        
        # Use tqdm for batch progress in validation
        val_iterator = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}", leave=False)
        with torch.no_grad():
            for embeddings, labels in val_iterator:
                embeddings, labels = embeddings.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(embeddings)
                loss = classification_criterion(outputs, labels)
                
                # Statistics
                val_loss += loss.item() * embeddings.size(0)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                # Store for metrics calculation
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())
                probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                all_probs.extend(probs)
                
                # Update progress bar
                val_iterator.set_postfix({"val_loss": f"{loss.item():.4f}"})
        
        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total
        
        # Calculate AUC and other metrics
        from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score, f1_score, confusion_matrix
        val_auc = roc_auc_score(all_labels, all_probs)
        val_ap = average_precision_score(all_labels, all_probs)
        val_f1 = f1_score(all_labels, all_preds)
        tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        
        # Update learning rate
        scheduler.step(val_auc)
        
        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_model_weights = model.state_dict().copy()
            torch.save(best_model_weights, f"{LOG_DIR}/best_model_epoch_{epoch}.pth")
            writer.add_text('Training', f'New best model saved at epoch {epoch} with AUC {best_val_auc:.4f}', epoch)
        
        # Update epoch progress bar with metrics
        epoch_iterator.set_postfix({
            "train_loss": f"{train_loss:.4f}", 
            "train_acc": f"{train_acc:.4f}",
            "val_loss": f"{val_loss:.4f}", 
            "val_acc": f"{val_acc:.4f}", 
            "val_auc": f"{val_auc:.4f}"
        })
        
        # Log epoch-level metrics to TensorBoard
        writer.add_scalar('Epoch/Loss/train', train_loss, epoch)
        writer.add_scalar('Epoch/Loss/val', val_loss, epoch)
        writer.add_scalar('Epoch/ClassLoss/train', train_class_loss, epoch)
        writer.add_scalar('Epoch/SSLLoss/train', train_ssl_loss, epoch)
        writer.add_scalar('Epoch/Accuracy/train', train_acc, epoch)
        writer.add_scalar('Epoch/Accuracy/val', val_acc, epoch)
        writer.add_scalar('Epoch/AUC/val', val_auc, epoch)
        writer.add_scalar('Epoch/AP/val', val_ap, epoch)
        writer.add_scalar('Epoch/F1/val', val_f1, epoch)
        writer.add_scalar('Epoch/Sensitivity/val', sensitivity, epoch)
        writer.add_scalar('Epoch/Specificity/val', specificity, epoch)
        writer.add_scalar('Epoch/LearningRate', optimizer.param_groups[0]['lr'], epoch)
        
        # Log PR curve (once every few epochs)
        if epoch % 2 == 0:
            precision, recall, _ = precision_recall_curve(all_labels, all_probs)
            # Create figure
            import matplotlib.pyplot as plt
            fig = plt.figure()
            plt.plot(recall, precision, marker='.', label=f'AP={val_ap:.3f}')
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.title(f'PR Curve - Epoch {epoch}')
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            # Add to tensorboard
            writer.add_figure(f'PR_Curve/epoch_{epoch}', fig, epoch)
            plt.close(fig)
            
            # Add confusion matrix as figure
            fig = plt.figure(figsize=(8, 6))
            plt.imshow([[tn, fp], [fn, tp]], cmap='Blues', interpolation='nearest')
            plt.colorbar()
            plt.title(f'Confusion Matrix - Epoch {epoch}')
            plt.xlabel('Predicted')
            plt.ylabel('Actual')
            thresh = (tn + fp + fn + tp) / 2
            for i in range(2):
                for j in range(2):
                    text = plt.text(j, i, [[tn, fp], [fn, tp]][i][j],
                                    ha="center", va="center", color="white" if [[tn, fp], [fn, tp]][i][j] > thresh else "black")
            plt.xticks([0, 1], ['Negative', 'Positive'])
            plt.yticks([0, 1], ['Negative', 'Positive'])
            writer.add_figure(f'Confusion_Matrix/epoch_{epoch}', fig, epoch)
            plt.close(fig)
    
    # Log model graph with sample input
    try:
        sample_input = torch.rand(1, 1, 384, 1).to(device)
        writer.add_graph(model, sample_input)
    except Exception as e:
        print(f"Couldn't add model graph to TensorBoard: {e}")
        
    # Load best model weights
    model.load_state_dict(best_model_weights)
    return model

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [5]:
# # First, save this code as a new cell to run after interrupting your current training
# # Run this immediately after your training is interrupted at epoch 4

# import torch

# # 1. Create a checkpoint from your current model state
# def create_checkpoint_from_model(model_path, output_checkpoint_path, current_epoch=4):
#     """
#     Create a checkpoint file from an existing .pth model file
#     """
#     # Load the model state dict
#     model_state_dict = torch.load(model_path)
    
#     # Create model
#     model = SelfSupervisedCancerModel()
#     model.load_state_dict(model_state_dict)
    
#     # Create a new optimizer and scheduler (we'll have to reinitialize these)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
#         optimizer, 
#         T_0=5,           # Restart every 5 epochs
#         T_mult=1,        # Keep the same cycle length
#         eta_min=1e-6     # Minimum learning rate
#     )
    
#     # If you've already trained for 4 epochs, advance the scheduler
#     for _ in range(current_epoch):
#         scheduler.step()
    
#     # Create checkpoint dictionary
#     checkpoint = {
#         'epoch': current_epoch,
#         'model_state_dict': model_state_dict,
#         'optimizer_state_dict': optimizer.state_dict(),
#         'scheduler_state_dict': scheduler.state_dict(),
#         'best_val_loss': float('inf'),  # We don't know this, so use a default
#         'train_loss': 0.0,              # We don't have this info
#         'val_loss': 0.0,                # We don't have this info
#         'accuracy': 0.0                 # We don't have this info
#     }
    
#     # Save the checkpoint
#     torch.save(checkpoint, output_checkpoint_path)
#     print(f"Created checkpoint at {output_checkpoint_path} from model at {model_path}")
    
#     return checkpoint

# # Create a checkpoint from your saved model
# model_path = "/kaggle/input/histopathology-cancer-classify/pytorch/midsize_model/1/best_model_epoch_2.pth"  # Your current saved model
# checkpoint_path = "/kaggle/working/checkpoint.pth"
# checkpoint = create_checkpoint_from_model(model_path, checkpoint_path, current_epoch=2)

# print("Checkpoint created! Download this file to resume training later.")
# print("To resume training, upload this checkpoint file and run the updated training code.")

In [6]:
# if __name__ == "__main__":
#     # Import TensorBoard SummaryWriter
#     from torch.utils.tensorboard import SummaryWriter
    
#     # Define paths
#     EMBEDDINGS_PATH = "/kaggle/input/histopath-cancer-embeddings/embeddings/"
#     METADATA_PATH = "/kaggle/input/histopath-cancer-embeddings/metadata/"
#     LOG_DIR = "/kaggle/working/tensorboard_logs"
    
#     # Create TensorBoard writer
#     writer = SummaryWriter(log_dir=LOG_DIR)
#     print(f"TensorBoard logs will be saved to: {LOG_DIR}")
    
#     # Set device
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     print(f"Using device: {device}")
    
#     # Create data loaders
#     train_loader, val_loader = create_data_loaders(
#         EMBEDDINGS_PATH, 
#         METADATA_PATH,
#         batch_size=512
#     )
#     print("data loaded")

#     init_channels = 32
#     # Create model
#     model = SelfSupervisedCancerModel()
#     model = model.to(device)
    
#     # Create a modified train_model function with TensorBoard logging

#     checkpoint_path = "/kaggle/working/checkpoint.pth"
#     # Train model with self-supervised learning and TensorBoard logging
#     trained_model = train_model_with_logging(
#         model=model,
#         train_loader=train_loader,
#         val_loader=val_loader,
#         device=device,
#         writer=writer,
#         num_epochs=3,
#         temperature=0.2,  # Temperature for contrastive loss
#         checkpoint_path="/kaggle/working/checkpoint.pth"
#     )
    
#     # Add embedding visualization to TensorBoard
#     print("\nAdding embeddings to TensorBoard...")
#     # Get a sample of embeddings
#     try:
#         # Create dataset for visualization
#         vis_dataset = EmbeddingsDataset(EMBEDDINGS_PATH, METADATA_PATH)
#         vis_loader = DataLoader(vis_dataset, batch_size=128, shuffle=False)
        
#         # Extract features using model
#         all_features = []
#         all_labels = []
#         all_images = []  # We don't have actual images, so we'll create simple representations
        
#         model.eval()
#         with torch.no_grad():
#             for i, (embeddings, labels) in enumerate(vis_loader):
#                 if i >= 10:  # Limit to 10 batches for visualization
#                     break
                    
#                 embeddings = embeddings.to(device)
#                 _, features = model(embeddings, return_features=True)
#                 all_features.append(features.cpu())
#                 all_labels.extend(labels.numpy())
                
#                 # Create simple image representations (1=white, 0=black)
#                 for label in labels:
#                     img = torch.ones(3, 32, 32) if label == 1 else torch.zeros(3, 32, 32)
#                     all_images.append(img)
        
#         # Concatenate features
#         if all_features:
#             all_features = torch.cat(all_features)
            
#             # Convert labels to strings for visualization
#             label_names = ['Negative' if l == 0 else 'Positive' for l in all_labels]
            
#             # Add embeddings to TensorBoard
#             writer.add_embedding(
#                 all_features, 
#                 metadata=label_names,
#                 label_img=torch.stack(all_images) if all_images else None,
#                 global_step=0
#             )
            
#     except Exception as e:
#         print(f"Error adding embeddings to TensorBoard: {e}")
    
#     # Save the model
#     torch.save(trained_model.state_dict(), "/kaggle/working/cancer_detector_model.pth")
#     print("Model saved to /kaggle/working/cancer_detector_model.pth")
    
#     # Close TensorBoard writer
#     writer.close()
#     print("TensorBoard logging complete")
    
#     # Example of inference and feature extraction
#     print("\nExample of model inference:")
#     try:
#         # Load dataset
#         dataset = EmbeddingsDataset(EMBEDDINGS_PATH, METADATA_PATH)
        
#         # Get a sample embedding
#         sample_id = dataset.master_metadata['file_id'].iloc[0]
#         embedding_data, true_label = dataset.get_embedding(sample_id)
        
#         # Prepare for inference
#         embedding_tensor = torch.tensor(embedding_data, dtype=torch.float32).view(1, 1, 384, 1).to(device)
        
#         # Run inference
#         model.eval()
#         with torch.no_grad():
#             # Get both output and features for visualization
#             output, features = model(embedding_tensor, return_features=True)
#             probs = torch.softmax(output, dim=1)
#             predicted_class = torch.argmax(probs, dim=1).item()
#             confidence = probs[0][predicted_class].item()
        
#         print(f"Sample ID: {sample_id}")
#         print(f"True label: {true_label} ({'Positive' if true_label == 1 else 'Negative'})")
#         print(f"Predicted class: {predicted_class} ({'Positive' if predicted_class == 1 else 'Negative'})")
#         print(f"Confidence: {confidence:.4f}")
        
#         # Extract features for all samples for visualization/clustering
#         print("\nExtracting features for all samples...")
#         all_features = []
#         all_labels = []
        
#         eval_dataset = EmbeddingsDataset(EMBEDDINGS_PATH, METADATA_PATH)
#         eval_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)
        
#         with torch.no_grad():
#             for embeddings, labels in eval_loader:
#                 embeddings = embeddings.to(device)
#                 _, batch_features = model(embeddings, return_features=True)
#                 all_features.append(batch_features.cpu().numpy())
#                 all_labels.append(labels.cpu().numpy())
        
#         all_features = np.vstack(all_features)
#         all_labels = np.concatenate(all_labels)
        
#         # Save features for downstream visualization/analysis
#         np.savez("/kaggle/working/cancer_features.npz", 
#                  features=all_features, 
#                  labels=all_labels)
        
#         print(f"Extracted features shape: {all_features.shape}")
#         print(f"Features saved to /kaggle/working/cancer_features.npz")
        
        
#     except Exception as e:
#         print(f"Error during inference: {e}")

## Distributed Training

In [7]:
%%writefile distributed_training_addon.py

import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from utils import *

def init_distributed():
    """Initialize distributed training environment"""
    # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
    dist_url = "env://"  # default

    # only works with torch.distributed.launch // torch.run
    rank = int(os.environ.get("RANK", "0"))
    world_size = int(os.environ.get('WORLD_SIZE', "1"))
    local_rank = int(os.environ.get('LOCAL_RANK', "0"))

    if world_size > 1:
        dist.init_process_group(
            backend="nccl",
            init_method=dist_url,
            world_size=world_size,
            rank=rank)

        # this will make all .cuda() calls work properly
        torch.cuda.set_device(local_rank)

        # synchronizes all the threads to reach this point before moving on
        dist.barrier()
        return True, rank, local_rank, world_size
    else:
        print("Not running in distributed mode")
        return False, 0, 0, 1


def create_distributed_data_loaders(embeddings_path, metadata_path, rank, world_size, batch_size=256, test_size=0.2, random_state=42):
    """Create PyTorch DataLoaders for distributed training"""
    # Create full dataset
    full_dataset = EmbeddingsDataset(embeddings_path, metadata_path)
    
    # Split indices for train/val
    train_indices, val_indices = train_test_split(
        range(len(full_dataset)),
        test_size=test_size,
        random_state=random_state,
        stratify=[full_dataset.metadata['label'].iloc[i] for i in range(len(full_dataset))]
    )
    
    # Create subset datasets
    from torch.utils.data import Subset
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    
    # Create distributed samplers
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )
    
    val_sampler = DistributedSampler(
        val_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False
    )
    
    # Create data loaders with distributed samplers
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,  # Don't shuffle here - sampler does it
        num_workers=4,
        pin_memory=True,
        sampler=train_sampler
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,  # Don't shuffle here
        num_workers=4,
        pin_memory=True,
        sampler=val_sampler
    )
    
    return train_loader, val_loader, train_sampler


def train_distributed_model(model, train_loader, val_loader, train_sampler, checkpoint_path, checkpoint_dir,
                            device, writer, rank, local_rank, is_distributed,
                            num_epochs=10, learning_rate=0.001, temperature=0.07):
    """Distributed training of the cancer detection model"""
    
    # Wrap model with DistributedDataParallel if using distributed training
    if is_distributed:
        # Convert BatchNorm to SyncBatchNorm for better stats across GPUs
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model = DistributedDataParallel(model, device_ids=[local_rank])
        model_without_ddp = model.module
    else:
        model_without_ddp = model
    
    # Define loss functions and optimizer
    classification_criterion = nn.CrossEntropyLoss()
    contrastive_criterion = NTXentLoss(temperature)  # Contrastive loss for self-supervision
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    
    # Improved learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=5,           # Restart every 5 epochs
        T_mult=1,        # Keep the same cycle length
        eta_min=1e-5     # Minimum learning rate
    )

    # Resume from checkpoint if exists
    start_epoch = 0
    best_val_auc = 0.0
    
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        if is_distributed:
            model_without_ddp.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
            
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Only load scheduler if it's compatible with our current scheduler
        try:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        except:
            print("Warning: Couldn't load scheduler state from checkpoint (may be different type)")
            # Manually advance scheduler to match the epoch
            for _ in range(checkpoint['epoch']):
                scheduler.step()
                
        start_epoch = checkpoint['epoch']
        if 'best_val_auc' in checkpoint:
            best_val_auc = checkpoint['best_val_auc']
        if rank == 0:
            print(f"Resuming training from epoch {start_epoch}")

    # Training loop
    best_model_weights = None
    
    from tqdm.notebook import trange, tqdm
    # Only show progress bar on main process
    if rank == 0:
        epoch_iterator = trange(start_epoch, num_epochs, desc="Epochs")
    else:
        epoch_iterator = range(start_epoch, num_epochs)
        
    for epoch in epoch_iterator:
        # Set sampler epoch for proper shuffling
        if is_distributed:
            train_sampler.set_epoch(epoch)
            
        # Training phase
        model.train()
        train_loss = 0.0
        train_class_loss = 0.0
        train_ssl_loss = 0.0
        train_correct = 0
        train_total = 0
        
        # Use tqdm for batch progress in training only on rank 0
        if rank == 0:
            batch_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", leave=False)
        else:
            batch_iterator = train_loader
            
        for batch_idx, (embeddings, labels) in enumerate(batch_iterator):
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with features
            outputs, features = model(embeddings, return_features=True)
            
            # Classification loss
            class_loss = classification_criterion(outputs, labels)
            
            # Self-supervised contrastive loss
            # Group features by class for positive pairs
            pos_indices = (labels == 1).nonzero(as_tuple=True)[0]
            neg_indices = (labels == 0).nonzero(as_tuple=True)[0]
            
            # Only compute contrastive loss if we have samples from both classes
            if len(pos_indices) > 0 and len(neg_indices) > 0:
                pos_features = features[pos_indices]
                neg_features = features[neg_indices]
                
                # Create positive and negative pairs for contrastive learning
                all_features = torch.cat([pos_features, neg_features], dim=0)
                all_labels = torch.cat([
                    torch.ones(len(pos_indices), device=device),
                    torch.zeros(len(neg_indices), device=device)
                ])
                
                # Compute contrastive loss
                ssl_loss = contrastive_criterion(all_features, all_labels)
                
                # Total loss (weighted combination)
                loss = class_loss + 0.5 * ssl_loss
                train_ssl_loss += ssl_loss.item() * embeddings.size(0)
            else:
                # If we don't have samples from both classes, just use classification loss
                loss = class_loss
                ssl_loss = torch.tensor(0.0)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            train_loss += loss.item() * embeddings.size(0)
            train_class_loss += class_loss.item() * embeddings.size(0)
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            # Update progress bar with current loss (only on rank 0)
            if rank == 0 and isinstance(batch_iterator, tqdm):
                batch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
            
            # Log batch-level metrics (every 10 batches) - only on rank 0
            if rank == 0 and batch_idx % 10 == 0:
                global_step = epoch * len(train_loader) + batch_idx
                writer.add_scalar('Batch/Loss/train', loss.item(), global_step)
                writer.add_scalar('Batch/ClassLoss/train', class_loss.item(), global_step)
                writer.add_scalar('Batch/SSLLoss/train', ssl_loss.item(), global_step)
                
                # Add histograms of model parameters
                if batch_idx % 50 == 0:
                    for name, param in model_without_ddp.named_parameters():
                        if param.requires_grad:
                            writer.add_histogram(f'Parameters/{name}', param.data, global_step)
        
        # Reduce metrics across all processes for accurate stats
        if is_distributed:
            # Create tensors on device for reduction
            train_loss_tensor = torch.tensor([train_loss], device=device)
            train_class_loss_tensor = torch.tensor([train_class_loss], device=device)
            train_ssl_loss_tensor = torch.tensor([train_ssl_loss], device=device)
            train_correct_tensor = torch.tensor([train_correct], device=device)
            train_total_tensor = torch.tensor([train_total], device=device)
            
            # All-reduce
            dist.all_reduce(train_loss_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(train_class_loss_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(train_ssl_loss_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(train_correct_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(train_total_tensor, op=dist.ReduceOp.SUM)
            
            # Update variables with reduced values
            train_loss = train_loss_tensor.item()
            train_class_loss = train_class_loss_tensor.item()
            train_ssl_loss = train_ssl_loss_tensor.item()
            train_correct = train_correct_tensor.item()
            train_total = train_total_tensor.item()
        
        train_loss = train_loss / train_total
        train_class_loss = train_class_loss / train_total
        train_ssl_loss = train_ssl_loss / train_total
        train_acc = train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        all_labels = []
        all_probs = []
        all_preds = []
        
        # Use tqdm for batch progress in validation (only on rank 0)
        if rank == 0:
            val_iterator = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}", leave=False)
        else:
            val_iterator = val_loader
            
        with torch.no_grad():
            for embeddings, labels in val_iterator:
                embeddings, labels = embeddings.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(embeddings)
                loss = classification_criterion(outputs, labels)
                
                # Statistics
                val_loss += loss.item() * embeddings.size(0)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
                # Store for metrics calculation
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())
                probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                all_probs.extend(probs)
                
                # Update progress bar (only on rank 0)
                if rank == 0 and isinstance(val_iterator, tqdm): 
                    val_iterator.set_postfix({"val_loss": f"{loss.item():.4f}"})
        
        # Reduce validation metrics across all processes
        if is_distributed:
            # Create tensors on device for reduction
            val_loss_tensor = torch.tensor([val_loss], device=device)
            val_correct_tensor = torch.tensor([val_correct], device=device)
            val_total_tensor = torch.tensor([val_total], device=device)
            
            # All-reduce
            dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(val_correct_tensor, op=dist.ReduceOp.SUM)
            dist.all_reduce(val_total_tensor, op=dist.ReduceOp.SUM)
            
            # Update variables with reduced values
            val_loss = val_loss_tensor.item()
            val_correct = val_correct_tensor.item()
            val_total = val_total_tensor.item()
            
            # Gather labels and predictions from all processes
            # First, convert lists to tensors
            all_labels_tensor = torch.tensor(all_labels, device=device)
            all_probs_tensor = torch.tensor(all_probs, device=device)
            all_preds_tensor = torch.tensor(all_preds, device=device)
            
            # Get sizes from all processes
            size_tensor = torch.tensor([len(all_labels)], device=device)
            # Get world size from the current process group
            current_world_size = dist.get_world_size()
            sizes = [torch.zeros_like(size_tensor) for _ in range(current_world_size)]
            dist.all_gather(sizes, size_tensor)
            
            # Create tensors to gather into
            max_size = max(size.item() for size in sizes)
            # Pad tensors to max size
            padded_labels = torch.zeros(max_size, device=device)
            padded_probs = torch.zeros(max_size, device=device)
            padded_preds = torch.zeros(max_size, device=device)
            
            # Copy data to padded tensors
            padded_labels[:len(all_labels)] = all_labels_tensor
            padded_probs[:len(all_probs)] = all_probs_tensor
            padded_preds[:len(all_preds)] = all_preds_tensor
            
            # Create list of tensors to gather into
            gathered_labels = [torch.zeros_like(padded_labels) for _ in range(current_world_size)]
            gathered_probs = [torch.zeros_like(padded_probs) for _ in range(current_world_size)]
            gathered_preds = [torch.zeros_like(padded_preds) for _ in range(current_world_size)]
            
            # Gather data from all processes
            dist.all_gather(gathered_labels, padded_labels)
            dist.all_gather(gathered_probs, padded_probs)
            dist.all_gather(gathered_preds, padded_preds)
            
            # Convert gathered tensors back to lists
            all_labels = []
            all_probs = []
            all_preds = []
            
            for i, size in enumerate(sizes):
                all_labels.extend(gathered_labels[i][:size].cpu().numpy())
                all_probs.extend(gathered_probs[i][:size].cpu().numpy())
                all_preds.extend(gathered_preds[i][:size].cpu().numpy())
        
        val_loss = val_loss / val_total
        val_acc = val_correct / val_total
        print("val loss: ", val_loss)
        print("val acc: ", val_acc)
        # Calculate AUC and other metrics (only on rank 0)
        if rank == 0:
            from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score, f1_score, confusion_matrix
            val_auc = roc_auc_score(all_labels, all_probs)
            val_ap = average_precision_score(all_labels, all_probs)
            val_f1 = f1_score(all_labels, all_preds)
            tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            
            # Update learning rate
            scheduler.step(val_auc)
            
            if not os.path.exists(os.path.dirname(checkpoint_dir)):
                os.makedirs(os.path.dirname(checkpoint_dir))
                
            # Save the checkpoint for this epoch
            epoch_checkpoint_path = f"{os.path.dirname(checkpoint_dir)}/model_epoch_{epoch}.pth"
            print("epoch_checkpoint_path: ", epoch_checkpoint_path)
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model_without_ddp.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_auc': val_auc
            }, epoch_checkpoint_path)
            writer.add_text('Training', f'Model saved at epoch {epoch}', epoch)
            
            # If this is also the best model, save it separately
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_model_weights = model_without_ddp.state_dict().copy()
                best_model_path = f"{os.path.dirname(checkpoint_dir)}/best_model_epoch_{epoch}.pth"
                torch.save(best_model_weights, best_model_path)
                writer.add_text('Training', f'New best model saved at epoch {epoch} with AUC {best_val_auc:.4f}', epoch)
                
                # Also update the latest checkpoint for resuming training
                torch.save({
                    'epoch': epoch + 1,
                    'model_state_dict': model_without_ddp.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_val_auc': best_val_auc
                }, checkpoint_path)
            
            # Update epoch progress bar with metrics (only on rank 0)
            if isinstance(epoch_iterator, tqdm):
                epoch_iterator.set_postfix({
                    "train_loss": f"{train_loss:.4f}", 
                    "train_acc": f"{train_acc:.4f}",
                    "val_loss": f"{val_loss:.4f}", 
                    "val_acc": f"{val_acc:.4f}", 
                    "val_auc": f"{val_auc:.4f}"
                })
            
            # Log epoch-level metrics to TensorBoard (only on rank 0)
            writer.add_scalar('Epoch/Loss/train', train_loss, epoch)
            writer.add_scalar('Epoch/Loss/val', val_loss, epoch)
            writer.add_scalar('Epoch/ClassLoss/train', train_class_loss, epoch)
            writer.add_scalar('Epoch/SSLLoss/train', train_ssl_loss, epoch)
            writer.add_scalar('Epoch/Accuracy/train', train_acc, epoch)
            writer.add_scalar('Epoch/Accuracy/val', val_acc, epoch)
            writer.add_scalar('Epoch/AUC/val', val_auc, epoch)
            writer.add_scalar('Epoch/AP/val', val_ap, epoch)
            writer.add_scalar('Epoch/F1/val', val_f1, epoch)
            writer.add_scalar('Epoch/Sensitivity/val', sensitivity, epoch)
            writer.add_scalar('Epoch/Specificity/val', specificity, epoch)
            writer.add_scalar('Epoch/LearningRate', optimizer.param_groups[0]['lr'], epoch)
            
            # Log PR curve (once every few epochs)
            if epoch % 2 == 0:
                precision, recall, _ = precision_recall_curve(all_labels, all_probs)
                # Create figure
                import matplotlib.pyplot as plt
                fig = plt.figure()
                plt.plot(recall, precision, marker='.', label=f'AP={val_ap:.3f}')
                plt.xlabel('Recall')
                plt.ylabel('Precision')
                plt.title(f'PR Curve - Epoch {epoch}')
                plt.legend()
                plt.grid(True)
                plt.tight_layout()
                # Add to tensorboard
                writer.add_figure(f'PR_Curve/epoch_{epoch}', fig, epoch)
                plt.close(fig)
                
                # Add confusion matrix as figure
                fig = plt.figure(figsize=(8, 6))
                plt.imshow([[tn, fp], [fn, tp]], cmap='Blues', interpolation='nearest')
                plt.colorbar()
                plt.title(f'Confusion Matrix - Epoch {epoch}')
                plt.xlabel('Predicted')
                plt.ylabel('Actual')
                thresh = (tn + fp + fn + tp) / 2
                for i in range(2):
                    for j in range(2):
                        text = plt.text(j, i, [[tn, fp], [fn, tp]][i][j],
                                        ha="center", va="center", color="white" if [[tn, fp], [fn, tp]][i][j] > thresh else "black")
                plt.xticks([0, 1], ['Negative', 'Positive'])
                plt.yticks([0, 1], ['Negative', 'Positive'])
                writer.add_figure(f'Confusion_Matrix/epoch_{epoch}', fig, epoch)
                plt.close(fig)
        
        # Wait for all processes to complete epoch before continuing
        if is_distributed:
            dist.barrier()
    
    # Load best model weights (only if we have them and are on rank 0)
    if rank == 0 and best_model_weights is not None:
        model_without_ddp.load_state_dict(best_model_weights)
        
    return model_without_ddp if is_distributed else model

Writing distributed_training_addon.py


In [8]:
!mv /kaggle/input/histopathology-cancer-classify/pytorch/default/2/checkpoint.pth /kaggle/working/checkpoint.pth

mv: cannot remove '/kaggle/input/histopathology-cancer-classify/pytorch/default/2/checkpoint.pth': Read-only file system


In [9]:
%%writefile kaggle_distributed_runner.py

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from utils import *
from distributed_training_addon import *

# Define paths
EMBEDDINGS_PATH = "/kaggle/input/histopath-cancer-embeddings/embeddings/"
METADATA_PATH = "/kaggle/input/histopath-cancer-embeddings/metadata/"
LOG_DIR = "/kaggle/working/tensorboard_logs/"
CHECKPOINT_PATH = "/kaggle/working/checkpoint.pth"
CHECKPOINT_DIR = "/kaggle/working/checkpoint_dir/"

print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

# Initialize distributed environment
is_distributed, rank, local_rank, world_size = init_distributed()
print(f"Distributed: {is_distributed}, Rank: {rank}, Local Rank: {local_rank}, World Size: {world_size}")

# Set device
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create TensorBoard writer only on main process
writer = SummaryWriter(log_dir=LOG_DIR) if rank == 0 else None
if rank == 0:
    print(f"TensorBoard logs will be saved to: {LOG_DIR}")
    # Create directory if it doesn't exist
    os.makedirs(LOG_DIR, exist_ok=True)

# Create distributed data loaders
train_loader, val_loader, train_sampler = create_distributed_data_loaders(
    EMBEDDINGS_PATH, 
    METADATA_PATH,
    rank=rank,
    world_size=world_size,
    batch_size= 512 if is_distributed else 512  # Scale batch size by number of GPUs
)
print(f"Rank {rank}: Data loaders created")

# Create model
model = SelfSupervisedCancerModel()
model = model.to(device)
print(f"Rank {rank}: Model created and moved to device")

# Train model with distributed training
trained_model = train_distributed_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    train_sampler=train_sampler,
    device=device,
    writer=writer,
    rank=rank,
    local_rank=local_rank,
    is_distributed=is_distributed,
    num_epochs=8,
    temperature=0.12,  # Temperature for contrastive loss
    checkpoint_path=CHECKPOINT_PATH,
    checkpoint_dir=CHECKPOINT_DIR
)

if rank == 0:
    print("Training complete")
    # Save the model
    torch.save(trained_model.state_dict(), "/kaggle/working/cancer_detector_model.pth")
    print("Model saved to /kaggle/working/cancer_detector_model.pth")
    
    # Close TensorBoard writer
    if writer is not None:
        writer.close()

# Clean up distributed processes
if is_distributed:
    dist.destroy_process_group()
    
# Example of inference (run only on one process)
if rank == 0:
    print("\nExample of model inference:")
    try:
        # Load dataset
        dataset = EmbeddingsDataset(EMBEDDINGS_PATH, METADATA_PATH)
        
        # Get a sample embedding
        sample_id = dataset.master_metadata['file_id'].iloc[0]
        embedding_data, true_label = dataset.get_embedding(sample_id)
        
        # Prepare for inference
        embedding_tensor = torch.tensor(embedding_data, dtype=torch.float32).view(1, 1, 384, 1).to(device)
        
        # Run inference
        trained_model.eval()
        with torch.no_grad():
            # Get both output and features for visualization
            output, features = trained_model(embedding_tensor, return_features=True)
            probs = torch.softmax(output, dim=1)
            predicted_class = torch.argmax(probs, dim=1).item()
            confidence = probs[0][predicted_class].item()
        
        print(f"Sample ID: {sample_id}")
        print(f"True label: {true_label} ({'Positive' if true_label == 1 else 'Negative'})")
        print(f"Predicted class: {predicted_class} ({'Positive' if predicted_class == 1 else 'Negative'})")
        print(f"Confidence: {confidence:.4f}")
        
    except Exception as e:
        print(f"Error during inference: {e}")

Writing kaggle_distributed_runner.py


In [10]:
import torch
torch.cuda.empty_cache()

In [11]:
%%writefile launch_training.py
from utils import *

import os
import sys
import torch.distributed as dist
import torch.multiprocessing as mp
import subprocess
import torch

def main():
    """
    Script to launch distributed training on Kaggle T4x2 environment
    This should be run with torchrun or torch.distributed.launch
    """
    # The actual training code will be imported by each worker process
    print("Launching distributed training with torchrun...")
    world_size = torch.cuda.device_count()
    print(f"World size: {world_size}")
    
    # Run the main script with torchrun
    cmd = [
        sys.executable, "-m", "torch.distributed.run",
        f"--nproc_per_node={world_size}",
        "kaggle_distributed_runner.py"
    ]
    
    process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    
    # Stream output in real-time
    for line in process.stdout:
        print(line, end="")
    
    process.wait()
    return process.returncode
    
if __name__ == "__main__":
    # Entry point for the script
    main()

Writing launch_training.py


In [12]:
!python launch_training.py

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
Launching distributed training with torchrun...
World size: 2
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/tensorboard/compat/__init__.py", line 42, in tf
    from tensorboard.compat import notf  # noqa: F401
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ImportError: cannot import name 'notf' from 'tensorboard.compat' (/usr/local/lib/python3.11/dist-packages/tensorboard/compat/__init__.py)

During handling of the above exception, another exception occurred:

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/tensorboard/compat/__init__.py", line 42, in tf
    from tensorboard.compat import notf  # noqa: F401
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ImportError: cannot import name 'notf' from 'tensorboard.compat' (/usr/local/lib/python3.11/dist-packages/tensorboard/co

In [13]:
# model = run_distributed_training(
#     embeddings_path='/kaggle/input/histopath-cancer-embeddings/embeddings',
#     metadata_path='/kaggle/input/histopath-cancer-embeddings/metadata',
#     batch_size=2048,  # Can use larger batch size now
#     num_epochs=20
# )