In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from transformers import BigBirdConfig, BigBirdModel
import torch.nn.functional as F
import math


import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Optional, Tuple
from pathlib import Path
import h5py


# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
!nvidia-smi

## Dataset

In [None]:
class PathfinderH5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        with h5py.File(h5_path, 'r') as f:
            self.length = len(f['images'])
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        with h5py.File(self.h5_path, 'r') as f:
            image = torch.from_numpy(f['images'][idx]).float()
            label = torch.tensor(f['labels'][idx]).long()
        
        image = (image > 127.5).float()
        directions = self.compute_direction_embeddings(image)
        pos_emb = self.get_positional_encoding(1024, 256)
        
        return {
            'image': image.view(-1),
            'directions': directions,
            'pos_emb': pos_emb,
            'label': label
        }
    
    def compute_direction_embeddings(self, image):
        dirs = torch.zeros((32, 32, 8))
        padded = F.pad(image, (1, 1, 1, 1))
        directions = [(-1, -1), (-1, 0), (-1, 1), (0, -1), 
                     (0, 1), (1, -1), (1, 0), (1, 1)]
        
        for i in range(32):
            for j in range(32):
                if image[i, j] > 0:
                    for d, (di, dj) in enumerate(directions):
                        dirs[i, j, d] = padded[i+di+1, j+dj+1]
        return dirs.view(-1, 8)
    
    def get_positional_encoding(self, seq_len, d_model):
        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe[:, :128]  # Match d_model dimension

## Tokenizer

In [None]:
class PathEndpointTokenizer:
    """Adds special tokens for path endpoints and processes images."""
    def __init__(self, max_seq_length=1024):
        self.max_seq_length = max_seq_length
        
    def find_endpoints(self, image):
        """Find the two endpoints in the image."""
        # Find the brightest pixels as endpoints
        flat_image = image.view(-1)
        _, indices = torch.topk(flat_image, 2)
        return indices
    
    def add_endpoint_tokens(self, sequence, endpoint_positions):
        """Add special tokens at endpoint positions."""
        sequence = sequence.clone()
        for pos in endpoint_positions:
            sequence[pos] = 2.0  # Special value for endpoints
        return sequence

## Model

In [None]:
class BigBirdPathfinder(pl.LightningModule):
    def __init__(
        self,
        block_size=8,  # Local window size
        num_random_blocks=2,  # Number of random blocks for global attention
        attention_probs_dropout_prob=0.2,
        hidden_dropout_prob=0.2,
        hidden_size=64,
        num_attention_heads=2,
        num_hidden_layers=2
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # BigBird configuration
        self.config = BigBirdConfig(
            attention_type="block_sparse",
            block_size=block_size,
            num_random_blocks=num_random_blocks,
            max_position_embeddings=1024,  # For 32x32 image
            attention_probs_dropout_prob=attention_probs_dropout_prob,
            hidden_dropout_prob=hidden_dropout_prob,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            num_hidden_layers=num_hidden_layers
        )
        
        # Initialize BigBird model
        self.model = BigBirdModel(self.config, add_pooling_layer=False)
        
        # Path-specific components
        self.endpoint_tokenizer = PathEndpointTokenizer()
        self.path_predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, 2)  # Binary classification
        )
        
        # Path-aware attention weights
        self.path_attention_weights = nn.Parameter(torch.ones(num_hidden_layers))
        
    def forward(self, batch):
        images = batch['image']  # [batch_size, 1024]
        outputs = []
        
        for i in range(0, images.size(0), 2):
            sequence = images[i:i+2]  # Process 2 at a time
            output = self.model(
                inputs_embeds=sequence.unsqueeze(-1),  # [2, 1024, 1]
                position_ids=torch.arange(1024, device=images.device).unsqueeze(0).expand(sequence.size(0), -1),
                output_hidden_states=True
            )
            states = output.last_hidden_state[:, 0]
            outputs.append(states)
        
        outputs = torch.cat(outputs, dim=0)
        return self.path_predictor(outputs)
    
    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = F.cross_entropy(logits, batch['label'])
        acc = (logits.argmax(dim=1) == batch['label']).float().mean()
        
        # Log metrics
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = F.cross_entropy(logits, batch['label'])
        acc = (logits.argmax(dim=1) == batch['label']).float().mean()
        
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        logits = self(batch)
        loss = F.cross_entropy(logits, batch['label'])
        acc = (logits.argmax(dim=1) == batch['label']).float().mean()
        
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=1e-4,
            weight_decay=0.01
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=200,
            eta_min=1e-6
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }

## Training

In [None]:
from torch.utils.tensorboard import SummaryWriter

def train_progressive_bigbird(max_epochs: int = 20):
    model = BigBirdPathfinder()  # Initialize at start
    difficulties = ['easy', 'medium', 'hard']
    writer = SummaryWriter('runs/bigbird_pathfinder')
    
    for difficulty in difficulties:
        print(f"\nTraining on {difficulty} dataset...")
        file_path = f'/kaggle/input/bigbird-dataset/merged_data_{difficulty}.h5'
        
        with h5py.File(file_path, 'r') as f:
           total = len(f['images'])
           train_size = int(0.8 * total)
           val_size = int(0.1 * total)
           test_size = total - train_size - val_size
    
           indices = torch.randperm(total)
           train_indices = indices[:train_size]
           val_indices = indices[train_size:train_size+val_size]
           test_indices = indices[train_size+val_size:]
            
           # Create datasets
           train_dataset = PathfinderH5Dataset(file_path)
           val_dataset = PathfinderH5Dataset(file_path)
           test_dataset = PathfinderH5Dataset(file_path)
            
           train_dataset.indices = train_indices
           val_dataset.indices = val_indices
           test_dataset.indices = test_indices
        
           print(f"Dataset sizes for {difficulty}:")
           print(f"Total: {total:,}")
           print(f"Train: {train_size:,}")
           print(f"Val: {val_size:,}")
           print(f"Test: {test_size:,}")
           
           # Data stats
           print("\nPixel Statistics:")
           print(f"Mean: {f['images'][:].mean():.2f}")
           print(f"Std: {f['images'][:].std():.2f}")
           print(f"Connected paths: {(f['labels'][:] == 1).sum():,}")
           print(f"Disconnected paths: {(f['labels'][:] == 0).sum():,}")
            
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)
        test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)
        
        trainer = pl.Trainer(
            max_epochs=max_epochs,
            accelerator='gpu',
            devices=[1],
            precision=16,
            callbacks=[
                pl.callbacks.EarlyStopping(monitor='val_loss', patience=5),
                pl.callbacks.TQDMProgressBar(refresh_rate=1)
            ],
            logger=pl.loggers.TensorBoardLogger('runs', name=f'bigbird_{difficulty}'),
            accumulate_grad_batches=4,
            limit_train_batches=0.25,
            limit_val_batches=0.2
        )
        
        print(f"\nStarting training for {difficulty}...")
        trainer.fit(model, train_loader, val_loader)
        test_result = trainer.test(model, test_loader)
        print(f"Test results for {difficulty}: {test_result}")
        
    return model

In [None]:
torch.cuda.empty_cache()
model = train_progressive_bigbird()

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs/