In [1]:
!pip install torch torchvision

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curan

In [2]:
import numpy as np
import torch
from torch.utils.data import Dataset
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import gc

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.cuda.empty_cache()

In [3]:
class SegmentationData(Dataset):    
    def __init__(self, image_paths, mask_paths, metadata, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.metadata = metadata.reset_index(drop=True)  # Ensure clean index
        self.transforms = transforms

    def __len__(self):
        return len(self.metadata)

    def get_image_path(self, image_id):
        return f"{self.image_paths}{image_id}"

    def get_mask_path(self, mask_id):
        return f"{self.mask_paths}{mask_id}"

    def __getitem__(self, idx):
        # Get file paths
        image_path = self.get_image_path(self.metadata.iloc[idx]['image'])
        mask_path = self.get_mask_path(self.metadata.iloc[idx]['mask'])

        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Failed to load image: {image_path}")
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(mask_path, 0)
        if mask is None:
            raise ValueError(f"Failed to load mask: {mask_path}")

        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
            # Ensure mask has channel dimension (1, H, W) and is float
            if mask.dim() == 2:
                mask = mask.unsqueeze(0)
            mask = mask.float() / 255.0  # Normalize to [0, 1] and ensure float
        else:
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            mask = torch.from_numpy(mask).unsqueeze(0).float() / 255.0

        return image, mask
            
    def train_test_split(self, train_ratio=0.8):
        total_samples = len(self.metadata)
        train_size = int(total_samples * train_ratio)
        
        train_metadata = self.metadata.iloc[:train_size].reset_index(drop=True)
        test_metadata = self.metadata.iloc[train_size:].reset_index(drop=True)

        train_dataset = SegmentationData(
            image_paths=self.image_paths,
            mask_paths=self.mask_paths,
            metadata=train_metadata,
            transforms=self.transforms
        )

        test_dataset = SegmentationData(
            image_paths=self.image_paths,
            mask_paths=self.mask_paths,
            metadata=test_metadata,
            transforms=self.transforms
        )

        return train_dataset, test_dataset


def training_transforms():
    return A.Compose([
        A.Resize(256, 256),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=0.0625,
            scale_limit=0.1,
            rotate_limit=45,
            p=0.5
        ),
        A.OneOf([
            A.ElasticTransform(p=0.3),
            A.GridDistortion(p=0.3),
            A.OpticalDistortion(p=0.3),
        ], p=0.3),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2(),
    ])


def validation_transforms():
    return A.Compose([
        A.Resize(256, 256),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        ),
        ToTensorV2(),
    ])

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import timm


class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvolutionalBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        return x


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv_block = ConvolutionalBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        skip = self.conv_block(x)
        pooled = self.pool(skip)
        return skip, pooled


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = ConvolutionalBlock(out_channels + skip_channels, out_channels)

    def forward(self, x, skip_connection):
        x = self.upconv(x)
        x = torch.cat([x, skip_connection], dim=1)
        x = self.conv_block(x)
        return x


class ResNetFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super(ResNetFeatureExtractor, self).__init__()
        resnet = models.resnet50(weights='IMAGENET1K_V2' if pretrained else None)
        
        # Extract features from layer2 (good mid-level features)
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,  # 256 channels
            resnet.layer2,  # 512 channels
        )
        self.out_channels = 512
        self.adaptive_pool = nn.AdaptiveAvgPool2d((16, 16))
    
    def forward(self, x):
        features = self.features(x)
        features = self.adaptive_pool(features)
        return features


class VisionTransformerFeatures(nn.Module):
    def __init__(self, img_size=256, output_channels=3):
        super(VisionTransformerFeatures, self).__init__()
        
        self.vit = timm.create_model('vit_tiny_patch16_224', pretrained=False, img_size=img_size)
        vit_features = self.vit.embed_dim
        
        self.projection = nn.Sequential(
            nn.Linear(vit_features, 256),
            nn.ReLU(),
            nn.Linear(256, output_channels * 16 * 16)
        )
        self.output_channels = output_channels
    
    def forward(self, x):
        B = x.size(0)
        
        vit_out = self.vit.forward_features(x)
        vit_out = vit_out.mean(dim=1)
        
        vit_out = self.projection(vit_out)
        vit_out = vit_out.view(B, self.output_channels, 16, 16)
        
        return vit_out


