In [97]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import h5py
import timm
import numpy as np
from scipy import ndimage
import torch.nn.functional as F
from torchvision import transforms
import torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
# Add these imports
import pandas as pd
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import RichProgressBar, LearningRateMonitor

In [98]:
class PathEnhancer:
    """Enhances path features in images."""
    def __init__(self):
        self.kernel = torch.ones(3, 3)
        
    def find_endpoints(self, image):
        """Locate the two brightest points (endpoints)."""
        flat_image = image.view(-1)
        values, indices = torch.topk(flat_image, 2)
        return indices
    
    def get_path_direction(self, image):
        """Compute directional features for path pixels."""
        directions = torch.zeros((32, 32, 8))
        padded = F.pad(image, (1, 1, 1, 1))
        
        for i in range(32):
            for j in range(32):
                if image[i, j] > 0.5:
                    for idx, (di, dj) in enumerate([(-1,-1), (-1,0), (-1,1), (0,-1), 
                                                  (0,1), (1,-1), (1,0), (1,1)]):
                        directions[i, j, idx] = padded[i+di+1, j+dj+1]
        return directions
    
    def enhance_paths(self, image):
        # Binarize
        binary = (image > 127.5).float()
        
        # Find endpoints
        endpoints = self.find_endpoints(binary)
        
        # Get path directions
        directions = self.get_path_direction(binary)
        
        # Create distance transform from endpoints
        endpoint_distances = torch.zeros((2, 32, 32))
        coords = torch.stack(torch.meshgrid(torch.arange(32), torch.arange(32))).reshape(2, -1).t().float()  # Convert to float
        for i, endpoint in enumerate(endpoints):
            y, x = endpoint // 32, endpoint % 32
            distances = torch.cdist(
                torch.tensor([[float(y), float(x)]]),  # Convert to float
                coords
            ).reshape(32, 32)
            endpoint_distances[i] = distances
        
        # Combine features
        enhanced = torch.cat([
            binary.unsqueeze(0),
            directions.permute(2, 0, 1),
            endpoint_distances,
        ])
        
        return enhanced

In [99]:
class PathfinderDataModule(pl.LightningDataModule):
    def __init__(self, data_path, batch_size=64):
        super().__init__()
        self.data_path = data_path
        self.batch_size = batch_size
        self.enhancer = PathEnhancer()
        
    def setup(self, stage=None):
        with h5py.File(self.data_path, 'r') as f:
            total = len(f['images'])
            indices = torch.randperm(total)
            train_size = int(0.8 * total)
            val_size = int(0.1 * total)
            
            self.train_indices = indices[:train_size]
            self.val_indices = indices[train_size:train_size+val_size]
            self.test_indices = indices[train_size+val_size:]
            
            print(f"\nDataset Statistics:")
            print(f"Total samples: {total:,}")
            print(f"Train: {len(self.train_indices):,}")
            print(f"Val: {len(self.val_indices):,}")
            print(f"Test: {len(self.test_indices):,}")
            
            # Compute class weights for balanced training
            labels = f['labels'][:]
            pos_weight = (labels == 0).sum() / (labels == 1).sum()
            self.pos_weight = torch.tensor([pos_weight])
    
    def _get_dataset(self, indices):
        return EnhancedPathfinderDataset(self.data_path, indices, self.enhancer)

    def train_dataloader(self):
        return DataLoader(
            self._get_dataset(self.train_indices),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=0
        )

    def val_dataloader(self):
        return DataLoader(
            self._get_dataset(self.val_indices),
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=0
        )

    def test_dataloader(self):
        return DataLoader(
            self._get_dataset(self.test_indices),
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=0
        )

In [100]:
class EnhancedPathfinderDataset(Dataset):
    def __init__(self, h5_path, indices, enhancer):
        self.h5_path = h5_path
        self.indices = indices
        self.enhancer = enhancer
        
        with h5py.File(h5_path, 'r') as f:
            self.length = len(indices)
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        true_idx = self.indices[idx]
        with h5py.File(self.h5_path, 'r') as f:
            image = torch.from_numpy(f['images'][true_idx]).float()
            label = torch.tensor(f['labels'][true_idx]).long()
        
        # Apply path enhancement
        enhanced = self.enhancer.enhance_paths(image)
        
        return {
            'image': enhanced,
            'original': image,
            'label': label
        }

