In [None]:
# OGBN-MAG Minibatch Training - Pragmatic Approach (Fixed Init)
# Simple data loading + PyG inheritance where needed = Working solution!

import torch
import torch.nn.functional as F
from torch_geometric.nn import HGTConv
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import HeteroData
import os
import h5py
import numpy as np
import pandas as pd
import gzip
import psutil
import gc
import warnings
import traceback

warnings.filterwarnings('ignore')
print("🚀 Starting pragmatic OGBN-MAG training...")

# --- 1. Super Simple Data Loader (Keep This Clean!) ---
def ensure_ogbn_data_exists(data_dir):
    """Check if OGBN-MAG raw files exist, download if needed"""
    ogbn_dir = os.path.join(data_dir, 'ogbn_mag', 'raw')
    
    if not os.path.exists(ogbn_dir):
        print("📦 Downloading OGBN-MAG (this may take a while)...")
        from ogb.nodeproppred import PygNodePropPredDataset
        
        # Download to a temp location to avoid framework issues
        temp_dataset = PygNodePropPredDataset('ogbn-mag', root=data_dir)
        print("✅ Download complete!")
        del temp_dataset
        gc.collect()
    
    return ogbn_dir

def load_ogbn_simple(data_dir):
    """Load OGBN-MAG data directly from files - no frameworks!"""
    print("📥 Loading OGBN-MAG from raw files...")
    
    # Ensure data exists
    raw_dir = ensure_ogbn_data_exists(data_dir)
    
    # Load paper features
    feat_file = os.path.join(raw_dir, 'node-feat', 'paper', 'node-feat.csv.gz')
    print("  Loading paper features...")
    with gzip.open(feat_file, 'rt') as f:
        paper_features = pd.read_csv(f, header=None).values.astype(np.float32)
    
    # Load paper labels
    label_file = os.path.join(raw_dir, 'node-label', 'paper', 'node-label.csv.gz')
    print("  Loading paper labels...")
    with gzip.open(label_file, 'rt') as f:
        paper_labels = pd.read_csv(f, header=None).values.flatten().astype(np.int64)
    
    # Load citation edges
    cite_file = os.path.join(raw_dir, 'relations', 'paper___cites___paper', 'edge.csv.gz')
    print("  Loading citation edges...")
    with gzip.open(cite_file, 'rt') as f:
        cite_edges = pd.read_csv(f, header=None).values.T.astype(np.int64)
    
    # Load author-paper edges (for heterogeneous graph)
    author_file = os.path.join(raw_dir, 'relations', 'author___writes___paper', 'edge.csv.gz')
    print("  Loading author-paper edges...")
    with gzip.open(author_file, 'rt') as f:
        author_paper_edges = pd.read_csv(f, header=None).values.T.astype(np.int64)
    
    # Load field-paper edges
    field_file = os.path.join(raw_dir, 'relations', 'paper___has_topic___field_of_study', 'edge.csv.gz')
    print("  Loading field-paper edges...")
    with gzip.open(field_file, 'rt') as f:
        field_paper_edges = pd.read_csv(f, header=None).values.T.astype(np.int64)
    
    # Calculate node counts
    num_papers = len(paper_features)
    num_authors = author_paper_edges[0].max() + 1
    num_fields = field_paper_edges[1].max() + 1
    num_classes = int(paper_labels.max()) + 1
    
    # Create train/val/test splits (use official split if available)
    split_dir = os.path.join(data_dir, 'ogbn_mag', 'split', 'time')
    if os.path.exists(split_dir):
        print("  Loading official splits...")
        train_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'train.csv.gz'), header=None).values.flatten()
        val_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'valid.csv.gz'), header=None).values.flatten()
        test_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'test.csv.gz'), header=None).values.flatten()
    else:
        print("  Creating random splits...")
        indices = np.random.RandomState(42).permutation(num_papers)
        train_size = int(0.8 * num_papers)
        val_size = int(0.1 * num_papers)
        train_idx = indices[:train_size]
        val_idx = indices[train_size:train_size + val_size]
        test_idx = indices[train_size + val_size:]
    
    print(f"✅ Loaded: {num_papers} papers, {num_authors} authors, {num_fields} fields")
    print(f"   Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")
    
    return {
        'paper_features': torch.from_numpy(paper_features),
        'paper_labels': torch.from_numpy(paper_labels),
        'paper_author_edges': torch.from_numpy(author_paper_edges),
        'paper_field_edges': torch.from_numpy(field_paper_edges),
        'paper_cite_edges': torch.from_numpy(cite_edges),
        'train_idx': torch.from_numpy(train_idx),
        'val_idx': torch.from_numpy(val_idx),
        'test_idx': torch.from_numpy(test_idx),
        'num_papers': num_papers,
        'num_authors': num_authors,
        'num_fields': num_fields,
        'num_classes': num_classes
    }