class HVUEArchitecture(nn.Module):
    def __init__(self, num_classes=4, img_size=256, pretrained=True):
        super(HVUEArchitecture, self).__init__()
        
        self.num_classes = num_classes
        
        # ResNet feature extractor (upgraded from DenseNet)
        self.resnet_features = ResNetFeatureExtractor(pretrained=pretrained)
        resnet_channels = self.resnet_features.out_channels
        
        # Vision Transformer
        self.vit = VisionTransformerFeatures(img_size=img_size, output_channels=3)
        vit_channels = self.vit.output_channels
        
        # U-Net Encoder (3 channels for RGB images)
        self.enc1 = EncoderBlock(3, 64)  # Changed from 2 to 3 for RGB
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 512)
        
        # Bottleneck
        self.bottleneck = ConvolutionalBlock(512, 1024)
        
        # Feature fusion at bottleneck
        fusion_channels = 1024 + resnet_channels + vit_channels
        
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(fusion_channels, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        )
        
        # U-Net Decoder
        self.dec1 = DecoderBlock(1024, 512, 512)
        self.dec2 = DecoderBlock(512, 256, 256)
        self.dec3 = DecoderBlock(256, 128, 128)
        self.dec4 = DecoderBlock(128, 64, 64)
        
        # Output layer
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
    
    def forward(self, x):
        # U-Net Encoder with skip connections
        s1, p1 = self.enc1(x)
        s2, p2 = self.enc2(p1)
        s3, p3 = self.enc3(p2)
        s4, p4 = self.enc4(p3)
        
        # Bottleneck
        b = self.bottleneck(p4)
        
        # ResNet features (upgraded from DenseNet)
        resnet_feat = self.resnet_features(x)
        
        # Vision Transformer features
        vit_feat = self.vit(x)
        
        # Concatenate all features at bottleneck
        fused = torch.cat([b, resnet_feat, vit_feat], dim=1)
        
        # Reduce channels through fusion
        fused = self.fusion_conv(fused)
        
        # U-Net Decoder with skip connections
        d1 = self.dec1(fused, s4)
        d2 = self.dec2(d1, s3)
        d3 = self.dec3(d2, s2)
        d4 = self.dec4(d3, s1)
        
        # Output (logits)
        out = self.final_conv(d4)
        
        return out



In [5]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import gc


def dice_coef_binary(y_true, y_pred, smooth=1e-6):
    y_true_f = y_true.contiguous().view(-1)
    y_pred_f = y_pred.contiguous().view(-1)
    
    intersection = torch.sum(y_true_f * y_pred_f)
    dice = (2.0 * intersection + smooth) / (torch.sum(y_true_f) + torch.sum(y_pred_f) + smooth)
    
    return dice


