## Imports & Configuration

In [1]:
import os
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import h5py
import numpy as np
from torchmetrics import Accuracy
import matplotlib.pyplot as plt
from datetime import datetime

# Configuration
pl.seed_everything(42)
EXPERIMENT_NAME = f"perceiver_{datetime.now().strftime('%Y%m%d_%H%M')}"
MAX_EPOCHS = 25 

# Training configurations
BATCH_SIZE = 128
NUM_WORKERS = 4
LEARNING_RATE = 1e-3

print(f"Starting experiment: {EXPERIMENT_NAME}")
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Device: {torch.cuda.get_device_name(0)}")

Starting experiment: perceiver_20241208_0425
GPU Available: True
GPU Device: Tesla T4


## Dataset Class

In [2]:
class PathfinderDataset(Dataset):
    def __init__(self, h5_file, difficulty):
        """Initialize dataset with proper data loading and preprocessing."""
        self.h5_file = h5py.File(h5_file, "r")
        self.images = self.h5_file["images"]
        self.labels = self.h5_file["labels"]
        self.difficulty = difficulty
        self.log_dataset_info()
    
    def log_dataset_info(self):
        """Log dataset statistics for monitoring."""
        total_samples = len(self.labels)
        labels_array = self.labels[:]
        pos_samples = np.sum(labels_array == 1)
        neg_samples = np.sum(labels_array == 0)
        
        print(f"\nDataset Information ({self.difficulty}):")
        print(f"Total samples: {total_samples}")
        print(f"Positive samples: {pos_samples} ({pos_samples/total_samples*100:.2f}%)")
        print(f"Negative samples: {neg_samples} ({neg_samples/total_samples*100:.2f}%)")

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        """Get and preprocess a single sample."""
        image = self.images[idx].astype(np.float32)
        image = image / 255.0  # Normalize to [0,1]
        sequence = image.reshape(-1)  # Flatten to sequence
        sequence = torch.tensor(sequence, dtype=torch.float32).unsqueeze(-1)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return sequence, label

In [3]:
# First, let's modify the DataLoader configuration to work better in notebooks
class PathfinderDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=128):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.difficulties = ['easy', 'medium', 'hard']
        
    def setup(self, stage=None):
        for difficulty in self.difficulties:
            file_path = os.path.join(self.data_dir, f'merged_data_{difficulty}.h5')
            dataset = PathfinderDataset(file_path, difficulty)
            
            total_size = len(dataset)
            train_size = int(0.7 * total_size)
            val_size = int(0.15 * total_size)
            test_size = total_size - train_size - val_size
            
            splits = torch.utils.data.random_split(
                dataset, 
                [train_size, val_size, test_size],
                generator=torch.Generator().manual_seed(42)
            )
            
            setattr(self, f'{difficulty}_train', splits[0])
            setattr(self, f'{difficulty}_val', splits[1])
            setattr(self, f'{difficulty}_test', splits[2])
    
    def train_dataloader(self, difficulty='easy'):
        return DataLoader(
            getattr(self, f'{difficulty}_train'),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=0,  # Changed to 0 to avoid multiprocessing issues
            pin_memory=True
        )
    
    def val_dataloader(self, difficulty='easy'):
        return DataLoader(
            getattr(self, f'{difficulty}_val'),
            batch_size=self.batch_size,
            num_workers=0,  # Changed to 0
            pin_memory=True
        )
    
    def test_dataloader(self, difficulty='easy'):
        return DataLoader(
            getattr(self, f'{difficulty}_test'),
            batch_size=self.batch_size,
            num_workers=0,  # Changed to 0
            pin_memory=True
        )

## Perceiver Model