# --- 2. Create HeteroData the PyG Way ---
def create_pyg_hetero_data(data_dict):
    """Create HeteroData using PyG's standard approach"""
    print("🔗 Setting up PyG-compatible heterogeneous data...")
    
    data = HeteroData()
    
    # Set node features
    data['paper'].x = data_dict['paper_features']
    data['paper'].y = data_dict['paper_labels']
    data['author'].x = torch.randn(data_dict['num_authors'], 128)
    data['field_of_study'].x = torch.randn(data_dict['num_fields'], 64)
    
    # Set edge indices
    data['author', 'writes', 'paper'].edge_index = data_dict['paper_author_edges'].contiguous()
    data['paper', 'written_by', 'author'].edge_index = data_dict['paper_author_edges'].flip(0).contiguous()
    data['paper', 'has_topic', 'field_of_study'].edge_index = data_dict['paper_field_edges'].contiguous()
    data['field_of_study', 'topic_of', 'paper'].edge_index = data_dict['paper_field_edges'].flip(0).contiguous()
    data['paper', 'cites', 'paper'].edge_index = data_dict['paper_cite_edges'].contiguous()
    
    # Set train/val/test masks
    num_papers = data_dict['num_papers']
    train_mask = torch.zeros(num_papers, dtype=torch.bool)
    val_mask = torch.zeros(num_papers, dtype=torch.bool)
    test_mask = torch.zeros(num_papers, dtype=torch.bool)
    
    train_mask[data_dict['train_idx']] = True
    val_mask[data_dict['val_idx']] = True
    test_mask[data_dict['test_idx']] = True
    
    data['paper'].train_mask = train_mask
    data['paper'].val_mask = val_mask
    data['paper'].test_mask = test_mask
    
    # Store additional info we need
    data.num_classes = data_dict['num_classes']
    data.train_idx = data_dict['train_idx']
    
    print(f"✅ PyG hetero data ready!")
    print(f"   Node types: {data.node_types}")
    print(f"   Edge types: {data.edge_types}")
    
    return data

# --- 3. Simple Batch Manager (Keep This!) ---
class SimpleBatchManager:
    def __init__(self, initial_size=256, target_memory_gb=6.0):
        self.batch_size = initial_size
        self.target_memory = target_memory_gb
        self.memory_history = []
    
    def update(self, current_memory_gb):
        self.memory_history.append(current_memory_gb)
        
        if len(self.memory_history) < 3:
            return False
        
        avg_memory = sum(self.memory_history[-3:]) / 3
        old_size = self.batch_size
        
        if avg_memory > self.target_memory * 0.85:
            self.batch_size = max(32, int(self.batch_size * 0.8))
        elif avg_memory < self.target_memory * 0.6:
            self.batch_size = int(self.batch_size * 1.1)
        
        if self.batch_size != old_size:
            print(f"🔄 Batch size: {old_size} → {self.batch_size} (memory: {avg_memory:.1f}GB)")
            return True
        return False