In [101]:
class SequentialDeiT(pl.LightningModule):
    def __init__(
        self,
        learning_rate=1e-4,
        weight_decay=0.01,
        warmup_steps=100
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Modified DeiT for sequential processing
        self.model  = timm.create_model(
            'deit_tiny_patch16_224',  # Change to tiny from small
            pretrained=True,
            num_classes=2,
            img_size=32,
            in_chans=11,
            patch_size=8  # Increase from 4
        )
        
        # Add path-specific attention
        self.path_attention = nn.MultiheadAttention(
            embed_dim=192,  # DeiT small hidden dim
            num_heads=6,
            dropout=0.1
        )
        
        # Metrics
        metrics = torchmetrics.MetricCollection({
            'accuracy': torchmetrics.Accuracy(task='binary'),
            'precision': torchmetrics.Precision(task='binary'),
            'recall': torchmetrics.Recall(task='binary'),
            'f1': torchmetrics.F1Score(task='binary')
        })
        
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')
        
    def forward(self, x):
        # Process through DeiT
        features = self.model.forward_features(x)
        
        # Add path-specific attention
        attn_out, _ = self.path_attention(
            features, features, features,
            need_weights=False
        )
        
        # Combine with original features
        features = features + attn_out
        
        # Classification head
        return self.model.head(features.mean(1))
    
    def training_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label']
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        
        # Modify metrics calculation
        self.train_metrics(logits.argmax(dim=1), y)  # Use predicted class instead of probabilities
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label']
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        
        self.val_metrics(logits.argmax(dim=1), y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch['image'], batch['label']
        logits = self(x)
        acc = (logits.argmax(1) == y).float().mean()
        
        self.test_metrics(logits.argmax(dim=1), y)
        self.log('test_acc', acc, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.hparams.learning_rate,
            total_steps=self.trainer.estimated_stepping_batches,
            pct_start=0.1
        )
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

In [102]:
def train_progressive_deit():
   print(f"GPU available: {torch.cuda.is_available()}")
   if torch.cuda.is_available():
       print(f"Device: {torch.cuda.get_device_name()}")

   # Training history storage
   history = {
       'medium_train_loss': [], 'medium_val_loss': [],
       'medium_train_acc': [], 'medium_val_acc': [],
       'hard_train_loss': [], 'hard_val_loss': [],
       'hard_train_acc': [], 'hard_val_acc': []
   }

   # Medium dataset training
   data_path_medium = '/kaggle/input/deit-data/merged_data_medium.h5'
   datamodule_medium = PathfinderDataModule(data_path_medium, batch_size=64)
   model = SequentialDeiT()
   
   trainer_medium = pl.Trainer(
       max_epochs=5,
       accelerator='gpu',
       devices=1,
       precision=16,
       limit_train_batches=0.5,
       limit_val_batches=0.3,
       callbacks=[
           ModelCheckpoint(
               dirpath='checkpoints/medium',
               filename='{epoch}-{val_accuracy:.2f}',
               monitor='val_accuracy',
               mode='max'
           ),
           EarlyStopping(monitor='val_loss', patience=3),
           RichProgressBar(),
           LearningRateMonitor(logging_interval='step')
       ],
       logger=TensorBoardLogger("logs", name="deit_medium")
   )
   
   print("\n" + "="*50)
   print("Training on Medium Dataset...")
   print("="*50)
   trainer_medium.fit(model, datamodule_medium)
   
   # Store medium dataset metrics
   history['medium_train_loss'] = trainer_medium.callback_metrics['train_loss'].item()
   history['medium_val_loss'] = trainer_medium.callback_metrics['val_loss'].item()
   history['medium_train_acc'] = trainer_medium.callback_metrics['train_acc'].item()
   history['medium_val_acc'] = trainer_medium.callback_metrics['val_acc'].item()
   
   # Hard dataset training
   data_path_hard = '/kaggle/input/deit-data/merged_data_hard.h5'
   datamodule_hard = PathfinderDataModule(data_path_hard, batch_size=64)
   
   trainer_hard = pl.Trainer(
       max_epochs=5,
       accelerator='gpu',
       devices=1,
       precision=16,
       limit_train_batches=0.5,
       limit_val_batches=0.3,
       callbacks=[
           ModelCheckpoint(
               dirpath='checkpoints/hard',
               filename='{epoch}-{val_accuracy:.2f}',
               monitor='val_accuracy',
               mode='max'
           ),
           EarlyStopping(monitor='val_loss', patience=3),
           RichProgressBar(),
           LearningRateMonitor(logging_interval='step')
       ],
       logger=TensorBoardLogger("logs", name="deit_hard")
   )
   
   print("\n" + "="*50)
   print("Training on Hard Dataset...")
   print("="*50)
   trainer_hard.fit(model, datamodule_hard)
   
   # Store hard dataset metrics
   history['hard_train_loss'] = trainer_hard.callback_metrics['train_loss'].item()
   history['hard_val_loss'] = trainer_hard.callback_metrics['val_loss'].item()
   history['hard_train_acc'] = trainer_hard.callback_metrics['train_acc'].item()
   history['hard_val_acc'] = trainer_hard.callback_metrics['val_acc'].item()
   
   # Plot training history
   plot_training_history(history)
   
   return model, history

def plot_training_history(history):
   fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
   
   # Loss plot
   ax1.plot(history['medium_train_loss'], label='Medium Train')
   ax1.plot(history['medium_val_loss'], label='Medium Val')
   ax1.plot(history['hard_train_loss'], label='Hard Train')
   ax1.plot(history['hard_val_loss'], label='Hard Val')
   ax1.set_title('Loss History')
   ax1.set_xlabel('Epoch')
   ax1.set_ylabel('Loss')
   ax1.legend()
   ax1.grid(True)
   
   # Accuracy plot
   ax2.plot(history['medium_train_acc'], label='Medium Train')
   ax2.plot(history['medium_val_acc'], label='Medium Val')
   ax2.plot(history['hard_train_acc'], label='Hard Train')
   ax2.plot(history['hard_val_acc'], label='Hard Val')
   ax2.set_title('Accuracy History')
   ax2.set_xlabel('Epoch')
   ax2.set_ylabel('Accuracy')
   ax2.legend()
   ax2.grid(True)
   
   plt.tight_layout()
   plt.savefig('training_history.png')
   plt.show()

# Run training with TensorBoard
model, history = train_progressive_deit()



GPU available: True
Device: Tesla P100-PCIE-16GB

Training on Medium Dataset...

Dataset Statistics:
Total samples: 200,000
Train: 160,000
Val: 20,000
Test: 20,000


Output()

NameError: name 'exit' is not defined

In [None]:
# View in TensorBoard (separate cell)
%load_ext tensorboard
%tensorboard --logdir logs/

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/