def train(train_dataset, val_dataset, model, epochs=50, batch_size=2, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if torch.cuda.is_available():
        num_workers = 2
        pin_memory = True
        prefetch_factor = 2
    else:
        num_workers = 0
        pin_memory = False
        prefetch_factor = None
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True if num_workers > 0 else False,
        prefetch_factor=prefetch_factor
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True if num_workers > 0 else False,
        prefetch_factor=prefetch_factor
    )
    
    model = model.to(device)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-7, verbose=True
    )
    
    use_amp = torch.cuda.is_available()
    scaler = torch.amp.GradScaler('cuda') if use_amp else None
    
    print(f"\n{'='*80}")
    print(f"Training Configuration:")
    print(f"{'='*80}")
    print(f"  Device: {device}")
    print(f"  Mixed Precision (AMP): {use_amp}")
    print(f"  Num Workers: {num_workers}")
    print(f"  Batch Size: {batch_size} (REDUCED for memory safety)")
    print(f"  Train samples: {len(train_dataset)}")
    print(f"  Validation samples: {len(val_dataset)}")
    print(f"  Total epochs: {epochs}")
    if torch.cuda.is_available():
        print(f"  GPU Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"{'='*80}\n")
    
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_dice = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs} [Train]')
        for batch_idx, (images, masks) in enumerate(train_pbar):
            images = images.to(device, non_blocking=True)
            masks = masks.to(device, non_blocking=True)
            
            # Ensure masks have correct shape (B, 1, H, W) and are float
            if masks.dim() == 3:
                masks = masks.unsqueeze(1)
            masks = masks.float()
            
            optimizer.zero_grad(set_to_none=True)
            
            if use_amp:
                with torch.amp.autocast('cuda'):
                    outputs = model(images)
                    loss = criterion(outputs, masks)
                    
                    # Compute dice for monitoring
                    probs = torch.sigmoid(outputs)
                    dice = dice_coef_binary(masks, probs)
                
                scaler.scale(loss).backward()
                
                # Gradient clipping to prevent explosion
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                probs = torch.sigmoid(outputs)
                dice = dice_coef_binary(masks, probs)
                
                loss.backward()
                
                # Gradient clipping to prevent explosion
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
            
            # Check for NaN
            if torch.isnan(loss) or torch.isnan(dice):
                print(f"\n⚠ WARNING: NaN detected at batch {batch_idx}. Skipping batch...")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()
                continue
            
            train_loss += loss.item() * images.size(0)
            train_dice += dice.item() * images.size(0)
            
            train_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'dice': f'{dice.item():.4f}'
            })
            
            if batch_idx % 10 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
                
        
        avg_train_loss = train_loss / len(train_dataset)
        avg_train_dice = train_dice / len(train_dataset)
        
        model.eval()
        val_loss = 0.0
        val_dice = 0.0
        
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{epochs} [Val]')
            for images, masks in val_pbar:
                try:
                    images = images.to(device, non_blocking=True)
                    masks = masks.to(device, non_blocking=True)
                    
                    # Ensure masks have correct shape (B, 1, H, W) and are float
                    if masks.dim() == 3:
                        masks = masks.unsqueeze(1)
                    masks = masks.float()
                    
                    # Forward pass
                    if use_amp:
                        with torch.amp.autocast('cuda'):
                            outputs = model(images)
                            loss = criterion(outputs, masks)
                            probs = torch.sigmoid(outputs)
                            dice = dice_coef_binary(masks, probs)
                    else:
                        outputs = model(images)
                        loss = criterion(outputs, masks)
                        probs = torch.sigmoid(outputs)
                        dice = dice_coef_binary(masks, probs)
                    
                    # Accumulate metrics
                    val_loss += loss.item() * images.size(0)
                    val_dice += dice.item() * images.size(0)
                    
                    val_pbar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'dice': f'{dice.item():.4f}'
                    })
                    
                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"\n⚠ WARNING: OOM during validation. Skipping batch...")
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                        continue
                    else:
                        raise e
        
        avg_val_loss = val_loss / len(val_dataset)
        avg_val_dice = val_dice / len(val_dataset)
        
        scheduler.step(avg_val_loss)
        
        print(f"\n{'='*80}")
        print(f"Epoch {epoch+1}/{epochs} Summary:")
        print(f"{'='*80}")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Dice: {avg_train_dice:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f} | Val Dice: {avg_val_dice:.4f}")
        
        if torch.cuda.is_available():
            print(f"  GPU Memory: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB / {torch.cuda.max_memory_allocated(0) / 1e9:.2f} GB (peak)")
            torch.cuda.reset_peak_memory_stats()
        
        print(f"{'='*80}\n")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': avg_val_loss,
                'val_dice': avg_val_dice,
            }, 'best_model_hvue.pth')
            print(f"  ✓ Best model saved (Val Loss: {avg_val_loss:.4f})")
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
    
    print("\n✓ Training completed!")
    return model


