In [None]:
# Cell 1: Install dependencies
# %pip install scikit-learn scikit-image SimpleITK nibabel nilearn albumentations seaborn pandas numpy matplotlib tqdm pydicom scipy umap-learn

In [3]:
# Cell 1: Import necessary libraries and setup environment
import os
import sys
import json
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from umap import UMAP

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# GPU Setup
def configure_gpu():
    """Configure GPU settings optimized for NVIDIA 4070Ti"""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        # Enable CUDNN benchmark for optimized convolution performance
        torch.backends.cudnn.benchmark = True
        # Print GPU info
        print(f"Using GPU: {torch.cuda.get_device_name(device)}")
        print(f"Total GPU Memory: {torch.cuda.get_device_properties(device).total_memory / 1e9:.2f} GB")
    else:
        raise EnvironmentError("CUDA-compatible GPU not found. Please check your GPU configuration.")
    return device

# Memory monitoring
def print_gpu_stats():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1e9  # Convert to GB
        reserved = torch.cuda.memory_reserved() / 1e9
        print(f"GPU Memory Allocated: {allocated:.2f} GB")
        print(f"GPU Memory Reserved: {reserved:.2f} GB")

# Configure device
device = configure_gpu()
print_gpu_stats()

Using GPU: NVIDIA GeForce RTX 4070 Ti
Total GPU Memory: 12.88 GB
GPU Memory Allocated: 0.00 GB
GPU Memory Reserved: 0.00 GB


In [6]:
# Cell 2a: Model Definitions
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

# Base Convolutional Block
class ConvBlock(nn.Module):
    """Memory-efficient convolutional block."""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.block = nn.Sequential(OrderedDict([
            ('conv', nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)),
            ('bn', nn.BatchNorm3d(out_channels)),
            ('relu', nn.ReLU(inplace=True))
        ]))

    def forward(self, x):
        return self.block(x)

# Unsupervised Autoencoder
class BaseAutoencoder(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        torch.backends.cudnn.benchmark = True

    def forward(self, x):
        z, skip_connections = self.encoder(x)
        reconstruction = self.decoder(z, skip_connections)
        return reconstruction

    def encode(self, x):
        z, _ = self.encoder(x)
        return z

    def decode(self, z):
        batch_size = z.size(0)
        device = z.device
        dummy_skips = (
            torch.zeros(batch_size, 32, 64, 64, 64, device=device),
            torch.zeros(batch_size, 64, 32, 32, 32, device=device),
            torch.zeros(batch_size, 128, 16, 16, 16, device=device),
            torch.zeros(batch_size, 256, 8, 8, 8, device=device)
        )
        return self.decoder(z, dummy_skips)

class Encoder(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.init_conv = ConvBlock(1, 16)
        self.down1 = nn.Sequential(
            ConvBlock(16, 32, stride=2),
            ConvBlock(32, 32)
        )
        self.down2 = nn.Sequential(
            ConvBlock(32, 64, stride=2),
            ConvBlock(64, 64)
        )
        self.down3 = nn.Sequential(
            ConvBlock(64, 128, stride=2),
            ConvBlock(128, 128)
        )
        self.down4 = nn.Sequential(
            ConvBlock(128, 256, stride=2),
            ConvBlock(256, 256)
        )
        self.flatten_size = 256 * 8 * 8 * 8
        self.fc = nn.Linear(self.flatten_size, latent_dim)

    def forward(self, x):
        x = self.init_conv(x)
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        flat = torch.flatten(d4, start_dim=1)
        z = self.fc(flat)
        return z, (d1, d2, d3, d4)

class Decoder(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.flatten_size = 256 * 8 * 8 * 8
        self.fc = nn.Linear(latent_dim, self.flatten_size)
        self.up1 = nn.Sequential(
            nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2),
            ConvBlock(128, 128)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            ConvBlock(64, 64)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            ConvBlock(32, 32)
        )
        self.up4 = nn.Sequential(
            nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2),
            ConvBlock(16, 16)
        )
        self.final_conv = nn.Conv3d(16, 1, kernel_size=1)

    def forward(self, z, skip_connections):
        x = self.fc(z)
        x = x.view(-1, 256, 8, 8, 8)
        d1, d2, d3, d4 = skip_connections
        x = self.up1(x + d4)
        x = self.up2(x + d3)
        x = self.up3(x + d2)
        x = self.up4(x + d1)
        x = torch.sigmoid(self.final_conv(x))
        return x

In [7]:
# Cell 2b: VAE and Semi-Supervised Models
class VAE(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.encoder = VAEEncoder(latent_dim)
        self.decoder = VAEDecoder(latent_dim)
        torch.backends.cudnn.benchmark = True

    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu

    def forward(self, x):
        mu, log_var, skip_connections = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.decoder(z, skip_connections)
        return reconstruction, mu, log_var

class VAEEncoder(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.init_conv = ConvBlock(1, 16)
        self.down1 = nn.Sequential(
            ConvBlock(16, 32, stride=2),
            ConvBlock(32, 32)
        )
        self.down2 = nn.Sequential(
            ConvBlock(32, 64, stride=2),
            ConvBlock(64, 64)
        )
        self.down3 = nn.Sequential(
            ConvBlock(64, 128, stride=2),
            ConvBlock(128, 128)
        )
        self.down4 = nn.Sequential(
            ConvBlock(128, 256, stride=2),
            ConvBlock(256, 256)
        )
        self.flatten_size = 256 * 8 * 8 * 8
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_var = nn.Linear(self.flatten_size, latent_dim)

    def forward(self, x):
        x = self.init_conv(x)
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        flat = torch.flatten(d4, start_dim=1)
        mu = self.fc_mu(flat)
        log_var = self.fc_var(flat)
        return mu, log_var, (d1, d2, d3, d4)

class VAEDecoder(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()
        self.flatten_size = 256 * 8 * 8 * 8
        self.fc = nn.Linear(latent_dim, self.flatten_size)
        self.up1 = nn.Sequential(
            nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2),
            ConvBlock(128, 128)
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2),
            ConvBlock(64, 64)
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2),
            ConvBlock(32, 32)
        )
        self.up4 = nn.Sequential(
            nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2),
            ConvBlock(16, 16)
        )
        self.final_conv = nn.Conv3d(16, 1, kernel_size=1)

    def forward(self, z, skip_connections):
        x = self.fc(z)
        x = x.view(-1, 256, 8, 8, 8)
        d1, d2, d3, d4 = skip_connections
        x = self.up1(x + d4)
        x = self.up2(x + d3)
        x = self.up3(x + d2)
        x = self.up4(x + d1)
        x = torch.sigmoid(self.final_conv(x))
        return x

