**IMPORTS AND TPU SET UP**

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pydicom
from pydicom.errors import InvalidDicomError
import nibabel as nib
import cv2
from scipy import ndimage
from tqdm import tqdm
import warnings
import gc
import time
import signal
import sys
from contextlib import contextmanager
warnings.filterwarnings('ignore')

# TPU-SPECIFIC SETUP
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.utils.utils as xu
    TPU_AVAILABLE = True
    print("TPU libraries loaded successfully")
except ImportError:
    TPU_AVAILABLE = False
    print("TPU libraries not available, falling back to GPU/CPU")

#TPU Environment Setup
if TPU_AVAILABLE:
    os.environ["XLA_USE_BF16"] = "1"  # Enable bfloat16 for TPU efficiency
    os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"  # 100MB tensor allocator
    os.environ["TPU_NUM_DEVICES"] = "8"  # v3-8 = 8 cores
else:
    # GPU fallback settings
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

# Reduce thread contention (for TPU)
os.environ["ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

print(f"TPU Available: {TPU_AVAILABLE}")
if TPU_AVAILABLE:
    print(f"TPU Devices: {xm.xrt_world_size()}")

# TIMEOUT AND HANGING PREVENTION
class TimeoutError(Exception):
    pass

@contextmanager
def timeout(duration):
    def timeout_handler(signum, frame):
        raise TimeoutError(f"Operation timed out after {duration} seconds")
    
    # Set the signal handler
    if hasattr(signal, 'SIGALRM'):  # Unix systems
        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(duration)
        try:
            yield
        finally:
            signal.alarm(0)
    else:
        yield


**TPU OPTMIZED CONFIGURATION AND TRAINING PIPELINE**

In [None]:
class TPUConfig:
    # Paths
    TRAIN_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv'
    LOCALIZER_CSV_PATH = '/kaggle/input/rsna-intracranial-aneurysm-detection/train_localizers.csv'
    SERIES_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/series/'
    SEGMENTATION_DIR = '/kaggle/input/rsna-intracranial-aneurysm-detection/segmentations/'
    
    # TPU-OPTIMIZED SIZES - Critical for v3-8
    STAGE1_TARGET_SIZE = (16, 32, 32)  # Increased for TPU efficiency
    STAGE1_BATCH_SIZE = 8 if TPU_AVAILABLE else 2 
    STAGE1_EPOCHS = 10 
    STAGE1_LR = 1e-3 if TPU_AVAILABLE else 3e-4
    
    # TPU Device Setup
    if TPU_AVAILABLE:
        DEVICE = xm.xla_device()
        print(f"Using TPU device: {DEVICE}")
    else:
        DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using fallback device: {DEVICE}")
    
    MIXED_PRECISION = True
    N_FOLDS = 3
    
    # Constants
    ID_COL = 'SeriesInstanceUID'
    LABEL_COLS = [
        'Left Infraclinoid Internal Carotid Artery', 'Right Infraclinoid Internal Carotid Artery',
        'Left Supraclinoid Internal Carotid Artery', 'Right Supraclinoid Internal Carotid Artery',
        'Left Middle Cerebral Artery', 'Right Middle Cerebral Artery', 'Anterior Communicating Artery',
        'Left Anterior Cerebral Artery', 'Right Anterior Cerebral Artery',
        'Left Posterior Communicating Artery', 'Right Posterior Communicating Artery',
        'Basilar Tip', 'Other Posterior Circulation', 'Aneurysm Present',
    ]
    TARGET_COL = 'Aneurysm Present'
    
    # Debug settings
    DEBUG_MODE = False
    DEBUG_SAMPLES = 100 
    
    # TPU-specific settings
    TPU_SYNC_FREQUENCY = 10  # Sync every N batches
    CHECKPOINT_FREQUENCY = 3  # Save every N epochs
    MAX_TIMEOUT_SECONDS = 3600  # 1 hour timeout per epoch

print(f"✅ TPU Configuration loaded - Device: {TPUConfig.DEVICE}")

# TPU-OPTIMIZED 3D UNET
class TPUOptimized3DUNet(nn.Module):
    def __init__(self, spatial_dims=3, in_channels=1, out_channels=16, 
                 features=(16, 32, 64, 32), dropout=0.1):
        super().__init__()
        
        self.features = features
        self.dropout = dropout
        
        # Simplified Encoder - TPU-friendly operations
        self.encoder_blocks = nn.ModuleList()
        prev_channels = in_channels
        
        for feature_count in features:
            # Single conv block per level for TPU efficiency
            block = nn.Sequential(
                nn.Conv3d(prev_channels, feature_count, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm3d(feature_count),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
            )
            self.encoder_blocks.append(block)
            prev_channels = feature_count
        
        # Simplified downsampling
        self.downsample_layers = nn.ModuleList([
            nn.MaxPool3d(kernel_size=2, stride=2) 
            for _ in range(min(2, len(features) - 1))  # Limit levels
        ])
        
        # Simplified Decoder
        self.upsample_layers = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()
        
        # Build decoder (reverse order)
        decoder_features = list(reversed(features))
        
        for i in range(min(2, len(decoder_features) - 1)):
            current_features = decoder_features[i]
            next_features = decoder_features[i + 1]
            
            # Upsample layer
            upsample = nn.ConvTranspose3d(
                current_features, next_features,
                kernel_size=2, stride=2, bias=False
            )
            self.upsample_layers.append(upsample)
            
            # Decoder block with skip connection
            decoder_block = nn.Sequential(
                nn.Conv3d(next_features * 2, next_features, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm3d(next_features),
                nn.ReLU(inplace=True),
                nn.Dropout3d(dropout) if dropout > 0 else nn.Identity()
            )
            self.decoder_blocks.append(decoder_block)
        
        # Final output layer
        self.final_conv = nn.Conv3d(features[0], out_channels, kernel_size=1, bias=True)
        
        # Initialize weights for TPU stability
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        skip_connections = []
        
        # Encoder path
        for i, encoder_block in enumerate(self.encoder_blocks):
            x = encoder_block(x)
            skip_connections.append(x)
            
            if i < len(self.downsample_layers):
                x = self.downsample_layers[i](x)
        
        # Decoder path
        skip_connections = skip_connections[:-1]  # Remove last (bottleneck)
        skip_connections.reverse()
        
        for i, (upsample_layer, decoder_block) in enumerate(zip(self.upsample_layers, self.decoder_blocks)):
            x = upsample_layer(x)
            
            if i < len(skip_connections):
                skip = skip_connections[i]
                
                # Handle size mismatch with interpolation
                if x.shape[2:] != skip.shape[2:]:
                    x = nn.functional.interpolate(
                        x, size=skip.shape[2:], 
                        mode='trilinear', align_corners=False
                    )
                
                x = torch.cat([x, skip], dim=1)
            
            x = decoder_block(x)
        
        x = self.final_conv(x)
        return x

# TPU-OPTIMIZED DICOM PROCESSOR
class TPUDICOMProcessor:
    def __init__(self, target_size=None):
        self.target_size = target_size or TPUConfig.STAGE1_TARGET_SIZE
        self.max_files = 15  # reduced for TPU memory
        
    def load_dicom_series(self, series_path):
        try:
            if not os.path.exists(series_path):
                return self._get_dummy_volume()
                
            dicom_files = [f for f in os.listdir(series_path) if f.endswith('.dcm')]
            if not dicom_files:
                return self._get_dummy_volume()
            
            # file limited for TPU memory
            step_size = max(1, len(dicom_files) // self.max_files)
            selected_files = dicom_files[::step_size][:self.max_files]
            
            pixel_arrays = []
            target_slice_shape = (96, 96)  # Reduced for TPU memory
            
            for file_name in selected_files:
                try:
                    with timeout(30):  # 30 second timeout per file
                        ds = pydicom.dcmread(
                            os.path.join(series_path, file_name), 
                            force=True, 
                            stop_before_pixels=False
                        )
                        
                        if hasattr(ds, 'pixel_array'):
                            arr = ds.pixel_array.astype(np.float32)
                            
                            if arr.ndim == 2:
                                # Resize for memory efficiency
                                if arr.shape != target_slice_shape:
                                    arr = cv2.resize(arr, target_slice_shape, interpolation=cv2.INTER_LINEAR)
                                pixel_arrays.append(arr)
                            elif arr.ndim == 3:
                                # Handle 3D volumes - take middle slice
                                middle_slice = arr[arr.shape[0] // 2]
                                if middle_slice.shape != target_slice_shape:
                                    middle_slice = cv2.resize(middle_slice, target_slice_shape, interpolation=cv2.INTER_LINEAR)
                                pixel_arrays.append(middle_slice)
                        
                        # Critical: immediate cleanup
                        del ds
                        
                except TimeoutError:
                    print(f"Timeout loading {file_name}")
                    continue
                except Exception as e:
                    print(f"Error loading {file_name}: {str(e)[:100]}")
                    continue
            
            if not pixel_arrays:
                return self._get_dummy_volume()
            
            # Create volume and clean up
            volume = np.stack(pixel_arrays, axis=0).astype(np.float32)
            del pixel_arrays  # Immediate cleanup
            
            # Preprocess
            volume = self._preprocess_volume(volume)
            
            return volume
            
        except Exception as e:
            print(f"Failed to load series {os.path.basename(series_path)}: {str(e)[:100]}")
            return self._get_dummy_volume()
    
    def _get_dummy_volume(self):
        """Return dummy volume for failed loads"""
        return np.zeros(self.target_size, dtype=np.float32)
    
    def _preprocess_volume(self, volume):
        # Clip extreme outliers
        p1, p99 = np.percentile(volume, [2, 98])  # Less aggressive clipping
        volume = np.clip(volume, p1, p99)
        
        # Normalize
        if p99 > p1:
            volume = (volume - p1) / (p99 - p1)
        else:
            volume = np.zeros_like(volume)
        
        # Resize to target shape
        if volume.shape != self.target_size:
            zoom_factors = [self.target_size[i] / volume.shape[i] for i in range(3)]
            volume = ndimage.zoom(volume, zoom_factors, order=1)
        
        # Ensure correct data type and range
        volume = np.clip(volume, 0, 1).astype(np.float32)
        
        return volume

# TPU-OPTIMIZED DATASET
class TPUSegmentationDataset(Dataset):
    def __init__(self, df, series_dir, processor, mode='train'):
        self.df = df.copy()  # Avoid reference issues
        self.series_dir = series_dir
        self.processor = processor
        self.mode = mode
        
        print(f"Created {mode} dataset with {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        """Get item with robust error handling for TPU"""
        max_retries = 2
        
        for attempt in range(max_retries + 1):
            try:
                return self._get_item_safe(idx)
            except Exception as e:
                if attempt == max_retries:
                    print(f"Failed to load item {idx} after {max_retries} retries")
                    return self._get_dummy_item()
                else:
                    print(f"Retry {attempt + 1} for item {idx}: {str(e)[:50]}")
                    time.sleep(0.1)  # Brief pause before retry
        
        return self._get_dummy_item()
    
    def _get_item_safe(self, idx):
        """Safe item loading with timeout"""
        with timeout(120):  # 2 minute timeout per item
            row = self.df.iloc[idx]
            series_id = row[TPUConfig.ID_COL]
            series_path = os.path.join(self.series_dir, series_id)
            
            # Load volume
            volume = self.processor.load_dicom_series(series_path)
            
            # Create simple mask (placeholder - in real scenario you'd load actual segmentation)
            mask = self._create_dummy_mask(volume.shape, int(row[TPUConfig.TARGET_COL]))
            
            # Get classification label
            has_aneurysm = float(row[TPUConfig.TARGET_COL])
            
            # Convert to tensors with correct dtype for TPU
            volume_tensor = torch.from_numpy(volume).float().unsqueeze(0)  #channel dim
            mask_tensor = torch.from_numpy(mask).float().unsqueeze(0)
            
            # Clean up numpy arrays
            del volume, mask
            
            return {
                'volume': volume_tensor,
                'mask': mask_tensor,
                'has_aneurysm': torch.tensor(has_aneurysm, dtype=torch.float32),
                'series_id': series_id
            }
    
    def _create_dummy_mask(self, shape, has_aneurysm):
        mask = np.zeros(shape, dtype=np.float32)
        
        if has_aneurysm:
            # Create a small central region as "aneurysm"
            h, w, d = shape
            center_h, center_w, center_d = h//2, w//2, d//2
            size = min(h, w, d) // 4
            
            mask[
                max(0, center_h-size//2):min(h, center_h+size//2),
                max(0, center_w-size//2):min(w, center_w+size//2),
                max(0, center_d-size//2):min(d, center_d+size//2)
            ] = 1.0
        
        return mask
    
    def _get_dummy_item(self):
        """Return dummy item for failed loads"""
        target_size = TPUConfig.STAGE1_TARGET_SIZE
        return {
            'volume': torch.zeros((1, *target_size), dtype=torch.float32),
            'mask': torch.zeros((1, *target_size), dtype=torch.float32),
            'has_aneurysm': torch.tensor(0.0, dtype=torch.float32),
            'series_id': "DUMMY_FAIL"
        }

# TPU-OPTIMIZED MODEL
class TPUSegmentationModel(nn.Module):
    def __init__(self, in_channels=1, seg_channels=1):
        super().__init__()
        
        # Backbone U-Net
        self.backbone = TPUOptimized3DUNet(
            spatial_dims=3,
            in_channels=in_channels,
            out_channels=32,  # Feature channels
            features=(16, 32, 64, 32),  # TPU-friendly sizes
            dropout=0.1
        )
        
        # Segmentation head
        self.seg_head = nn.Conv3d(32, seg_channels, kernel_size=1, bias=True)
        
        # Classification head - for TPU
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )
        
        # Initialize for TPU stability
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Get features from backbone
        features = self.backbone(x)
        
        # Segmentation output
        seg_logits = self.seg_head(features)
        
        # Classification output
        pooled_features = self.global_pool(features).flatten(1)
        cls_logits = self.classifier(pooled_features)
        
        return seg_logits, cls_logits

# TPU-OPTIMIZED TRAINING FUNCTIONS
def tpu_train_epoch(model, loader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    num_batches = 0
    
    # Create progress bar
    pbar = tqdm(loader, desc=f"Epoch {epoch+1} Training")
    
    for batch_idx, batch in enumerate(pbar):
        try:
            with timeout(300):  # 5-minute timeout per batch
                # Move to device
                volume = batch['volume'].to(device, non_blocking=True)
                mask = batch['mask'].to(device, non_blocking=True)
                has_aneurysm = batch['has_aneurysm'].to(device, non_blocking=True)
                
                # Clear gradients
                optimizer.zero_grad()
                
                # Forward pass
                seg_logits, cls_logits = model(volume)
                
                # Calculate losses
                seg_loss = nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
                cls_loss = nn.functional.binary_cross_entropy_with_logits(
                    cls_logits.view(-1), has_aneurysm
                )
                total_loss_batch = seg_loss + cls_loss
                
                # Backward pass
                total_loss_batch.backward()
                
                # Gradient clipping for TPU stability
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                # TPU-specific step
                if TPU_AVAILABLE:
                    xm.optimizer_step(optimizer)  # TPU-optimized step
                    if batch_idx % TPUConfig.TPU_SYNC_FREQUENCY == 0:
                        xm.mark_step()  # Sync TPU cores
                else:
                    optimizer.step()
                
                # Update metrics
                total_loss += total_loss_batch.item()
                num_batches += 1
                
                # Update progress bar
                pbar.set_postfix({
                    'Loss': f'{total_loss_batch.item():.4f}',
                    'Avg': f'{total_loss/num_batches:.4f}'
                })
                
                # Memory cleanup for long training
                if batch_idx % 20 == 0:
                    if TPU_AVAILABLE:
                        pass
                    else:
                        torch.cuda.empty_cache()
                        gc.collect()
        
        except TimeoutError:
            print(f"Batch {batch_idx} timed out, skipping...")
            continue
        except Exception as e:
            print(f"Error in batch {batch_idx}: {str(e)[:100]}")
            continue
    
    # Final TPU sync
    if TPU_AVAILABLE:
        xm.mark_step()
    
    avg_loss = total_loss / max(num_batches, 1)
    return avg_loss


def tpu_validate_epoch(model, loader, device, epoch):
    model.eval()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch+1} Validation")
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(pbar):
            try:
                with timeout(180):  # 3-minute timeout per validation batch
                    # Move to device
                    volume = batch['volume'].to(device, non_blocking=True)
                    mask = batch['mask'].to(device, non_blocking=True)
                    has_aneurysm = batch['has_aneurysm'].to(device, non_blocking=True)
                    
                    # Forward pass
                    seg_logits, cls_logits = model(volume)
                    
                    # Calculate losses
                    seg_loss = nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
                    cls_loss = nn.functional.binary_cross_entropy_with_logits(
                        cls_logits.view(-1), has_aneurysm
                    )
                    total_loss_batch = seg_loss + cls_loss
                    
                    total_loss += total_loss_batch.item()
                    num_batches += 1
                    
                    # Update progress bar
                    pbar.set_postfix({
                        'Val Loss': f'{total_loss_batch.item():.4f}',
                        'Avg': f'{total_loss/num_batches:.4f}'
                    })
            
            except TimeoutError:
                print(f"Validation batch {batch_idx} timed out, skipping...")
                continue
            except Exception as e:
                print(f"Validation error in batch {batch_idx}: {str(e)[:100]}")
                continue
    
    # Final TPU sync
    if TPU_AVAILABLE:
        xm.mark_step()
    
    avg_loss = total_loss / max(num_batches, 1)
    return avg_loss

# MAIN TPU TRAINING FUNCTION
def main_tpu_training():
    print(f"TPU BRAIN ANEURYSM TRAINING")
    print(f"{'='*60}")
    print(f"Device: {TPUConfig.DEVICE}")
    print(f"TPU Available: {TPU_AVAILABLE}")
    
    if TPU_AVAILABLE:
        print(f"TPU Cores: {xm.xrt_world_size()}")
        print(f"Current TPU Core: {xm.get_ordinal()}")
    
    try:
        # Load and prepare data
        print("Loading training data...")
        train_df = pd.read_csv(TPUConfig.TRAIN_CSV_PATH)
        
        if TPUConfig.DEBUG_MODE:
            train_df = train_df.head(TPUConfig.DEBUG_SAMPLES)
            print(f"Debug mode: using {len(train_df)} samples")
        
        print(f"Training samples: {len(train_df)}")
        print(f"Positive cases: {train_df[TPUConfig.TARGET_COL].sum()}")
        
        # Simple train/val split
        val_size = max(10, len(train_df) // 10)  # At least 10 samples for validation
        val_df = train_df[:val_size].copy().reset_index(drop=True)
        train_df = train_df[val_size:].copy().reset_index(drop=True)
        
        print(f"Train: {len(train_df)}, Val: {len(val_df)}")
        
        # Create datasets
        processor = TPUDICOMProcessor()
        train_dataset = TPUSegmentationDataset(train_df, TPUConfig.SERIES_DIR, processor, 'train')
        val_dataset = TPUSegmentationDataset(val_df, TPUConfig.SERIES_DIR, processor, 'val')
        
        #data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=TPUConfig.STAGE1_BATCH_SIZE,
            shuffle=True,
            num_workers=0,
            pin_memory=False,
            drop_last=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=TPUConfig.STAGE1_BATCH_SIZE,
            shuffle=False,
            num_workers=0,
            pin_memory=False,
            drop_last=False
        )
        
        print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")
        
        # Create model
        print("Creating TPU-optimized model...")
        model = TPUSegmentationModel().to(TPUConfig.DEVICE)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})")
        
        # Optimizer - TPU optimized
        optimizer = optim.AdamW(
            model.parameters(), 
            lr=TPUConfig.STAGE1_LR,
            weight_decay=1e-4,
            eps=1e-8  # TPU-friendly epsilon
        )
        
        # Training loop with comprehensive hang prevention
        best_val_loss = float('inf')
        patience_counter = 0
        patience_limit = 3
        
        print(f"\nStarting training for {TPUConfig.STAGE1_EPOCHS} epochs...")
        
        for epoch in range(TPUConfig.STAGE1_EPOCHS):
            print(f"\n{'='*20} EPOCH {epoch+1}/{TPUConfig.STAGE1_EPOCHS} {'='*20}")
            
            epoch_start_time = time.time()
            
            try:
                with timeout(TPUConfig.MAX_TIMEOUT_SECONDS):
                    # Training phase
                    print("Training phase...")
                    train_loss = tpu_train_epoch(model, train_loader, optimizer, TPUConfig.DEVICE, epoch)
                    
                    # Validation phase
                    print("Validation phase...")
                    val_loss = tpu_validate_epoch(model, val_loader, TPUConfig.DEVICE, epoch)
                    
                    epoch_time = time.time() - epoch_start_time
                    
                    print(f"\nEpoch {epoch+1} Results:")
                    print(f"   Train Loss: {train_loss:.4f}")
                    print(f"   Val Loss: {val_loss:.4f}")
                    print(f"   Time: {epoch_time:.1f}s")
                    
                    # Save best model
                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        patience_counter = 0
                        
                        checkpoint = {
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'train_loss': train_loss,
                            'val_loss': val_loss,
                            'best_val_loss': best_val_loss
                        }
                        
                        torch.save(checkpoint, 'tpu_aneurysm_best.pth')
                        print(f"Saved best model (val_loss: {val_loss:.4f})")
                    else:
                        patience_counter += 1
                        print(f"Patience: {patience_counter}/{patience_limit}")
                    
                    # Regular checkpoint saving
                    if (epoch + 1) % TPUConfig.CHECKPOINT_FREQUENCY == 0:
                        checkpoint = {
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'train_loss': train_loss,
                            'val_loss': val_loss
                        }
                        torch.save(checkpoint, f'tpu_aneurysm_epoch_{epoch+1}.pth')
                        print(f"Saved checkpoint at epoch {epoch+1}")
                    
                    # Early stopping
                    if patience_counter >= patience_limit:
                        print(f"Early stopping at epoch {epoch+1}")
                        break
                    
                    # Memory cleanup
                    if TPU_AVAILABLE:
                        # Automatic TPU memory Management
                        xm.mark_step()
                    else:
                        torch.cuda.empty_cache()
                        gc.collect()
            
            except TimeoutError:
                print(f"Epoch {epoch+1} timed out after {TPUConfig.MAX_TIMEOUT_SECONDS}s")
                print("Saving emergency checkpoint...")
                
                emergency_checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'status': 'timeout_interrupted'
                }
                torch.save(emergency_checkpoint, f'tpu_aneurysm_emergency_epoch_{epoch+1}.pth')
                
                # Try to continue or break based on severity
                if epoch >= 2:  # If progress is made
                    print("Attempting to continue training...")
                    continue
                else:
                    print("Early epoch timeout, stopping training")
                    break
                    
            except Exception as e:
                print(f"Unexpected error in epoch {epoch+1}: {str(e)}")
                print("Saving emergency checkpoint...")
                
                try:
                    emergency_checkpoint = {
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'status': 'error_interrupted',
                        'error': str(e)
                    }
                    torch.save(emergency_checkpoint, f'tpu_aneurysm_error_epoch_{epoch+1}.pth')
                except:
                    print("Failed to save emergency checkpoint")
                
                # Decide whether to continue or stop
                if epoch >= 2:
                    print("Attempting to recover and continue...")
                    time.sleep(10)  # Brief recovery pause
                    continue
                else:
                    print("Critical error in early training, stopping")
                    break
        
        print(f"\nTraining completed!")
        print(f"Best validation loss: {best_val_loss:.4f}")
        
        # Final model save
        final_checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'total_epochs': epoch + 1,
            'status': 'completed'
        }
        torch.save(final_checkpoint, 'tpu_aneurysm_final.pth')
        print("Saved final model checkpoint")
        
        return model, best_val_loss
        
    except Exception as e:
        print(f"Critical training error: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None


In [None]:
# RUN TRAINING
model, best_loss = main_tpu_training()
print("Expected training time: ? hours")
print("Output: tpu_aneurysm_final.pth")

**Model Evaluation and Performance**

In [None]:
def quick_tpu_evaluation(model_path='tpu_aneurysm_best.pth'):
    print("TPU MODEL EVALUATION")
    print("="*50)
    
    try:
        # Load test data
        train_df = pd.read_csv(TPUConfig.TRAIN_CSV_PATH)
        
        # Create small test set
        test_size = min(50, len(train_df) // 10)  # Small for quick eval
        test_df = train_df.sample(n=test_size, random_state=42).reset_index(drop=True)
        
        print(f"Test set: {len(test_df)} samples")
        print(f"Positive cases: {test_df[TPUConfig.TARGET_COL].sum()}")
        
        # Load model
        model = TPUSegmentationModel().to(TPUConfig.DEVICE)
        
        if os.path.exists(model_path):
            print(f"Loading model from {model_path}")
            checkpoint = torch.load(model_path, map_location=TPUConfig.DEVICE)
            
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
                print(f"Best val loss: {checkpoint.get('best_val_loss', 'unknown')}")
            else:
                model.load_state_dict(checkpoint)
        else:
            print(f"Model file {model_path} not found, using random weights")
        
        # Create test dataset and loader
        processor = TPUDICOMProcessor()
        test_dataset = TPUSegmentationDataset(test_df, TPUConfig.SERIES_DIR, processor, 'test')
        test_loader = DataLoader(
            test_dataset, 
            batch_size=1,  # Small batch for evaluation
            shuffle=False, 
            num_workers=0
        )
        
        # Evaluation
        model.eval()
        predictions = []
        probabilities = []
        true_labels = []
        
        print("Running evaluation...")
        with torch.no_grad():
            for batch_idx, batch in enumerate(tqdm(test_loader, desc="Evaluating")):
                try:
                    with timeout(60):  # 1-minute timeout per batch
                        volume = batch['volume'].to(TPUConfig.DEVICE, non_blocking=True)
                        true_label = batch['has_aneurysm'].to(TPUConfig.DEVICE, non_blocking=True)
                        
                        # Forward pass
                        seg_logits, cls_logits = model(volume)
                        
                        # Get predictions
                        prob = torch.sigmoid(cls_logits).cpu().numpy()[0]
                        pred = 1 if prob > 0.5 else 0
                        
                        probabilities.append(prob)
                        predictions.append(pred)
                        true_labels.append(true_label.cpu().numpy()[0])
                        
                        # Periodic TPU sync
                        if TPU_AVAILABLE and batch_idx % 10 == 0:
                            xm.mark_step()
                
                except TimeoutError:
                    print(f"Evaluation batch {batch_idx} timed out")
                    continue
                except Exception as e:
                    print(f"Error in evaluation batch {batch_idx}: {str(e)[:50]}")
                    continue
        
        # Calculate metrics
        predictions = np.array(predictions)
        probabilities = np.array(probabilities)
        true_labels = np.array(true_labels)
        
        if len(predictions) > 0:
            from sklearn.metrics import accuracy_score, roc_auc_score, classification_report
            
            accuracy = accuracy_score(true_labels, predictions)
            
            if len(np.unique(true_labels)) > 1:  # Check if both classes are available 
                auc = roc_auc_score(true_labels, probabilities)
                print(f"AUC-ROC: {auc:.3f}")
            else:
                print("Only one class in test set, cannot calculate AUC")
            
            print(f"Accuracy: {accuracy:.3f}")
            print(f"Predictions made: {len(predictions)}")
            
            print("\nClassification Report:")
            print(classification_report(true_labels, predictions, target_names=['No Aneurysm', 'Aneurysm']))
            
            # Sample predictions
            print(f"\nSample Predictions:")
            for i in range(min(5, len(predictions))):
                status = "✅" if predictions[i] == true_labels[i] else "❌"
                print(f"{status} True: {int(true_labels[i])}, Pred: {int(predictions[i])}, Prob: {probabilities[i]:.3f}")
        
        else:
            print("No successful predictions made")
        
        print(f"\nEvaluation completed!")
        
        return {
            'predictions': predictions,
            'probabilities': probabilities,
            'true_labels': true_labels,
            'accuracy': accuracy_score(true_labels, predictions) if len(predictions) > 0 else 0
        }
        
    except Exception as e:
        print(f"Evaluation error: {str(e)}")
        import traceback
        traceback.print_exc()
        return None
