In [None]:
!pip install --upgrade xgboost

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm
import os
import copy
import time
import pickle
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import xgboost as xgb
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# Set random seed for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Hyperparameters
num_classes = 5  # Classes 0,1,2,3,4
BATCH_SIZE = 2  
ACCUMULATION_STEPS = 8 

# Memory optimizations
torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True'
num_epochs = 100  
dropout_rate = 0.3  
weight_decay = 0.01  
patience = 20  
min_delta = 0.0001
lr_patience = 10 
lr_factor = 0.2 
fine_tune_learning_rate = 1e-5  

# XGBoost parameters
xgb_params = {
    'max_depth': [6, 8, 10], 
    'learning_rate': [0.01, 0.03, 0.05], 
    'n_estimators': [200],  
    'min_child_weight': [1, 2],  
    'gamma': [0.1, 0.2],  
    'subsample': [0.8, 0.9], 
    'colsample_bytree': [0.8, 0.9],
    'tree_method': ['gpu_hist'],
    'predictor': ['gpu_predictor'],
    'max_bin': [256],
    'gpu_id': [0],
    'objective': ['multi:softprob'],
    'num_class': [num_classes],
    'eval_metric': ['mlogloss'],
    'use_label_encoder': [False]
}


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation((-45, 45)),
        transforms.ColorJitter(
            brightness=0.4, 
            contrast=0.4, 
            saturation=0.4, 
            hue=0.2
        ),
        transforms.RandomAffine(
            degrees=(-30, 30), 
            translate=(0.15, 0.15), 
            scale=(0.85, 1.15),
            shear=(-10, 10)
        ),
        transforms.RandomPerspective(p=0.3, distortion_scale=0.5),
        transforms.RandomAutocontrast(p=0.3),
        transforms.RandomEqualize(p=0.3),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.3),
        transforms.ToTensor(),
        normalize,
        transforms.RandomErasing(p=0.3, scale=(0.02, 0.33))
    ]),
    'val': transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        normalize
    ]),
    'test': transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        normalize
    ])
}


In [None]:
# Memory optimization
import gc

def clear_gpu_cache():
    """Clear GPU cache and garbage collect"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

def get_gpu_memory():
    """Get current GPU memory usage"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**2, torch.cuda.memory_reserved() / 1024**2
    return 0, 0

def print_gpu_memory():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated, reserved = get_gpu_memory()
        print(f'GPU Memory: {allocated:.2f}MB allocated, {reserved:.2f}MB reserved')


In [None]:
class TwoStreamDataset(Dataset):
    def __init__(self, rgb_root, vessel_root, transform=None):
        self.transform = transform
        self.image_pairs = []
        
        print(f"Initializing dataset with RGB root: {rgb_root}")
        print(f"Vessel root: {vessel_root}")
        
        for class_idx in range(5):
            class_name = str(class_idx)
            rgb_class_dir = os.path.join(rgb_root, class_name)
            vessel_class_dir = os.path.join(vessel_root, class_name)
            
            if not os.path.exists(rgb_class_dir) or not os.path.exists(vessel_class_dir):
                print(f"Warning: Directory not found - RGB: {rgb_class_dir} or Vessel: {vessel_class_dir}")
                continue
                
            rgb_files = [f for f in os.listdir(rgb_class_dir) if f.endswith('.png')]
            
            for img_name in rgb_files:
                rgb_path = os.path.join(rgb_class_dir, img_name)
                vessel_path = os.path.join(vessel_class_dir, img_name)
                
                if os.path.exists(vessel_path):
                    self.image_pairs.append({
                        'rgb': rgb_path,
                        'vessel': vessel_path,
                        'label': class_idx
                    })
        
        print(f"Total number of image pairs found: {len(self.image_pairs)}")
        if len(self.image_pairs) == 0:
            raise RuntimeError("No valid image pairs found!")

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        try:
            pair = self.image_pairs[idx]
            rgb_img = Image.open(pair['rgb']).convert('RGB')
            vessel_img = Image.open(pair['vessel']).convert('RGB')
            
            if self.transform:
                rgb_img = self.transform(rgb_img)
                vessel_img = self.transform(vessel_img)
            
            label = torch.tensor(pair['label'], dtype=torch.long)
            return rgb_img, vessel_img, label
            
        except Exception as e:
            print(f"Error loading pair {idx}: {str(e)}")
            print(f"Paths - RGB: {pair['rgb']}, Vessel: {pair['vessel']}")
            raise e