In [4]:
class PerceiverEncoder(nn.Module):
    def __init__(self, 
                 input_dim=1,           # Input channel dimension
                 num_latents=64,        # Number of latent vectors
                 latent_dim=128,        # Dimension of latent vectors
                 num_self_attn=4,       # Number of self-attention iterations
                 num_cross_attn=1):     # Number of cross-attention iterations
        super().__init__()
        
        # Initialize latent array
        self.latents = nn.Parameter(torch.randn(1, num_latents, latent_dim))
        
        # Cross attention: input sequence → latents
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=latent_dim,
            num_heads=8,
            batch_first=True
        )
        
        # Self attention layers for latent processing
        self.self_attention_layers = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=latent_dim,
                num_heads=8,
                batch_first=True
            ) for _ in range(num_self_attn)
        ])
        
        # Input projection
        self.input_projection = nn.Linear(input_dim, latent_dim)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(latent_dim)
        self.norm2 = nn.LayerNorm(latent_dim)
        
        self.num_cross_attn = num_cross_attn

    def forward(self, x):
        """Process input through cross-attention and self-attention iterations."""
        batch_size = x.shape[0]
        
        # Project input to latent dimension
        x = self.input_projection(x)
        
        # Expand latents for batch size
        latents = self.latents.expand(batch_size, -1, -1)
        
        # Cross attention iterations
        for _ in range(self.num_cross_attn):
            # Cross attention between inputs and latents
            latents_update, _ = self.cross_attention(
                query=latents,
                key=x,
                value=x
            )
            latents = latents + latents_update
            latents = self.norm1(latents)
            
            # Self attention processing
            for self_attn in self.self_attention_layers:
                latents_update, _ = self_attn(
                    query=latents,
                    key=latents,
                    value=latents
                )
                latents = latents + latents_update
                latents = self.norm2(latents)
        
        return latents

class PerceiverPathfinder(pl.LightningModule):
    def __init__(self, 
                 sequence_length=1024,
                 num_classes=2):
        super().__init__()
        self.save_hyperparameters()
        
        # Perceiver encoder
        self.encoder = PerceiverEncoder(
            input_dim=1,
            num_latents=64,
            latent_dim=128,
            num_self_attn=4,
            num_cross_attn=1
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(128 * 64, 512),  # Flatten and project
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
        
        # Metrics
        self.train_accuracy = Accuracy(task="binary")
        self.val_accuracy = Accuracy(task="binary")
        self.test_accuracy = Accuracy(task="binary")

    def forward(self, x):
        # Get latent representations
        latents = self.encoder(x)
        
        # Flatten latents for classification
        latents_flat = latents.reshape(latents.shape[0], -1)
        
        # Classify
        return self.classifier(latents_flat)

    def training_step(self, batch, batch_idx):
        sequences, labels = batch
        logits = self(sequences)
        loss = F.cross_entropy(logits, labels)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.train_accuracy(preds, labels)
        
        # Log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        sequences, labels = batch
        logits = self(sequences)
        loss = F.cross_entropy(logits, labels)
        
        preds = torch.argmax(logits, dim=1)
        acc = self.val_accuracy(preds, labels)
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        sequences, labels = batch
        logits = self(sequences)
        loss = F.cross_entropy(logits, labels)
        
        preds = torch.argmax(logits, dim=1)
        acc = self.test_accuracy(preds, labels)
        
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=LEARNING_RATE,
            weight_decay=0.01
        )
        
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=MAX_EPOCHS,
            eta_min=1e-6
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

## Data Module Implementation