# Semi-Supervised Autoencoder
class SemiSupervisedAE(nn.Module):
    def __init__(self, latent_dim=256, num_classes=3):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        torch.backends.cudnn.benchmark = True

    def forward(self, x):
        z, skip_connections = self.encoder(x)
        reconstruction = self.decoder(z, skip_connections)
        classification = self.classifier(z)
        return reconstruction, classification, z

# Semi-Supervised VAE
class SSVAE(nn.Module):
    def __init__(self, latent_dim=256, num_classes=3):
        super().__init__()
        self.encoder = VAEEncoder(latent_dim)
        self.decoder = VAEDecoder(latent_dim)
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        torch.backends.cudnn.benchmark = True

    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu

    def forward(self, x):
        mu, log_var, skip_connections = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.decoder(z, skip_connections)
        classification = self.classifier(z)
        return reconstruction, classification, mu, log_var

In [8]:
# Cell 2: Load models and dataset

class ModelLoader:
    """Handles loading of pretrained models"""
    def __init__(self, checkpoint_dir='checkpoints'):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.model_configs = {
            'autoencoder': (BaseAutoencoder, 'autoencoder_checkpoint.pth'),
            'vae': (VAE, 'vae_checkpoint.pth'),
            'ssae': (SemiSupervisedAE, 'ssae_checkpoint.pth'),
            'ssvae': (SSVAE, 'ssvae_checkpoint.pth')
        }

    def load_model(self, model_name):
        """Load a specific model from checkpoint"""
        if model_name not in self.model_configs:
            raise ValueError(f"Unknown model: {model_name}")

        ModelClass, checkpoint_file = self.model_configs[model_name]
        checkpoint_path = self.checkpoint_dir / checkpoint_file

        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

        # Initialize model
        model = ModelClass()
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        model.to(device)

        return model

    def load_all_models(self):
        """Load all models"""
        models = {}
        for model_name in self.model_configs:
            print(f"Loading {model_name}...")
            models[model_name] = self.load_model(model_name)
            print_gpu_stats()
        return models

# Load validation dataset
def load_validation_data(batch_size=8):
    """Load validation dataset from validated_file_paths.csv"""
    df = pd.read_csv('validated_file_paths.csv')
    val_dataset = DaTScanDataset(df)  # Your existing dataset class
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    return val_loader

In [10]:
# Cell 2c: Dataset Class and Loading Functions
import pydicom
import numpy as np
from torch.utils.data import Dataset