def evaluate(model, test_dataset, batch_size=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if torch.cuda.is_available():
        num_workers = 2
        pin_memory = True
    else:
        num_workers = 0
        pin_memory = False
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=True if num_workers > 0 else False
    )
    
    model = model.to(device)
    model.eval()
    
    total_dice = 0.0
    total_iou = 0.0
    total_pixel_acc = 0.0
    total_precision = 0.0
    total_recall = 0.0
    total_f1 = 0.0
    total_specificity = 0.0
    total_samples = 0
    
    # For AUC-ROC calculation
    all_probs = []
    all_labels = []
    
    print(f"\n{'='*80}")
    print(f"Evaluating on {len(test_dataset)} samples...")
    print(f"{'='*80}\n")
    
    use_amp = torch.cuda.is_available()
    
    with torch.no_grad():
        for images, masks in tqdm(test_loader, desc='Evaluating'):
            try:
                images = images.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)
                
                # Ensure masks have correct shape (B, 1, H, W) and are float
                if masks.dim() == 3:
                    masks = masks.unsqueeze(1)
                masks = masks.float()
                
                # Forward pass
                if use_amp:
                    with torch.amp.autocast('cuda'):
                        outputs = model(images)
                        probs = torch.sigmoid(outputs)
                else:
                    outputs = model(images)
                    probs = torch.sigmoid(outputs)
                
                # Binary predictions (threshold at 0.5)
                preds = (probs > 0.5).float()
                
                # Flatten for metric calculation
                masks_flat = masks.view(-1)
                preds_flat = preds.view(-1)
                probs_flat = probs.view(-1)
                
                # Compute metrics
                # 1. Dice Coefficient (using binary predictions for consistency)
                dice = dice_coef_binary(masks, preds)  # Changed from probs to preds
                
                # 2. IoU (Intersection over Union / Jaccard Index)
                intersection = torch.sum(masks_flat * preds_flat)
                union = torch.sum(masks_flat) + torch.sum(preds_flat) - intersection
                iou = (intersection + 1e-6) / (union + 1e-6)
                
                # 3. Pixel Accuracy
                pixel_acc = torch.sum(preds_flat == masks_flat) / masks_flat.numel()
                
                # 4. Precision, Recall, F1
                true_positives = torch.sum(masks_flat * preds_flat)
                false_positives = torch.sum((1 - masks_flat) * preds_flat)
                false_negatives = torch.sum(masks_flat * (1 - preds_flat))
                true_negatives = torch.sum((1 - masks_flat) * (1 - preds_flat))
                
                precision = (true_positives + 1e-6) / (true_positives + false_positives + 1e-6)
                recall = (true_positives + 1e-6) / (true_positives + false_negatives + 1e-6)
                f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
                
                # 5. Specificity
                specificity = (true_negatives + 1e-6) / (true_negatives + false_positives + 1e-6)
                
                # Skip if NaN
                if torch.isnan(dice) or torch.isnan(iou):
                    print(f"\n⚠ WARNING: NaN detected in evaluation. Skipping batch...")
                    continue
                
                # Accumulate metrics
                total_dice += dice.item() * images.size(0)
                total_iou += iou.item() * images.size(0)
                total_pixel_acc += pixel_acc.item() * images.size(0)
                total_precision += precision.item() * images.size(0)
                total_recall += recall.item() * images.size(0)
                total_f1 += f1.item() * images.size(0)
                total_specificity += specificity.item() * images.size(0)
                total_samples += images.size(0)
                
                # Collect for AUC-ROC (sample to avoid memory issues)
                if len(all_probs) < 10000:  # Limit samples for AUC calculation
                    all_probs.append(probs_flat.cpu())
                    all_labels.append(masks_flat.cpu())
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print(f"\n⚠ WARNING: OOM during evaluation. Skipping batch...")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e
    
    # Calculate averages
    avg_dice = total_dice / total_samples if total_samples > 0 else 0
    avg_iou = total_iou / total_samples if total_samples > 0 else 0
    avg_pixel_acc = total_pixel_acc / total_samples if total_samples > 0 else 0
    avg_precision = total_precision / total_samples if total_samples > 0 else 0
    avg_recall = total_recall / total_samples if total_samples > 0 else 0
    avg_f1 = total_f1 / total_samples if total_samples > 0 else 0
    avg_specificity = total_specificity / total_samples if total_samples > 0 else 0
    # Print results
    print(f"\n{'='*80}")
    print("COMPREHENSIVE EVALUATION RESULTS")
    print(f"{'='*80}")
    print(f"\n Primary Metrics:")
    print(f"  Dice Coefficient:        {avg_dice:.4f}  (Overlap-based metric)")
    print(f"  IoU (Jaccard Index):     {avg_iou:.4f}  (Intersection over Union)")
    print(f"  Pixel Accuracy:          {avg_pixel_acc:.4f}  (Correct pixels / Total pixels)")
    
    print(f"\n Classification Metrics:")
    print(f"  Precision:               {avg_precision:.4f}  (TP / (TP + FP))")
    print(f"  Recall (Sensitivity):    {avg_recall:.4f}  (TP / (TP + FN))")
    print(f"  F1 Score:                {avg_f1:.4f}  (Harmonic mean of Precision & Recall)")
    print(f"  Specificity:             {avg_specificity:.4f}  (TN / (TN + FP))")
        
    print(f"{'='*80}\n")
    
    return {
        'dice': avg_dice,
        'iou': avg_iou,
        'pixel_accuracy': avg_pixel_acc,
        'precision': avg_precision,
        'recall': avg_recall,
        'f1_score': avg_f1,
        'specificity': avg_specificity,
    }