# Dataset
print("Creating datasets...")
data_dir = "/kaggle/input/vdmdrneww/New_VDMDR" 
rgb_root = os.path.join(data_dir, 'RGB')
vessel_root = os.path.join(data_dir, 'Vessel')

full_dataset = TwoStreamDataset(rgb_root, vessel_root, transform=data_transforms['train'])

# Splits
total_len = len(full_dataset)
train_len = int(0.7 * total_len)
val_len = int(0.15 * total_len)
test_len = total_len - train_len - val_len

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    full_dataset, 
    [train_len, val_len, test_len],
    generator=generator
)

print(f"Split sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")


dataloaders = {
    'train': DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=2,  
        pin_memory=True,  
        persistent_workers=True, 
        prefetch_factor=2 
    ),
    'val': DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=2,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    ),
    'test': DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=2,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
}

dataset_sizes = {
    'train': len(train_dataset),
    'val': len(val_dataset),
    'test': len(test_dataset)
}

print("Dataset sizes:", dataset_sizes)

# Test for first batch
print("\nTesting first batch loading...")
try:
    train_iter = iter(dataloaders['train'])
    first_batch = next(train_iter)
    print("Successfully loaded first batch!")
    print(f"RGB tensor shape: {first_batch[0].shape}")
    print(f"Vessel tensor shape: {first_batch[1].shape}")
    print(f"Labels shape: {first_batch[2].shape}")
except Exception as e:
    print(f"Error loading first batch: {str(e)}")
    raise e


In [None]:
class TwoStreamNetwork(nn.Module):
    """Two-stream neural network with Swin Transformer backbone"""
    def __init__(self, num_classes=5, feature_extraction=False):
        super(TwoStreamNetwork, self).__init__()
        
        # RGB Stream
        self.rgb_model = timm.create_model(
            'swin_large_patch4_window12_384',  # Swin-L 384x384
            pretrained=True,  # Use pretrained weights
            use_checkpoint=True,  # Enable gradient checkpointing
            checkpoint_path=None,
            num_classes=0,  # Remove classification head
            drop_path_rate=0.2  # Stochastic depth
        )
        
        # Vessel Stream
        self.vessel_model = timm.create_model(
            'swin_large_patch4_window12_384',
            pretrained=True,
            use_checkpoint=True,
            checkpoint_path=None,
            num_classes=0,
            drop_path_rate=0.2
        )
        
        self.rgb_model.set_grad_checkpointing(enable=True)
        self.vessel_model.set_grad_checkpointing(enable=True)
        
        # Feature dimensions
        self.feature_dim = 1536
        self.combined_features = self.feature_dim * 2
        hidden_dim = 1024
        
        self.feature_extraction = feature_extraction
        
        if not feature_extraction:
            # Fusion Layers
            self.fusion = nn.Sequential(
                nn.Linear(self.combined_features, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout_rate)
            )
            
            # Classification head
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.LayerNorm(hidden_dim // 2),
                nn.GELU(),
                nn.Dropout(dropout_rate),
                nn.Linear(hidden_dim // 2, num_classes)
            )
            
            self._init_weights()
                    
    def _init_weights(self):
        """Initialize the weights of fusion and classifier layers"""
        if not self.feature_extraction:
            for m in self.fusion.modules():
                if isinstance(m, nn.Linear):
                    nn.init.kaiming_normal_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                        
            for m in self.classifier.modules():
                if isinstance(m, nn.Linear):
                    nn.init.kaiming_normal_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)
                    
    @torch.cuda.amp.autocast()
    def forward(self, rgb_input, vessel_input):
        """Forward pass with memory-efficient feature extraction"""
        # Extract RGB features
        rgb_features = self.rgb_model(rgb_input)
        torch.cuda.empty_cache()
        
        # Extract vessel features
        vessel_features = self.vessel_model(vessel_input)
        torch.cuda.empty_cache()
        
        # Combine features
        combined = torch.cat((rgb_features, vessel_features), dim=1)
        
        if self.feature_extraction:
            return combined
        
        # Fusion and classification
        fused = self.fusion(combined)
        output = self.classifier(fused)
        
        return output

