In [None]:
# Data Loading Cell - Memory-mapped loading, no data in RAM
import torch
import torch.nn.functional as F
from torch_geometric.nn import HGTConv
from torch_geometric.data import HeteroData
import os
import numpy as np
import pandas as pd
import gzip
import gc
import warnings
import h5py

warnings.filterwarnings('ignore')

class DiskBasedOGBNMAG:
    """Memory-efficient OGBN-MAG dataset that keeps data on disk"""
    
    def __init__(self, data_dir='./data'):
        self.data_dir = data_dir
        self.cache_dir = os.path.join(data_dir, 'ogbn_mag_cache')
        os.makedirs(self.cache_dir, exist_ok=True)
        
        # Prepare data files
        self._prepare_data()
        
        # Open memory-mapped files
        self._open_mmap_files()
        
    def _prepare_data(self):
        """Convert raw OGBN-MAG files to memory-mapped format if needed"""
        print("📥 Preparing memory-mapped OGBN-MAG data...")
        
        # Check if already prepared
        if os.path.exists(os.path.join(self.cache_dir, 'metadata.npz')):
            print("✅ Memory-mapped data already exists")
            return
            
        # Ensure raw data exists
        ogbn_dir = os.path.join(self.data_dir, 'ogbn_mag', 'raw')
        if not os.path.exists(ogbn_dir):
            print("📦 Downloading OGBN-MAG...")
            from ogb.nodeproppred import PygNodePropPredDataset
            temp_dataset = PygNodePropPredDataset('ogbn-mag', root=self.data_dir)
            del temp_dataset
            gc.collect()
        
        print("🔄 Converting to memory-mapped format...")
        
        # Process paper features
        feat_file = os.path.join(ogbn_dir, 'node-feat', 'paper', 'node-feat.csv.gz')
        with gzip.open(feat_file, 'rt') as f:
            paper_features = pd.read_csv(f, header=None).values.astype(np.float32)
        
        # Save as memory-mapped
        mmap_feat = np.memmap(os.path.join(self.cache_dir, 'paper_features.dat'),
                             dtype='float32', mode='w+', shape=paper_features.shape)
        mmap_feat[:] = paper_features
        mmap_feat.flush()
        del paper_features, mmap_feat
        
        # Process labels
        label_file = os.path.join(ogbn_dir, 'node-label', 'paper', 'node-label.csv.gz')
        with gzip.open(label_file, 'rt') as f:
            paper_labels = pd.read_csv(f, header=None).values.flatten().astype(np.int64)
        
        mmap_labels = np.memmap(os.path.join(self.cache_dir, 'paper_labels.dat'),
                               dtype='int64', mode='w+', shape=paper_labels.shape)
        mmap_labels[:] = paper_labels
        mmap_labels.flush()
        
        # Process edges in chunks
        edge_files = {
            'cite': ('paper___cites___paper', 'edge.csv.gz'),
            'author': ('author___writes___paper', 'edge.csv.gz'),
            'field': ('paper___has_topic___field_of_study', 'edge.csv.gz')
        }
        
        edge_counts = {}
        for edge_type, (rel_dir, filename) in edge_files.items():
            edge_file = os.path.join(ogbn_dir, 'relations', rel_dir, filename)
            with gzip.open(edge_file, 'rt') as f:
                edges = pd.read_csv(f, header=None).values.T.astype(np.int64)
            
            # Save edges
            mmap_edges = np.memmap(os.path.join(self.cache_dir, f'{edge_type}_edges.dat'),
                                  dtype='int64', mode='w+', shape=edges.shape)
            mmap_edges[:] = edges
            mmap_edges.flush()
            edge_counts[edge_type] = edges.shape[1]
            del edges, mmap_edges
        
        # Calculate metadata
        num_papers = len(paper_labels)
        num_authors = 1134649  # From OGBN-MAG stats
        num_fields = 59965
        num_classes = int(paper_labels.max()) + 1
        feat_dim = mmap_feat.shape[1] if 'mmap_feat' in locals() else 128
        
        # Load splits
        split_dir = os.path.join(self.data_dir, 'ogbn_mag', 'split', 'time')
        if os.path.exists(split_dir):
            print("  Loading official splits...")
            train_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'train.csv.gz'), 
                                  header=None).values.flatten()
            val_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'valid.csv.gz'), 
                                header=None).values.flatten()
            test_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'test.csv.gz'), 
                                 header=None).values.flatten()
        else:
            indices = np.random.RandomState(42).permutation(num_papers)
            train_size = int(0.8 * num_papers)
            val_size = int(0.1 * num_papers)
            train_idx = indices[:train_size]
            val_idx = indices[train_size:train_size + val_size]
            test_idx = indices[train_size + val_size:]
        
        # Save metadata
        np.savez(os.path.join(self.cache_dir, 'metadata.npz'),
                num_papers=num_papers,
                num_authors=num_authors,
                num_fields=num_fields,
                num_classes=num_classes,
                feat_dim=feat_dim,
                train_idx=train_idx,
                val_idx=val_idx,
                test_idx=test_idx,
                edge_counts=edge_counts)
        
        print("✅ Memory-mapped data prepared!")
        gc.collect()
    
    def _open_mmap_files(self):
        """Open memory-mapped files for reading"""
        # Load metadata
        metadata = np.load(os.path.join(self.cache_dir, 'metadata.npz'), allow_pickle=True)
        self.num_papers = int(metadata['num_papers'])
        self.num_authors = int(metadata['num_authors'])
        self.num_fields = int(metadata['num_fields'])
        self.num_classes = int(metadata['num_classes'])
        self.feat_dim = int(metadata['feat_dim'])
        self.train_idx = torch.from_numpy(metadata['train_idx'])
        self.val_idx = torch.from_numpy(metadata['val_idx'])
        self.test_idx = torch.from_numpy(metadata['test_idx'])
        
        # Open memory-mapped arrays (read-only)
        self.paper_features = np.memmap(os.path.join(self.cache_dir, 'paper_features.dat'),
                                       dtype='float32', mode='r', 
                                       shape=(self.num_papers, self.feat_dim))
        
        self.paper_labels = np.memmap(os.path.join(self.cache_dir, 'paper_labels.dat'),
                                     dtype='int64', mode='r', shape=(self.num_papers,))
        
        # Open edge files
        edge_counts = metadata['edge_counts'].item()
        self.cite_edges = np.memmap(os.path.join(self.cache_dir, 'cite_edges.dat'),
                                   dtype='int64', mode='r', shape=(2, edge_counts['cite']))
        self.author_edges = np.memmap(os.path.join(self.cache_dir, 'author_edges.dat'),
                                     dtype='int64', mode='r', shape=(2, edge_counts['author']))
        self.field_edges = np.memmap(os.path.join(self.cache_dir, 'field_edges.dat'),
                                    dtype='int64', mode='r', shape=(2, edge_counts['field']))
        
    def get_paper_batch(self, indices):
        """Load a batch of papers from disk"""
        # Convert to numpy array for indexing
        if isinstance(indices, torch.Tensor):
            indices = indices.numpy()
        
        # Load only requested features and labels
        features = torch.from_numpy(self.paper_features[indices].copy())
        labels = torch.from_numpy(self.paper_labels[indices].copy())
        
        return features, labels
    
    def get_edges_for_nodes(self, node_indices, edge_type='cite'):
        """Get edges connected to specific nodes"""
        if isinstance(node_indices, torch.Tensor):
            node_indices = node_indices.numpy()
        
        # Select appropriate edge array
        if edge_type == 'cite':
            edges = self.cite_edges
        elif edge_type == 'author':
            edges = self.author_edges
        else:
            edges = self.field_edges
        
        # Find edges involving these nodes (this is still memory intensive for large graphs)
        # In production, you'd want an index structure for this
        node_set = set(node_indices.tolist())
        mask = np.array([edges[0, i] in node_set or edges[1, i] in node_set 
                        for i in range(edges.shape[1])])
        
        if mask.any():
            return torch.from_numpy(edges[:, mask].copy())
        else:
            return torch.empty(2, 0, dtype=torch.long)