In [5]:
class PerceiverPathfinder(pl.LightningModule):
    def __init__(self, 
                 sequence_length=1024,
                 num_classes=2,
                 num_latents=128,
                 latent_dim=256,
                 num_self_attn=6,
                 num_cross_attn=2,
                 dropout=0.1):
        super().__init__()
        self.save_hyperparameters()
        
        # Input processing with improved embedding
        self.input_embedding = nn.Sequential(
            nn.Linear(1, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        
        # Initialize learnable latent vectors
        self.latents = nn.Parameter(torch.randn(1, num_latents, latent_dim))
        
        # Cross attention layers
        self.cross_attention_layers = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=latent_dim,
                num_heads=8,
                dropout=dropout,
                batch_first=True
            ) for _ in range(num_cross_attn)
        ])
        
        # Self attention layers
        self.self_attention_layers = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=latent_dim,
                num_heads=8,
                dropout=dropout,
                batch_first=True
            ) for _ in range(num_self_attn)
        ])
        
        # Layer normalization
        self.norm_layers = nn.ModuleList([
            nn.LayerNorm(latent_dim)
            for _ in range(num_self_attn + num_cross_attn)
        ])
        
        # Enhanced classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(latent_dim * num_latents),
            nn.Linear(latent_dim * num_latents, latent_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(latent_dim * 4, latent_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(latent_dim, num_classes)
        )
        
        # Metrics for tracking performance
        self.train_accuracy = Accuracy(task="binary")
        self.val_accuracy = Accuracy(task="binary")
        self.test_accuracy = Accuracy(task="binary")
        self.learning_rate = 2e-4

    def forward(self, x):
        batch_size = x.shape[0]
        
        # Create rich input embeddings
        x = self.input_embedding(x)
        
        # Expand latents for batch processing
        latents = self.latents.expand(batch_size, -1, -1)
        
        # Cross attention processing
        for i, cross_attn in enumerate(self.cross_attention_layers):
            latents_update, _ = cross_attn(latents, x, x)
            latents = latents + latents_update
            latents = self.norm_layers[i](latents)
        
        # Self attention processing
        offset = len(self.cross_attention_layers)
        for i, self_attn in enumerate(self.self_attention_layers):
            latents_update, _ = self_attn(latents, latents, latents)
            latents = latents + latents_update
            latents = self.norm_layers[offset + i](latents)
        
        # Prepare for classification
        latents = latents.reshape(batch_size, -1)
        return self.classifier(latents)

    def training_step(self, batch, batch_idx):
        """Defines the training loop behavior"""
        sequences, labels = batch
        logits = self(sequences)
        
        # Calculate loss
        loss = F.cross_entropy(logits, labels)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.train_accuracy(preds, labels)
        
        # Log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        """Defines the validation loop behavior"""
        sequences, labels = batch
        logits = self(sequences)
        
        # Calculate loss
        loss = F.cross_entropy(logits, labels)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.val_accuracy(preds, labels)
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        """Defines the test loop behavior"""
        sequences, labels = batch
        logits = self(sequences)
        
        # Calculate loss
        loss = F.cross_entropy(logits, labels)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.test_accuracy(preds, labels)
        
        # Log metrics
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        """Defines the optimization algorithm and learning rate schedule"""
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=0.01,
            betas=(0.9, 0.999)
        )
        
        # Calculate total steps for the scheduler
        steps_per_epoch = 1093  # Adjust based on your dataset size
        total_steps = 35 * steps_per_epoch
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.learning_rate,
            total_steps=total_steps,
            pct_start=0.3,
            div_factor=25,
            final_div_factor=1000,
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step"
            }
        }

## Training and Evaluation