In [6]:
import pandas as pd

if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

base_path = "/kaggle/input/augmented-forest-segmentation/Forest Segmented/Forest Segmented/"
image_paths = base_path + 'images/'
mask_paths = base_path + 'masks/'
metadata_path = base_path + 'meta_data.csv'

print("\n[1] Loading metadata...")
metadata = pd.read_csv(metadata_path)
print(f"    Total samples: {len(metadata)}")


[1] Loading metadata...
    Total samples: 5108


In [7]:
print("\n[2] Creating dataset (on-demand loading, NO preloading)...")
dataset = SegmentationData(
    image_paths=image_paths,
    mask_paths=mask_paths,
    metadata=metadata,
    transforms=None
)
print("    ✓ Dataset created (images will load on-the-fly)")



[2] Creating dataset (on-demand loading, NO preloading)...
    ✓ Dataset created (images will load on-the-fly)


In [8]:
print("\n[3] Splitting dataset...")
train_dataset, test_dataset = dataset.train_test_split(train_ratio=0.7)

train_dataset.transforms = training_transforms()
test_dataset.transforms = validation_transforms()

print(f"    Train samples: {len(train_dataset)}")
print(f"    Test samples: {len(test_dataset)}")



[3] Splitting dataset...
    Train samples: 3575
    Test samples: 1533


  original_init(self, **validated_kwargs)


In [9]:
print("\n[4] Creating model with ResNet50 backbone...")
model = HVUEArchitecture(
    num_classes=1,
    pretrained=True
)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"    ✓ Using ResNet50 (upgraded from DenseNet)")
print(f"    ✓ Pretrained on ImageNet: Yes")
print(f"    Total parameters: {total_params:,}")
print(f"    Trainable parameters: {trainable_params:,}")


[4] Creating model with ResNet50 backbone...


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 187MB/s]


    ✓ Using ResNet50 (upgraded from DenseNet)
    ✓ Pretrained on ImageNet: Yes
    Total parameters: 62,090,921
    Trainable parameters: 62,090,921


In [10]:
print("\n[5] Training model with MEMORY-SAFE settings...")

trained_model = train(
    train_dataset=train_dataset,
    val_dataset=test_dataset,
    model=model,
    epochs=5,
    batch_size=2,
    learning_rate=0.0001
)


[5] Training model with MEMORY-SAFE settings...





Training Configuration:
  Device: cuda
  Mixed Precision (AMP): True
  Num Workers: 2
  Batch Size: 2 (REDUCED for memory safety)
  Train samples: 3575
  Validation samples: 1533
  Total epochs: 5
  GPU Memory Available: 15.83 GB



Epoch 1/5 [Train]: 100%|██████████| 1788/1788 [03:28<00:00,  8.58it/s, loss=0.8789, dice=0.7177]
Epoch 1/5 [Val]: 100%|██████████| 767/767 [00:25<00:00, 30.04it/s, loss=0.4296, dice=0.1425]