# Create disk-based dataset
print("🔄 Initializing disk-based OGBN-MAG dataset...")
disk_data = DiskBasedOGBNMAG('./data')

# Create a properly initialized HeteroData structure with all node and edge types
data = HeteroData()
data.num_classes = disk_data.num_classes

# Initialize all node types with dummy data to ensure metadata() works correctly
# This is crucial for HGTConv to properly initialize its edge_types_map
print("📊 Initializing HeteroData with all node and edge types...")

# Add dummy nodes for each type (just 1 node each to establish the types)
data['paper'].x = torch.randn(1, 128)  # Paper features
data['author'].x = torch.randn(1, 128)  # Author features
data['field_of_study'].x = torch.randn(1, 64)  # Field features

# Add dummy edges for ALL edge types that HGTConv expects
# This ensures HeteroData.metadata() returns the complete edge type list
data['author', 'writes', 'paper'].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
data['paper', 'written_by', 'author'].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
data['paper', 'has_topic', 'field_of_study'].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
data['field_of_study', 'topic_of', 'paper'].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
data['paper', 'cites', 'paper'].edge_index = torch.tensor([[0], [0]], dtype=torch.long)

# Verify metadata is correct
print("✅ HeteroData metadata:")
print(f"   Node types: {data.node_types}")
print(f"   Edge types: {data.edge_types}")