In [6]:
def train_and_evaluate(data_module, difficulty):
    """Train and evaluate the Perceiver model for a specific difficulty level.
    
    Args:
        data_module: The data module containing the datasets
        difficulty: Which difficulty level to train on ('easy', 'medium', 'hard')
    """
    # Create model
    model = PerceiverPathfinder()

    # Configure training environment
    torch.backends.cudnn.benchmark = True  # May improve performance
    
    # Create callbacks for monitoring training
    callbacks = [
        pl.callbacks.ModelCheckpoint(
            dirpath=f'perceiver_checkpoints/{difficulty}',
            filename=f'perceiver_{difficulty}-{{epoch}}-{{val_acc:.2f}}',
            monitor='val_acc',
            mode='max',
            save_top_k=3
        ),
        pl.callbacks.EarlyStopping(
            monitor='val_acc',
            mode='max',
            patience=5,  # Stop if no improvement for 5 epochs
            verbose=True
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
    
    # Create trainer
    trainer = pl.Trainer(
        max_epochs=25,
        accelerator='gpu',
        devices=1,  # Use single GPU
        precision='16-mixed',
        callbacks=callbacks,
        logger=True,
        gradient_clip_val=1.0,
        log_every_n_steps=10,
        enable_progress_bar=True,
        strategy='auto'  # Let PyTorch Lightning choose the best strategy
    )
    
    # Train and test the model
    print(f"\nTraining Perceiver on {difficulty} dataset:")
    trainer.fit(
        model,
        train_dataloaders=data_module.train_dataloader(difficulty),
        val_dataloaders=data_module.val_dataloader(difficulty)
    )
    
    print(f"\nTesting Perceiver on {difficulty} dataset:")
    trainer.test(
        model,
        dataloaders=data_module.test_dataloader(difficulty)
    )
    
    return model, trainer

## Main Execution

In [7]:
def run_all_experiments():
    """Run experiments for all difficulty levels and collect results."""
    print("Initializing Perceiver experiments...")
    
    # Initialize data module
    data_module = PathfinderDataModule('/kaggle/input/merged')
    data_module.setup()
    
    # Store results for each difficulty
    results = {}
    
    try:
        for difficulty in ['easy', 'medium', 'hard']:
            print(f"\n{'='*50}")
            print(f"Starting Perceiver training for {difficulty} difficulty")
            print(f"{'='*50}")
            
            # Train and evaluate
            model, trainer = train_and_evaluate(data_module, difficulty)
            
            # Store results
            results[difficulty] = {
                'test_acc': trainer.callback_metrics.get('test_acc', 0).item(),
                'test_loss': trainer.callback_metrics.get('test_loss', 0).item()
            }
            
            # Print results for this difficulty
            print(f"\nResults for {difficulty}:")
            print(f"Test Accuracy: {results[difficulty]['test_acc']*100:.2f}%")
            print(f"Test Loss: {results[difficulty]['test_loss']:.4f}")
    
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        print("Partial results:", results)
    
    # Save results to CSV
    import pandas as pd
    results_df = pd.DataFrame([
        {
            'model': 'Perceiver',
            'difficulty': diff,
            'accuracy': metrics['test_acc'],
            'loss': metrics['test_loss']
        }
        for diff, metrics in results.items()
    ])
    results_df.to_csv(f'perceiver_results_{datetime.now().strftime("%Y%m%d_%H%M")}.csv', 
                      index=False)
    
    return results

# Run all experiments
if __name__ == "__main__":
    results = run_all_experiments()

Initializing Perceiver experiments...

Dataset Information (easy):
Total samples: 199800
Positive samples: 99985 (50.04%)
Negative samples: 99815 (49.96%)

Dataset Information (medium):
Total samples: 200000
Positive samples: 100222 (50.11%)
Negative samples: 99778 (49.89%)

Dataset Information (hard):
Total samples: 200000
Positive samples: 99920 (49.96%)
Negative samples: 100080 (50.04%)

Starting Perceiver training for easy difficulty

Training Perceiver on easy dataset:


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Testing Perceiver on easy dataset:


/opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]


Results for easy:
Test Accuracy: 50.29%
Test Loss: 0.6931

Starting Perceiver training for medium difficulty

Training Perceiver on medium dataset:


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Testing Perceiver on medium dataset:


Testing: |          | 0/? [00:00<?, ?it/s]


Results for medium:
Test Accuracy: 49.71%
Test Loss: 0.6931

Starting Perceiver training for hard difficulty

Training Perceiver on hard dataset:


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Testing Perceiver on hard dataset:


Testing: |          | 0/? [00:00<?, ?it/s]


Results for hard:
Test Accuracy: 49.56%
Test Loss: 0.6932