def load_dicom(file_path):
    """Load and process a DICOM file"""
    try:
        ds = pydicom.dcmread(file_path)
        pixel_array = ds.pixel_array.astype(np.float32)
        
        # Apply rescaling if attributes are present
        if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
            slope = ds.RescaleSlope
            intercept = ds.RescaleIntercept
            pixel_array = pixel_array * slope + intercept
            
        return pixel_array, ds
    except Exception as e:
        raise IOError(f"Error reading DICOM file {file_path}: {e}")

def process_volume(volume, target_shape=(128, 128, 128)):
    """Process a 3D volume (normalize, resize, mask)"""
    # Normalize intensity
    volume = np.clip(volume, a_min=0, a_max=None)
    vmin, vmax = volume.min(), volume.max()
    if vmax > vmin:
        volume = (volume - vmin) / (vmax - vmin)
    else:
        volume = volume - vmin

    # Resize to target shape if needed
    if volume.shape != target_shape:
        # Your resizing logic here (maintaining aspect ratio)
        pass

    return volume

class DaTScanDataset(Dataset):
    """Dataset class for DaTSCAN images"""
    def __init__(self, dataframe):
        self.df = dataframe
        self._calculate_dataset_statistics()

    def _calculate_dataset_statistics(self):
        """Calculate dataset statistics"""
        stats_list = []
        for _, row in self.df.iterrows():
            try:
                volume, _ = load_dicom(row["file_path"])
                stats_list.append({
                    'min': volume.min(),
                    'max': volume.max()
                })
            except Exception as e:
                print(f"Error processing file {row['file_path']}: {e}")

        if stats_list:
            self.stats = {
                'min': min(stat['min'] for stat in stats_list),
                'max': max(stat['max'] for stat in stats_list)
            }
        else:
            self.stats = {'min': 0, 'max': 1}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        try:
            file_path = self.df.iloc[idx]["file_path"]
            
            # Load DICOM
            volume, _ = load_dicom(file_path)
            
            # Process volume
            processed_vol = process_volume(volume)
            
            # Convert to tensor
            volume_tensor = torch.from_numpy(np.expand_dims(processed_vol, axis=0)).float()

            return {
                "volume": volume_tensor,
                "label": self.df.iloc[idx]["label"],
                "path": file_path
            }

        except Exception as e:
            print(f"Error loading file {file_path}: {str(e)}")
            return None

In [None]:
# Cell 3: Latent Space Extraction
class LatentSpaceExtractor:
    """Extracts and manages latent representations from models"""
    def __init__(self, models, val_loader):
        self.models = models
        self.val_loader = val_loader
        self.latent_vectors = {}
        self.labels = None

    def extract_latent_vectors(self, batch_size=8):
        """Extract latent vectors from all models"""
        for model_name, model in self.models.items():
            print(f"\nExtracting latent vectors for {model_name}...")
            vectors = []
            labels = []

            with torch.no_grad():
                for batch in tqdm(self.val_loader):
                    volumes = batch['volume'].to(device)
                    batch_labels = batch['label']

                    # Extract latent vectors based on model type
                    if isinstance(model, (VAE, SSVAE)):
                        _, _, mu, _ = model(volumes)
                        vectors.append(mu.cpu().numpy())
                    else:  # Autoencoder or Semi-supervised AE
                        if isinstance(model, SemiSupervisedAE):
                            _, _, z = model(volumes)
                        else:
                            z = model.encode(volumes)
                        vectors.append(z.cpu().numpy())

                    labels.extend(batch_labels)

                    # Clean up GPU memory
                    del volumes
                    torch.cuda.empty_cache()

            self.latent_vectors[model_name] = np.concatenate(vectors)
            if self.labels is None:
                self.labels = np.array(labels)

            print_gpu_stats()

        return self.latent_vectors, self.labels

# Example usage:
model_loader = ModelLoader()
models = model_loader.load_all_models()
val_loader = load_validation_data()
extractor = LatentSpaceExtractor(models, val_loader)
latent_vectors, labels = extractor.extract_latent_vectors()

print("\nExtraction complete! Latent vectors shape for each model:")
for model_name, vectors in latent_vectors.items():
    print(f"{model_name}: {vectors.shape}")

Loading autoencoder...


  checkpoint = torch.load(checkpoint_path)


GPU Memory Allocated: 1.70 GB
GPU Memory Reserved: 2.80 GB
Loading vae...
GPU Memory Allocated: 2.12 GB
GPU Memory Reserved: 3.48 GB
Loading ssae...
GPU Memory Allocated: 2.41 GB
GPU Memory Reserved: 3.48 GB
Loading ssvae...
GPU Memory Allocated: 2.83 GB
GPU Memory Reserved: 4.16 GB

Extracting latent vectors for autoencoder...


  0%|          | 0/374 [00:00<?, ?it/s]