# Store disk dataset reference
data._disk_data = disk_data

# Training indices
train_idx = disk_data.train_idx
val_idx = disk_data.val_idx
num_classes = disk_data.num_classes

print(f"\n✅ Disk-based data ready!")
print(f"   Papers: {disk_data.num_papers:,} (on disk)")
print(f"   Authors: {disk_data.num_authors:,}")
print(f"   Fields: {disk_data.num_fields:,}")
print(f"   Classes: {num_classes}")
print(f"   Memory usage: Minimal - data remains on disk")

# For compatibility with existing code
data_dict = {
    'num_papers': disk_data.num_papers,
    'num_authors': disk_data.num_authors,
    'num_fields': disk_data.num_fields,
    'paper_features': disk_data.paper_features,  # This is a memory-mapped array
}

In [None]:
# Model Definition and Disk-Based Sampler

# Disk-Based Memory-Efficient Sampler
class DiskBasedSampler:
    """Sampler that loads data from disk on-demand"""
    def __init__(self, disk_data, batch_size=128, num_neighbors=[15, 10]):
        self.disk_data = disk_data
        self.batch_size = batch_size
        self.num_neighbors = num_neighbors
        
        # Create edge index for fast neighbor lookup (this does use some memory)
        # In production, you'd use a graph database or specialized index
        print("Building neighbor index...")
        self._build_neighbor_index()
        
    def _build_neighbor_index(self):
        """Build a simple neighbor index for citation edges"""
        # For true disk-based, this would be saved to disk too
        # Here we just build a dict for the citation network
        self.cite_neighbors = {}
        
        # Process in chunks to limit memory
        chunk_size = 1000000
        num_edges = self.disk_data.cite_edges.shape[1]
        
        for start in range(0, num_edges, chunk_size):
            end = min(start + chunk_size, num_edges)
            edges_chunk = self.disk_data.cite_edges[:, start:end]
            
            for i in range(edges_chunk.shape[1]):
                src, dst = edges_chunk[0, i], edges_chunk[1, i]
                if dst not in self.cite_neighbors:
                    self.cite_neighbors[dst] = []
                self.cite_neighbors[dst].append(src)
        
        print(f"  Built index for {len(self.cite_neighbors)} nodes")
    
    def sample_neighbors(self, node_id, num_samples):
        """Sample neighbors for a single node"""
        if node_id not in self.cite_neighbors:
            return []
        
        neighbors = self.cite_neighbors[node_id]
        if len(neighbors) <= num_samples:
            return neighbors
        
        # Random sample
        indices = torch.randperm(len(neighbors))[:num_samples]
        return [neighbors[i] for i in indices]
    
    def create_minibatch(self, target_nodes, force_edges=False):
        """Create a minibatch with ALL required edge types in EXACT metadata order"""
        # Multi-hop sampling
        all_paper_nodes = set(target_nodes)
        current_layer = list(target_nodes)
        
        for num_samples in self.num_neighbors:
            next_layer = set()
            for node in current_layer:
                neighbors = self.sample_neighbors(node, num_samples)
                next_layer.update(neighbors)
            
            all_paper_nodes.update(next_layer)
            current_layer = list(next_layer)
        
        # Convert to list for indexing
        all_paper_nodes = list(all_paper_nodes)
        num_paper_nodes = len(all_paper_nodes)
        
        # Load features and labels from disk
        paper_features, paper_labels = self.disk_data.get_paper_batch(all_paper_nodes)
        
        # Create batch data structure
        batch = HeteroData()
        
        # Paper data
        batch['paper'].x = paper_features
        batch['paper'].y = paper_labels
        
        # Create proper dummy nodes for other types - HGTConv needs all node types
        num_authors = max(10, num_paper_nodes // 10)  # At least 10 authors
        num_fields = max(5, num_paper_nodes // 20)    # At least 5 fields
        
        batch['author'].x = torch.randn(num_authors, 128)
        batch['field_of_study'].x = torch.randn(num_fields, 64)
        
        # Create node mapping
        node_mapping = {old: new for new, old in enumerate(all_paper_nodes)}
        
        # CRITICAL: Create ALL edge types in EXACT SAME ORDER as metadata
        # HeteroData.metadata() will return edge types in the order they were added
        # So we must match that order: [('author', 'writes', 'paper'), ('paper', 'written_by', 'author'),
        #                               ('paper', 'has_topic', 'field_of_study'), ('field_of_study', 'topic_of', 'paper'), 
        #                               ('paper', 'cites', 'paper')]
        
        # 1. ('author', 'writes', 'paper') - FIRST in metadata
        author_paper_edges = []
        if force_edges or num_paper_nodes > 0:
            for i in range(min(num_authors, num_paper_nodes)):
                author_paper_edges.append([i, i % num_paper_nodes])
        
        if author_paper_edges:
            batch['author', 'writes', 'paper'].edge_index = torch.tensor(author_paper_edges).T
        else:
            batch['author', 'writes', 'paper'].edge_index = torch.empty(2, 0, dtype=torch.long)
        
        # 2. ('paper', 'written_by', 'author') - SECOND in metadata
        if author_paper_edges:
            batch['paper', 'written_by', 'author'].edge_index = torch.tensor([[e[1], e[0]] for e in author_paper_edges]).T
        else:
            batch['paper', 'written_by', 'author'].edge_index = torch.empty(2, 0, dtype=torch.long)
        
        # 3. ('paper', 'has_topic', 'field_of_study') - THIRD in metadata
        field_paper_edges = []
        if force_edges or num_paper_nodes > 0:
            for i in range(min(num_fields, num_paper_nodes)):
                field_paper_edges.append([i, i % num_fields])  # paper -> field
        
        if field_paper_edges:
            batch['paper', 'has_topic', 'field_of_study'].edge_index = torch.tensor(field_paper_edges).T
        else:
            batch['paper', 'has_topic', 'field_of_study'].edge_index = torch.empty(2, 0, dtype=torch.long)
        
        # 4. ('field_of_study', 'topic_of', 'paper') - FOURTH in metadata  
        if field_paper_edges:
            batch['field_of_study', 'topic_of', 'paper'].edge_index = torch.tensor([[e[1], e[0]] for e in field_paper_edges]).T
        else:
            batch['field_of_study', 'topic_of', 'paper'].edge_index = torch.empty(2, 0, dtype=torch.long)
        
        # 5. ('paper', 'cites', 'paper') - FIFTH in metadata (LAST!)
        cite_edges = []
        for i, node in enumerate(all_paper_nodes):
            if node in self.cite_neighbors:
                for neighbor in self.cite_neighbors[node]:
                    if neighbor in node_mapping:
                        cite_edges.append([node_mapping[neighbor], node_mapping[node]])
        
        if cite_edges:
            batch['paper', 'cites', 'paper'].edge_index = torch.tensor(cite_edges).T
        elif force_edges and num_paper_nodes >= 2:
            batch['paper', 'cites', 'paper'].edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
        else:
            batch['paper', 'cites', 'paper'].edge_index = torch.empty(2, 0, dtype=torch.long)
        
        # Mark target nodes
        target_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
        for node in target_nodes:
            if node in node_mapping:
                target_mask[node_mapping[node]] = True
        batch['paper'].target_mask = target_mask
        
        return batch
    
    def get_batches(self, indices, shuffle=True):
        """Generate batches from indices"""
        if shuffle:
            perm = torch.randperm(len(indices))
            indices = indices[perm]
        
        for i in range(0, len(indices), self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            yield self.create_minibatch(batch_indices.tolist())

# Research-Optimized HGT Model with better initialization
class ResearchOptimalHGT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, metadata, heads=8, dropout=0.6, num_layers=3):
        super().__init__()
        self.num_layers = num_layers
        self.dropout = torch.nn.Dropout(dropout)
        
        # Store metadata for initialization
        self.node_types = metadata[0]
        self.edge_types = metadata[1]
        
        # Define input dimensions for each node type
        self.in_dims = {
            'paper': 128,  # OGBN-MAG paper features
            'author': 128,  # Dummy features
            'field_of_study': 64  # Dummy features
        }
        
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        self.residual_projs = torch.nn.ModuleList()
        
        # First layer - use actual dimensions
        self.convs.append(HGTConv(self.in_dims, hidden_dim, metadata, heads=heads))
        self.norms.append(torch.nn.LayerNorm(hidden_dim))
        
        # Hidden layers
        for i in range(num_layers - 2):
            self.convs.append(HGTConv(hidden_dim, hidden_dim, metadata, heads=heads))
            self.norms.append(torch.nn.LayerNorm(hidden_dim))
            self.residual_projs.append(torch.nn.Linear(hidden_dim, hidden_dim))
        
        # Output layer
        self.convs.append(HGTConv(hidden_dim, out_dim, metadata, heads=1))
        
        self.use_residual = num_layers > 2
        
    def forward(self, x_dict, edge_index_dict):
        # Ensure all node types are present
        for node_type in self.node_types:
            if node_type not in x_dict:
                raise ValueError(f"Missing node type '{node_type}' in x_dict")
        
        # Ensure all edge types from metadata are present in the batch
        for edge_type in self.edge_types:
            if edge_type not in edge_index_dict:
                print(f"Warning: Missing edge type {edge_type}, adding empty tensor")
                edge_index_dict[edge_type] = torch.empty(2, 0, dtype=torch.long, device=list(x_dict.values())[0].device)
        
        # First layer
        x_dict = self.convs[0](x_dict, edge_index_dict)
        x_dict = {key: self.norms[0](x) for key, x in x_dict.items()}
        x_dict = {key: F.leaky_relu(x, negative_slope=0.2) for key, x in x_dict.items()}
        x_dict = {key: self.dropout(x) for key, x in x_dict.items()}
        
        # Hidden layers with residual
        for i in range(1, self.num_layers - 1):
            if self.use_residual:
                x_dict_res = {k: v.clone() for k, v in x_dict.items()}
            
            x_dict = self.convs[i](x_dict, edge_index_dict)
            x_dict = {key: self.norms[i](x) for key, x in x_dict.items()}
            x_dict = {key: F.leaky_relu(x, negative_slope=0.2) for key, x in x_dict.items()}
            x_dict = {key: self.dropout(x) for key, x in x_dict.items()}
            
            if self.use_residual:
                for key in x_dict.keys():
                    if key in x_dict_res:
                        residual = self.residual_projs[i-1](x_dict_res[key])
                        x_dict[key] = x_dict[key] + residual
        
        # Output layer
        x_dict = self.convs[-1](x_dict, edge_index_dict)
        
        return x_dict

print("✅ Model and disk-based sampler classes defined!")

# The HeteroData object 'data' now has proper metadata built-in from the dummy edges we added
# No need to override data.metadata anymore - it will return the correct metadata automatically

In [28]:
# FINAL TRAINING: Multi-GPU with Disk-Based Loading
# Uses GPUs 1 and 2 - Two RTX 2060 SUPER GPUs

import time
from datetime import datetime
from torch.nn.parallel import DataParallel
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm

print("🚀 FINAL MULTI-GPU TRAINING WITH DISK-BASED LOADING")
print("=" * 60)

# Use GPUs 1 and 2 - Two RTX 2060 SUPER GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

# Disable NCCL for stability
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '0'

torch.cuda.set_device(0)  # Device 0 now maps to physical GPU 1

# Verify GPU setup
num_gpus = torch.cuda.device_count()
print(f"Using {num_gpus} GPUs:")
for i in range(num_gpus):
    print(f"  Device {i}: {torch.cuda.get_device_name(i)} ({torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f}GB)")

# Force multi-GPU usage
use_multi_gpu = True

# FINAL CONFIGURATION - Optimized for two RTX 2060 SUPER GPUs
final_config = {
    # Research-critical parameters
    'hidden_dim': 256,
    'heads': 8,
    'dropout': 0.6,
    'num_layers': 3,
    'lr': 0.005,
    'weight_decay': 5e-4,
    'gradient_clip': 1.0,
    'label_smoothing': 0.1,
    'num_neighbors': [25, 20, 15],
    
    # GPU optimization for two RTX 2060 SUPER GPUs
    'batch_size_per_gpu': 192,  # Conservative for ~7.6GB GPUs
    'accumulation_steps': 4,    # Maintain effective batch size
    'use_amp': True,
    
    # Training parameters
    'max_epochs': 50,
    'validation_frequency': 5,
    'early_stopping_patience': 10,
    'checkpoint_dir': './final_checkpoints',
}

os.makedirs(final_config['checkpoint_dir'], exist_ok=True)

print("\n📊 Configuration:")
print(f"  Total batch size: {final_config['batch_size_per_gpu'] * num_gpus}")
print(f"  Data loading: From disk on-demand")
print(f"  Memory usage: Minimal")
print(f"  GPU setup: Two RTX 2060 SUPER GPUs")

# Create model
device = torch.device('cuda:0')
print("\n🧠 Setting up model...")

# Adjust num_classes if needed
if num_classes % final_config['heads'] != 0:
    adjusted_classes = ((num_classes + final_config['heads'] - 1) // final_config['heads']) * final_config['heads']
    print(f"   Adjusting classes: {num_classes} → {adjusted_classes}")
    num_classes = adjusted_classes

# DEBUG: Verify metadata before model creation
metadata = data.metadata()
print(f"   Model metadata: {metadata}")
print(f"   Node types: {metadata[0]}")
print(f"   Edge types: {metadata[1]}")

# Create research-optimal model with explicit metadata
model = ResearchOptimalHGT(
    in_dim=None,  # Not used anymore, dimensions are hardcoded in the model
    hidden_dim=final_config['hidden_dim'],
    out_dim=num_classes,
    metadata=metadata,  # Pass the actual metadata tuple
    heads=final_config['heads'],
    dropout=final_config['dropout'],
    num_layers=final_config['num_layers']
)

# Move model to GPU first
model = model.to(device)

# Initialize model with a batch that has edges
print("   Initializing model with sample batch...")
try:
    with torch.no_grad():
        # Get a sample batch with forced edges for initialization
        sample_sampler = DiskBasedSampler(disk_data, batch_size=64, num_neighbors=[5, 5])
        sample_batch = sample_sampler.create_minibatch(train_idx[:64].tolist(), force_edges=True)
        sample_batch = sample_batch.to(device)
        
        # Debug: Check edge types in batch
        print(f"   Sample batch edge types: {list(sample_batch.edge_index_dict.keys())}")
        print(f"   Expected edge types: {metadata[1]}")
        
        # Debug: Check HGTConv edge types map
        hgt_conv = model.convs[0]
        print(f"   HGTConv edge_types_map: {hgt_conv.edge_types_map}")
        
        # Run forward pass to initialize lazy modules
        _ = model(sample_batch.x_dict, sample_batch.edge_index_dict)
        print("   ✅ Model initialized successfully")
except Exception as e:
    print(f"   ❌ Model initialization failed: {e}")
    import traceback
    traceback.print_exc()
    raise

# Set up DataParallel for two RTX 2060 SUPER GPUs
if use_multi_gpu and num_gpus > 1:
    try:
        print("\n🔧 Setting up DataParallel for two RTX 2060 SUPER GPUs...")
        
        # Use all available GPUs
        device_ids = list(range(num_gpus))
        print(f"   Using device IDs: {device_ids}")
        
        # Wrap model with DataParallel
        model = DataParallel(model, device_ids=device_ids)
        
        # Quick test forward pass
        print("   Testing DataParallel setup...")
        with torch.no_grad():
            test_sampler = DiskBasedSampler(disk_data, batch_size=32, num_neighbors=[3, 3])
            test_batch = test_sampler.create_minibatch(train_idx[:32].tolist(), force_edges=True)
            test_batch = test_batch.to(device)
            _ = model(test_batch.x_dict, test_batch.edge_index_dict)
        
        print(f"✅ Model distributed across {num_gpus} RTX 2060 SUPER GPUs")
        actual_batch_size = final_config['batch_size_per_gpu'] * num_gpus
        
    except Exception as e:
        print(f"⚠️  DataParallel failed: {e}")
        print("📌 Falling back to single GPU")
        import traceback
        traceback.print_exc()
        
        # Remove DataParallel wrapper if it was added
        if hasattr(model, 'module'):
            model = model.module
        
        use_multi_gpu = False
        actual_batch_size = final_config['batch_size_per_gpu']
else:
    print("\n📌 Using single GPU training")
    actual_batch_size = final_config['batch_size_per_gpu']

print(f"   Final batch size: {actual_batch_size}")
print(f"   Multi-GPU enabled: {use_multi_gpu}")

# Optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=final_config['lr'],
    weight_decay=final_config['weight_decay']
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10, T_mult=2, eta_min=1e-5
)

# Create disk-based samplers
print("\n📊 Creating disk-based data samplers...")
train_sampler = DiskBasedSampler(
    disk_data,
    batch_size=actual_batch_size,
    num_neighbors=final_config['num_neighbors']
)

val_sampler = DiskBasedSampler(
    disk_data,
    batch_size=actual_batch_size,
    num_neighbors=final_config['num_neighbors']
)

# Mixed precision scaler
scaler = GradScaler() if final_config['use_amp'] else None

# Memory monitoring for two RTX 2060 SUPER GPUs
def get_gpu_memory_str():
    if use_multi_gpu and num_gpus > 1:
        mem_strs = []
        for i in range(num_gpus):
            alloc = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            total = torch.cuda.get_device_properties(i).total_memory / 1024**3
            mem_strs.append(f"GPU{i}: {alloc:.1f}/{total:.1f}GB")
        return " | ".join(mem_strs)
    else:
        alloc = torch.cuda.memory_allocated(0) / 1024**3
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        return f"GPU0: {alloc:.1f}/{total:.1f}GB"

# Training function
def train_epoch(epoch):
    model.train()
    total_loss = 0
    total_examples = 0
    
    batches_per_epoch = min(800, len(train_idx) // actual_batch_size)
    optimizer.zero_grad()
    
    pbar = tqdm(range(batches_per_epoch), desc=f'Epoch {epoch}')
    
    for batch_idx, batch in enumerate(train_sampler.get_batches(train_idx, shuffle=True)):
        if batch_idx >= batches_per_epoch:
            break
            
        try:
            batch = batch.to(device, non_blocking=True)
            
            # Mixed precision forward
            with autocast(enabled=final_config['use_amp']):
                out_dict = model(batch.x_dict, batch.edge_index_dict)
                
                target_mask = batch['paper'].target_mask
                if target_mask.sum() == 0:
                    continue
                
                paper_out = out_dict['paper'][target_mask][:, :num_classes]
                paper_labels = batch['paper'].y[target_mask]
                
                loss = F.cross_entropy(
                    paper_out, 
                    paper_labels, 
                    label_smoothing=final_config['label_smoothing']
                )
                loss = loss / final_config['accumulation_steps']
            
            # Backward
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            
            # Gradient accumulation and step
            if (batch_idx + 1) % final_config['accumulation_steps'] == 0:
                if scaler:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), final_config['gradient_clip'])
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), final_config['gradient_clip'])
                    optimizer.step()
                
                optimizer.zero_grad()
            
            # Metrics
            batch_size = target_mask.sum().item()
            total_loss += float(loss) * batch_size * final_config['accumulation_steps']
            total_examples += batch_size
            
            # Update progress
            pbar.update(1)
            if batch_idx % 50 == 0:
                pbar.set_postfix({
                    'loss': f'{total_loss/max(1, total_examples):.4f}',
                    'lr': f'{optimizer.param_groups[0]["lr"]:.6f}',
                    'mem': get_gpu_memory_str()
                })
            
            # Periodic cache clearing
            if batch_idx % 100 == 0:
                torch.cuda.empty_cache()
                
        except Exception as e:
            print(f"\nError in batch {batch_idx}: {e}")
            if "out of memory" in str(e).lower():
                print("⚠️  GPU OOM detected. Clearing cache and continuing...")
                torch.cuda.empty_cache()
            import traceback
            traceback.print_exc()
            continue
    
    pbar.close()
    return total_loss / max(1, total_examples)