def get_memory_usage():
    """Get current memory usage in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024 / 1024

# --- 4. Load Data ---
print("\n📂 Loading OGBN-MAG data...")
data_dict = load_ogbn_simple('./data')
data = create_pyg_hetero_data(data_dict)

# Get training indices
train_idx = data.train_idx
print(f"🎯 Training on {len(train_idx)} papers")

# --- 5. Simple Model ---
print("\n🧠 Creating model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"   Device: {device}")

num_classes = data.num_classes
hidden_dim = 128
heads = 4

# Ensure compatibility
if num_classes % heads != 0:
    adjusted_classes = ((num_classes + heads - 1) // heads) * heads
    print(f"   Adjusting classes: {num_classes} → {adjusted_classes}")
    num_classes = adjusted_classes

class SimpleHGT(torch.nn.Module):
    def __init__(self, hidden_dim, out_dim, metadata, heads=4):
        super().__init__()
        self.conv1 = HGTConv(-1, hidden_dim, metadata, heads=heads)
        self.conv2 = HGTConv(hidden_dim, out_dim, metadata, heads=heads)
        self.dropout = torch.nn.Dropout(0.3)
    
    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.dropout(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

model = SimpleHGT(hidden_dim, num_classes, data.metadata(), heads=heads).to(device)

# Initialize lazy parameters with dummy forward pass
print("   Initializing model parameters...")
with torch.no_grad():
    # Create dummy batch with correct shapes
    dummy_x_dict = {}
    for node_type in data.node_types:
        if node_type == 'paper':
            dummy_x_dict[node_type] = torch.randn(10, data_dict['paper_features'].shape[1]).to(device)
        elif node_type == 'author':
            dummy_x_dict[node_type] = torch.randn(10, 128).to(device)
        else:  # field_of_study
            dummy_x_dict[node_type] = torch.randn(10, 64).to(device)
    
    # Create dummy edge indices
    dummy_edge_dict = {}
    for edge_type in data.edge_types:
        dummy_edge_dict[edge_type] = torch.randint(0, 10, (2, 20)).to(device)
    
    # Dummy forward pass to initialize parameters
    _ = model(dummy_x_dict, dummy_edge_dict)

print("✅ Model parameters initialized!")

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

# --- 6. Fixed Training Function ---
batch_manager = SimpleBatchManager(initial_size=128, target_memory_gb=6.0)

def train_epoch():
    model.train()
    total_loss = total_examples = 0
    
    print("   Creating NeighborLoader...")
    try:
        # Create loader - should work now with proper HeteroData inheritance!
        loader = NeighborLoader(
            data,
            num_neighbors=[15, 10],
            batch_size=batch_manager.batch_size,
            input_nodes=('paper', train_idx),
            shuffle=True,
            num_workers=0
        )
        print("   ✅ NeighborLoader created successfully")
    except Exception as e:
        print(f"   ❌ Failed to create NeighborLoader: {e}")
        traceback.print_exc()
        raise e
    
    batch_count = 0
    print("   Starting batch iteration...")
    
    for batch in loader:
        batch_count += 1
        if batch_count > 5:  # Only process 5 batches for debugging
            break
            
        try:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            out_dict = model(batch.x_dict, batch.edge_index_dict)
            
            # Get the number of target nodes (papers) in this batch
            if hasattr(batch['paper'], 'batch_size'):
                batch_size = batch['paper'].batch_size
            else:
                # Fallback: count actual target nodes
                batch_size = batch['paper'].y.size(0)
            
            # Loss calculation - Use original class count
            original_classes = data.num_classes
            paper_out = out_dict['paper'][:batch_size, :original_classes]
            paper_labels = batch['paper'].y[:batch_size]
            
            loss = F.cross_entropy(paper_out, paper_labels)
            
            loss.backward()
            optimizer.step()
            
            total_loss += float(loss) * batch_size
            total_examples += batch_size
            
            if batch_count == 1:
                print(f"     First batch: Loss={loss:.4f}, Batch size={batch_size}")
            
            # Memory management
            if batch_count % 10 == 0:
                gc.collect()
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
        
        except Exception as e:
            print(f"   Error in batch {batch_count}: {e}")
            traceback.print_exc()
            raise e
    
    avg_loss = total_loss / max(1, total_examples)
    print(f"   Processed {batch_count} batches, {total_examples} examples, avg loss: {avg_loss:.4f}")
    return avg_loss, False

# --- 7. Training Loop ---
print("\n🏃 Training...")

for epoch in range(1, 4):
    print(f"\n=== Epoch {epoch} ===")
    
    try:
        loss, _ = train_epoch()
        memory_usage = get_memory_usage()
        print(f"✅ Epoch {epoch}: Loss={loss:.4f}, Memory={memory_usage:.1f}GB")
    except Exception as e:
        print(f"❌ Error in epoch {epoch}: {e}")
        break

print(f"\n🎉 Training complete!")
print(f"Final memory: {get_memory_usage():.1f}GB")

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'batch_size': batch_manager.batch_size,
    'hidden_dim': hidden_dim,
    'num_classes': num_classes,
    'heads': heads
}, 'pragmatic_hgt_model.pt')

print("💾 Model saved as 'pragmatic_hgt_model.pt'")
print("✅ Minibatch training working!")