Epoch 1/5 Summary:
  Train Loss: 0.5215 | Train Dice: 0.6741
  Val Loss: 0.4502 | Val Dice: 0.7017
  GPU Memory: 1.02 GB / 14.91 GB (peak)

  ✓ Best model saved (Val Loss: 0.4502)


Epoch 2/5 [Train]: 100%|██████████| 1788/1788 [03:24<00:00,  8.73it/s, loss=0.5261, dice=0.8362]
Epoch 2/5 [Val]: 100%|██████████| 767/767 [00:24<00:00, 31.67it/s, loss=0.3245, dice=0.1606]



Epoch 2/5 Summary:
  Train Loss: 0.4651 | Train Dice: 0.7142
  Val Loss: 0.4715 | Val Dice: 0.6845
  GPU Memory: 1.02 GB / 1.44 GB (peak)



Epoch 3/5 [Train]: 100%|██████████| 1788/1788 [03:25<00:00,  8.70it/s, loss=0.2558, dice=0.8524]
Epoch 3/5 [Val]: 100%|██████████| 767/767 [00:24<00:00, 31.76it/s, loss=0.6687, dice=0.1031]



Epoch 3/5 Summary:
  Train Loss: 0.4400 | Train Dice: 0.7326
  Val Loss: 0.4159 | Val Dice: 0.7297
  GPU Memory: 1.02 GB / 1.44 GB (peak)

  ✓ Best model saved (Val Loss: 0.4159)


Epoch 4/5 [Train]: 100%|██████████| 1788/1788 [03:25<00:00,  8.70it/s, loss=0.3273, dice=0.7736]
Epoch 4/5 [Val]: 100%|██████████| 767/767 [00:23<00:00, 31.96it/s, loss=0.4045, dice=0.1333]



Epoch 4/5 Summary:
  Train Loss: 0.4277 | Train Dice: 0.7343
  Val Loss: 0.4249 | Val Dice: 0.7059
  GPU Memory: 1.02 GB / 1.44 GB (peak)



Epoch 5/5 [Train]: 100%|██████████| 1788/1788 [03:25<00:00,  8.72it/s, loss=0.2482, dice=0.8767]
Epoch 5/5 [Val]: 100%|██████████| 767/767 [00:23<00:00, 32.16it/s, loss=0.5114, dice=0.1228]



Epoch 5/5 Summary:
  Train Loss: 0.4256 | Train Dice: 0.7424
  Val Loss: 0.4350 | Val Dice: 0.6914
  GPU Memory: 1.02 GB / 1.44 GB (peak)


✓ Training completed!


In [11]:
print("\n[6] Evaluating model...")

results = evaluate(
    model=trained_model,
    test_dataset=test_dataset,
    batch_size=2
)


[6] Evaluating model...

Evaluating on 1533 samples...



Evaluating: 100%|██████████| 767/767 [00:24<00:00, 31.14it/s]


COMPREHENSIVE EVALUATION RESULTS

 Primary Metrics:
  Dice Coefficient:        0.7650  (Overlap-based metric)
  IoU (Jaccard Index):     0.6750  (Intersection over Union)
  Pixel Accuracy:          0.8015  (Correct pixels / Total pixels)

 Classification Metrics:
  Precision:               0.8547  (TP / (TP + FP))
  Recall (Sensitivity):    0.7663  (TP / (TP + FN))
  F1 Score:                0.7650  (Harmonic mean of Precision & Recall)
  Specificity:             0.7241  (TN / (TN + FP))






In [12]:
print("\n[7] Saving final model...")
torch.save(trained_model.state_dict(), 'trained_model_final.pth')
print("    ✓ Model saved to trained_model_final.pth")

# Final cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()
print(f"    ✓ Memory cleaned up")

print("\n" + "="*80)
print("✓ TRAINING COMPLETED SUCCESSFULLY!")
print("="*80)


[7] Saving final model...
    ✓ Model saved to trained_model_final.pth
    ✓ Memory cleaned up

✓ TRAINING COMPLETED SUCCESSFULLY!