@torch.no_grad()
def validate():
    model.eval()
    total_loss = 0
    total_correct = 0
    total_examples = 0
    
    val_batches = min(100, len(val_idx) // actual_batch_size)
    
    for i, batch in enumerate(tqdm(val_sampler.get_batches(val_idx, shuffle=False), 
                                  desc='Validating', total=val_batches)):
        if i >= val_batches:
            break
            
        try:
            batch = batch.to(device, non_blocking=True)
            
            with autocast(enabled=final_config['use_amp']):
                out_dict = model(batch.x_dict, batch.edge_index_dict)
                    
                target_mask = batch['paper'].target_mask
                
                if target_mask.sum() == 0:
                    continue
                
                paper_out = out_dict['paper'][target_mask][:, :num_classes]
                paper_labels = batch['paper'].y[target_mask]
                loss = F.cross_entropy(paper_out, paper_labels)
                
                pred = paper_out.argmax(dim=-1)
                correct = (pred == paper_labels).sum().item()
            
            batch_size = target_mask.sum().item()
            total_loss += float(loss) * batch_size
            total_correct += correct
            total_examples += batch_size
            
        except Exception as e:
            continue
    
    val_loss = total_loss / max(1, total_examples)
    val_acc = total_correct / max(1, total_examples)
    return val_loss, val_acc

# Checkpoint management
def save_checkpoint(epoch, train_loss, val_loss, val_acc, is_best=False):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_acc': val_acc,
        'config': final_config,
    }
    
    torch.save(checkpoint, os.path.join(final_config['checkpoint_dir'], 'latest.pt'))
    
    if is_best:
        torch.save(checkpoint, os.path.join(final_config['checkpoint_dir'], 'best.pt'))
        print(f"💾 New best model saved! Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

def load_checkpoint():
    checkpoint_path = os.path.join(final_config['checkpoint_dir'], 'latest.pt')
    if os.path.exists(checkpoint_path):
        print(f"📂 Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        if hasattr(model, 'module'):
            model.module.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return checkpoint['epoch']
    return 0

# MAIN TRAINING LOOP
print("\n" + "="*60)
print("🏃 STARTING DISK-BASED TRAINING")
print(f"   Mode: {'Multi-GPU' if use_multi_gpu else 'Single GPU'}")
print(f"   Batch size: {actual_batch_size}")
print(f"   GPUs: {num_gpus} RTX 2060 SUPER")
print("="*60)

# Try to resume from checkpoint
start_epoch = load_checkpoint()

# Clear GPU cache before training
torch.cuda.empty_cache()

# Training history
best_val_loss = float('inf')
best_val_acc = 0.0
patience_counter = 0
training_start = datetime.now()

for epoch in range(start_epoch + 1, final_config['max_epochs'] + 1):
    epoch_start = time.time()
    
    # Train
    train_loss = train_epoch(epoch)
    
    # Step scheduler
    scheduler.step()
    
    # Validate periodically
    if epoch % final_config['validation_frequency'] == 0:
        val_loss, val_acc = validate()
        
        # Check if best
        is_best = val_loss < best_val_loss
        if is_best:
            best_val_loss = val_loss
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Save checkpoint
        save_checkpoint(epoch, train_loss, val_loss, val_acc, is_best)
        
        # Print summary
        print(f"\n{'='*60}")
        print(f"Epoch {epoch} Summary:")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f} {'🏆 NEW BEST!' if is_best else ''}")
        print(f"  Val Accuracy: {val_acc:.4%}")
        print(f"  Time: {time.time() - epoch_start:.1f}s")
        print(f"  Memory: {get_gpu_memory_str()}")
        print(f"{'='*60}\n")
        
        # Early stopping
        if patience_counter >= final_config['early_stopping_patience']:
            print("🛑 Early stopping triggered!")
            break
    else:
        print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Time={time.time() - epoch_start:.1f}s")
    
    # Clear cache periodically
    if epoch % 5 == 0:
        torch.cuda.empty_cache()

# Training complete
total_time = (datetime.now() - training_start).total_seconds()
print(f"\n{'='*60}")
print(f"✅ TRAINING COMPLETE!")
print(f"  Total time: {total_time/3600:.2f} hours")
print(f"  Best validation loss: {best_val_loss:.4f}")
print(f"  Best validation accuracy: {best_val_acc:.4%}")
print(f"  Memory: {get_gpu_memory_str()}")
print(f"{'='*60}")

# Reset CUDA device selection
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

🚀 FINAL MULTI-GPU TRAINING WITH DISK-BASED LOADING
Using 2 GPUs:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (95.0GB)
  Device 1: NVIDIA GeForce RTX 2060 SUPER (7.6GB)

📊 Configuration:
  Total batch size: 384
  Data loading: From disk on-demand
  Memory usage: Minimal
  GPU setup: Two RTX 2060 SUPER GPUs

🧠 Setting up model...
   Adjusting classes: 349 → 352
   Model metadata: ([], [])
   Node types: []
   Edge types: []
   Initializing model with sample batch...
Building neighbor index...
  Built index for 629169 nodes
   Sample batch edge types: [('author', 'writes', 'paper'), ('paper', 'written_by', 'author'), ('paper', 'has_topic', 'field_of_study'), ('field_of_study', 'topic_of', 'paper'), ('paper', 'cites', 'paper')]
   Expected edge types: []
   HGTConv edge_types_map: {}
   ❌ Model initialization failed: ('author', 'writes', 'paper')


Traceback (most recent call last):
  File "/tmp/ipykernel_58053/2799755780.py", line 112, in <module>
    _ = model(sample_batch.x_dict, sample_batch.edge_index_dict)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_58053/4268810785.py", line 213, in forward
    x_dict = self.convs[0](x_dict, edge_index_dict)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local

KeyError: ('author', 'writes', 'paper')