# Model initialisation
print("Creating model...")
model = TwoStreamNetwork(num_classes=num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(
    model.parameters(), 
    lr=fine_tune_learning_rate, 
    weight_decay=weight_decay,
    betas=(0.9, 0.999),  
    eps=1e-8  
)


scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min',
    factor=lr_factor,
    patience=lr_patience,
    verbose=True,
    min_lr=1e-7
)

print("Model created successfully!")


In [None]:
# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, patience=5):
    """Training function with mixed precision, gradient accumulation, and early stopping"""
    since = time.time()
    scaler = torch.amp.GradScaler('cuda')
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_loss = float('inf')
    epochs_no_improve = 0
    
    # History
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    max_grad_norm = 1.0 
    
    # Training loop
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
                
            running_loss = 0.0
            running_corrects = 0
            
            pbar = tqdm(dataloaders[phase], desc=f'{phase} Epoch {epoch}')
            optimizer.zero_grad(set_to_none=True)
            
            for i, (rgb_inputs, vessel_inputs, labels) in enumerate(pbar):
                rgb_inputs = rgb_inputs.to(device, non_blocking=True)
                vessel_inputs = vessel_inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    with torch.amp.autocast('cuda'):
                        outputs = model(rgb_inputs, vessel_inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
                        loss = loss / ACCUMULATION_STEPS
                    
                    # Backward pass (only in training)
                    if phase == 'train':
                        scaler.scale(loss).backward()
                        
                        # If we've accumulated enough gradients
                        if (i + 1) % ACCUMULATION_STEPS == 0 or (i + 1) == len(pbar):
                            scaler.unscale_(optimizer)
                            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                            
                            scaler.step(optimizer)
                            scaler.update()
                            optimizer.zero_grad(set_to_none=True)
                        
                running_loss += (loss.item() * ACCUMULATION_STEPS) * rgb_inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                batch_loss = loss.item() * ACCUMULATION_STEPS
                batch_acc = torch.sum(preds == labels.data).double() / rgb_inputs.size(0)
                pbar.set_postfix({
                    'loss': f'{batch_loss:.4f}',
                    'acc': f'{batch_acc:.4f}'
                })
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Track history
            if phase == 'train':
                train_losses.append(epoch_loss)
                train_accs.append(epoch_acc.item())
            else:  # validation phase
                val_losses.append(epoch_loss)
                val_accs.append(epoch_acc.item())
                scheduler.step(epoch_loss)
                
                # Keep best model
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                    
                # Early stop check
                if epochs_no_improve >= patience:
                    print(f'\nEarly stopping triggered after epoch {epoch}')
                    print(f'Best validation accuracy: {best_acc:.4f}')
                    model.load_state_dict(best_model_wts)
                    return model, train_losses, val_losses, train_accs, val_accs
                    
            torch.cuda.empty_cache()
                    
        current_lr = optimizer.param_groups[0]['lr']
        print(f'\nCurrent learning rate: {current_lr:.2e}\n')
        
    # Training complete
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, train_losses, val_losses, train_accs, val_accs


In [None]:
# Training the model
print("Starting training...")
model_ft, train_losses, val_losses, train_accs, val_accs = train_model(
    model, criterion, optimizer, scheduler, num_epochs=num_epochs, patience=patience
)

# Plot learning curves
plt.figure(figsize=(12, 4))

# Loss curves
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Accuracy curves
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Feature extraction function
def extract_features(model, dataloader, device):
    """Extract features from the two-stream network"""
    model.eval()
    features = []
    labels = []
    
    with torch.no_grad():
        for rgb_inputs, vessel_inputs, batch_labels in tqdm(dataloader, desc='Extracting features'):
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            rgb_inputs = rgb_inputs.to(device)
            vessel_inputs = vessel_inputs.to(device)
            
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                batch_features = model(rgb_inputs, vessel_inputs)
            
            features.append(batch_features.cpu().numpy())
            labels.extend(batch_labels.numpy())
            
            # Clean up memory
            del rgb_inputs, vessel_inputs, batch_features
            
    return np.vstack(features), np.array(labels)

# Initialize feature extraction model
print("Creating feature extraction model...")
feature_model = TwoStreamNetwork(num_classes=num_classes, feature_extraction=True)
feature_model = feature_model.to(device)

# Load pre-trained weights
if os.path.exists('best_swin_model.pth'):
    print("Loading pre-trained weights...")
    checkpoint = torch.load('best_swin_model.pth')
    feature_model.load_state_dict(checkpoint['model_state_dict'], strict=False)

# Extract features from datasets
print("Extracting features...")
train_features, train_labels = extract_features(feature_model, dataloaders['train'], device)
val_features, val_labels = extract_features(feature_model, dataloaders['val'], device)
test_features, test_labels = extract_features(feature_model, dataloaders['test'], device)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Scale features
print("Scaling features...")
scaler = StandardScaler()
train_features_scaled = scaler.fit_transform(train_features)
val_features_scaled = scaler.transform(val_features)
test_features_scaled = scaler.transform(test_features)

# Clean up memory
del train_features, val_features, test_features

# Train XGBoost classifier
print("Training XGBoost classifier...")
xgb_classifier = xgb.XGBClassifier(tree_method='gpu_hist', predictor='gpu_predictor')
grid_search = GridSearchCV(
    xgb_classifier, 
    xgb_params, 
    cv=3, 
    n_jobs=1, 
    verbose=2
)

# Fit XGBoost with early stopping
print("Fitting XGBoost...")
grid_search.fit(
    train_features_scaled, 
    train_labels,
    eval_set=[(val_features_scaled, val_labels)],
    early_stopping_rounds=20,
    verbose=True
)

# Get best model
best_xgb = grid_search.best_estimator_
print(f"Best parameters: {grid_search.best_params_}")

# Evaluate model
print("\nEvaluating model...")
val_pred = best_xgb.predict(val_features_scaled)
val_acc = np.mean(val_pred == val_labels)
print(f"Validation accuracy: {val_acc:.4f}")

test_pred = best_xgb.predict(test_features_scaled)
test_acc = np.mean(test_pred == test_labels)
print(f"Test accuracy: {test_acc:.4f}")

# Metrics
print('\nDetailed Classification Report:')
print(classification_report(test_labels, test_pred, 
                          target_names=[f'Class {i}' for i in range(num_classes)],
                          digits=4))

print('\nConfusion Matrix:')
cm = confusion_matrix(test_labels, test_pred)
print(cm)

# Per-class accuracy
class_accuracy = cm.diagonal() / cm.sum(axis=1)
for i, acc in enumerate(class_accuracy):
    print(f'Class {i} Accuracy: {acc:.4f}')

# Save ensemble model
save_path = f'ensemble_model_{test_acc:.4f}.pkl'
with open(save_path, 'wb') as f:
    pickle.dump({
        'feature_model_state': feature_model.state_dict(),
        'xgboost_model': best_xgb,
        'scaler': scaler,
        'test_acc': test_acc,
        'best_params': grid_search.best_params_
    }, f)
print(f'\nEnsemble model saved to {save_path}')
