In [None]:
# OGBN-MAG Minibatch Training - Memory-Efficient Sampling
# True minibatch training without loading full graph into memory

import torch
import torch.nn.functional as F
from torch_geometric.nn import HGTConv
from torch_geometric.data import HeteroData
import os
import numpy as np
import pandas as pd
import gzip
import psutil
import gc
import warnings

warnings.filterwarnings('ignore')
print("🚀 Starting memory-efficient OGBN-MAG minibatch training...")

# --- 1. Data Loading (unchanged) ---
def ensure_ogbn_data_exists(data_dir):
    """Check if OGBN-MAG raw files exist, download if needed"""
    ogbn_dir = os.path.join(data_dir, 'ogbn_mag', 'raw')
    
    if not os.path.exists(ogbn_dir):
        print("📦 Downloading OGBN-MAG (this may take a while)...")
        from ogb.nodeproppred import PygNodePropPredDataset
        
        # Download to a temp location to avoid framework issues
        temp_dataset = PygNodePropPredDataset('ogbn-mag', root=data_dir)
        print("✅ Download complete!")
        del temp_dataset
        gc.collect()
    
    return ogbn_dir

def load_ogbn_simple(data_dir):
    """Load OGBN-MAG data directly from files - no frameworks!"""
    print("📥 Loading OGBN-MAG from raw files...")
    
    # Ensure data exists
    raw_dir = ensure_ogbn_data_exists(data_dir)
    
    # Load paper features
    feat_file = os.path.join(raw_dir, 'node-feat', 'paper', 'node-feat.csv.gz')
    print("  Loading paper features...")
    with gzip.open(feat_file, 'rt') as f:
        paper_features = pd.read_csv(f, header=None).values.astype(np.float32)
    
    # Load paper labels
    label_file = os.path.join(raw_dir, 'node-label', 'paper', 'node-label.csv.gz')
    print("  Loading paper labels...")
    with gzip.open(label_file, 'rt') as f:
        paper_labels = pd.read_csv(f, header=None).values.flatten().astype(np.int64)
    
    # Load citation edges
    cite_file = os.path.join(raw_dir, 'relations', 'paper___cites___paper', 'edge.csv.gz')
    print("  Loading citation edges...")
    with gzip.open(cite_file, 'rt') as f:
        cite_edges = pd.read_csv(f, header=None).values.T.astype(np.int64)
    
    # Load author-paper edges (for heterogeneous graph)
    author_file = os.path.join(raw_dir, 'relations', 'author___writes___paper', 'edge.csv.gz')
    print("  Loading author-paper edges...")
    with gzip.open(cite_file, 'rt') as f:
        author_paper_edges = pd.read_csv(f, header=None).values.T.astype(np.int64)
    
    # Load field-paper edges
    field_file = os.path.join(raw_dir, 'relations', 'paper___has_topic___field_of_study', 'edge.csv.gz')
    print("  Loading field-paper edges...")
    with gzip.open(field_file, 'rt') as f:
        field_paper_edges = pd.read_csv(f, header=None).values.T.astype(np.int64)
    
    # Calculate node counts
    num_papers = len(paper_features)
    num_authors = author_paper_edges[0].max() + 1
    num_fields = field_paper_edges[1].max() + 1
    num_classes = int(paper_labels.max()) + 1
    
    # Create train/val/test splits (use official split if available)
    split_dir = os.path.join(data_dir, 'ogbn_mag', 'split', 'time')
    if os.path.exists(split_dir):
        print("  Loading official splits...")
        train_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'train.csv.gz'), header=None).values.flatten()
        val_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'valid.csv.gz'), header=None).values.flatten()
        test_idx = pd.read_csv(os.path.join(split_dir, 'paper', 'test.csv.gz'), header=None).values.flatten()
    else:
        print("  Creating random splits...")
        indices = np.random.RandomState(42).permutation(num_papers)
        train_size = int(0.8 * num_papers)
        val_size = int(0.1 * num_papers)
        train_idx = indices[:train_size]
        val_idx = indices[train_size:train_size + val_size]
        test_idx = indices[train_size + val_size:]
    
    print(f"✅ Loaded: {num_papers} papers, {num_authors} authors, {num_fields} fields")
    print(f"   Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")
    
    return {
        'paper_features': torch.from_numpy(paper_features),
        'paper_labels': torch.from_numpy(paper_labels),
        'paper_author_edges': torch.from_numpy(author_paper_edges),
        'paper_field_edges': torch.from_numpy(field_paper_edges),
        'paper_cite_edges': torch.from_numpy(cite_edges),
        'train_idx': torch.from_numpy(train_idx),
        'val_idx': torch.from_numpy(val_idx),
        'test_idx': torch.from_numpy(test_idx),
        'num_papers': num_papers,
        'num_authors': num_authors,
        'num_fields': num_fields,
        'num_classes': num_classes
    }

# --- 2. Create HeteroData (unchanged) ---
def create_pyg_hetero_data(data_dict):
    """Create HeteroData using PyG's standard approach"""
    print("🔗 Setting up PyG-compatible heterogeneous data...")
    
    data = HeteroData()
    
    # Set node features
    data['paper'].x = data_dict['paper_features']
    data['paper'].y = data_dict['paper_labels']
    data['author'].x = torch.randn(data_dict['num_authors'], 128)
    data['field_of_study'].x = torch.randn(data_dict['num_fields'], 64)
    
    # Set edge indices
    data['author', 'writes', 'paper'].edge_index = data_dict['paper_author_edges'].contiguous()
    data['paper', 'written_by', 'author'].edge_index = data_dict['paper_author_edges'].flip(0).contiguous()
    data['paper', 'has_topic', 'field_of_study'].edge_index = data_dict['paper_field_edges'].contiguous()
    data['field_of_study', 'topic_of', 'paper'].edge_index = data_dict['paper_field_edges'].flip(0).contiguous()
    data['paper', 'cites', 'paper'].edge_index = data_dict['paper_cite_edges'].contiguous()
    
    # Store additional info we need
    data.num_classes = data_dict['num_classes']
    data.train_idx = data_dict['train_idx']
    
    print(f"✅ PyG hetero data ready!")
    print(f"   Node types: {data.node_types}")
    print(f"   Edge types: {data.edge_types}")
    
    return data

# --- 3. Memory-Efficient Batch Sampler ---
class MemoryEfficientSampler:
    """Memory-efficient sampler that uses PyTorch operations instead of Python dictionaries"""
    def __init__(self, data, batch_size=128, num_neighbors=[15, 10]):
        self.data = data
        self.batch_size = batch_size
        self.num_neighbors = num_neighbors
        print(f"   Created memory-efficient sampler (batch_size={batch_size})")
        
    def sample_neighbors_tensor(self, node_ids, edge_index, num_samples, node_type='paper'):
        """Sample neighbors using tensor operations - no Python loops or dictionaries"""
        if len(node_ids) == 0:
            return torch.tensor([], dtype=torch.long)
        
        # Convert to tensor if needed
        if isinstance(node_ids, list):
            node_ids = torch.tensor(node_ids, dtype=torch.long)
        
        # Find all edges where our nodes are destinations
        mask = torch.isin(edge_index[1], node_ids)
        
        if not mask.any():
            return torch.tensor([], dtype=torch.long)
        
        # Get source nodes from these edges
        sources = edge_index[0, mask]
        destinations = edge_index[1, mask]
        
        # Group by destination and sample
        sampled_neighbors = []
        for node_id in node_ids:
            # Get neighbors for this specific node
            node_mask = destinations == node_id
            node_neighbors = sources[node_mask]
            
            if len(node_neighbors) > 0:
                # Sample up to num_samples neighbors
                if len(node_neighbors) > num_samples:
                    indices = torch.randperm(len(node_neighbors))[:num_samples]
                    sampled = node_neighbors[indices]
                else:
                    sampled = node_neighbors
                sampled_neighbors.append(sampled)
        
        if sampled_neighbors:
            return torch.unique(torch.cat(sampled_neighbors))
        else:
            return torch.tensor([], dtype=torch.long)
    
    def create_minibatch(self, target_nodes):
        """Create a minibatch subgraph for target nodes using memory-efficient operations"""
        device = self.data['paper'].x.device
        
        # Start with target paper nodes
        target_nodes_tensor = torch.tensor(target_nodes, dtype=torch.long)
        all_paper_nodes = [target_nodes_tensor]
        
        # Sample citation neighbors layer by layer
        current_nodes = target_nodes_tensor
        paper_cite_edges = self.data['paper', 'cites', 'paper'].edge_index
        
        for num_samples in self.num_neighbors:
            # Sample neighbors for current layer
            neighbors = self.sample_neighbors_tensor(
                current_nodes, paper_cite_edges, num_samples
            )
            if len(neighbors) > 0:
                all_paper_nodes.append(neighbors)
                current_nodes = neighbors
        
        # Combine all paper nodes
        all_paper_nodes = torch.unique(torch.cat(all_paper_nodes))
        
        # Find connected authors and fields (limit to reduce memory)
        author_paper_edges = self.data['author', 'writes', 'paper'].edge_index
        field_paper_edges = self.data['paper', 'has_topic', 'field_of_study'].edge_index
        
        # Get authors connected to our papers
        author_mask = torch.isin(author_paper_edges[1], all_paper_nodes)
        connected_authors = torch.unique(author_paper_edges[0, author_mask])
        # Limit number of authors to avoid memory explosion
        if len(connected_authors) > 1000:
            indices = torch.randperm(len(connected_authors))[:1000]
            connected_authors = connected_authors[indices]
        
        # Get fields connected to our papers
        field_mask = torch.isin(field_paper_edges[0], all_paper_nodes)
        connected_fields = torch.unique(field_paper_edges[1, field_mask])
        # Limit number of fields
        if len(connected_fields) > 200:
            indices = torch.randperm(len(connected_fields))[:200]
            connected_fields = connected_fields[indices]
        
        # Create node mappings
        paper_mapping = {int(old): new for new, old in enumerate(all_paper_nodes.tolist())}
        author_mapping = {int(old): new for new, old in enumerate(connected_authors.tolist())}
        field_mapping = {int(old): new for new, old in enumerate(connected_fields.tolist())}
        
        # Create batch data
        batch = HeteroData()
        
        # Add node features
        batch['paper'].x = self.data['paper'].x[all_paper_nodes]
        batch['paper'].y = self.data['paper'].y[all_paper_nodes]
        
        if len(connected_authors) > 0:
            batch['author'].x = self.data['author'].x[connected_authors]
        else:
            batch['author'].x = torch.empty(0, 128)
            
        if len(connected_fields) > 0:
            batch['field_of_study'].x = self.data['field_of_study'].x[connected_fields]
        else:
            batch['field_of_study'].x = torch.empty(0, 64)
        
        # Add edges (only those within sampled nodes) using tensor operations
        for edge_type in self.data.edge_types:
            edge_index = self.data[edge_type].edge_index
            src_type, _, dst_type = edge_type
            
            # Get appropriate node sets and mappings
            if src_type == 'paper':
                src_nodes = all_paper_nodes
                src_mapping = paper_mapping
            elif src_type == 'author':
                src_nodes = connected_authors
                src_mapping = author_mapping
            else:  # field_of_study
                src_nodes = connected_fields
                src_mapping = field_mapping
                
            if dst_type == 'paper':
                dst_nodes = all_paper_nodes
                dst_mapping = paper_mapping
            elif dst_type == 'author':
                dst_nodes = connected_authors
                dst_mapping = author_mapping
            else:  # field_of_study
                dst_nodes = connected_fields
                dst_mapping = field_mapping
            
            # Filter edges efficiently
            if len(src_nodes) > 0 and len(dst_nodes) > 0:
                src_mask = torch.isin(edge_index[0], src_nodes)
                dst_mask = torch.isin(edge_index[1], dst_nodes)
                edge_mask = src_mask & dst_mask
                
                if edge_mask.any():
                    filtered_edges = edge_index[:, edge_mask]
                    # Remap indices
                    new_src = torch.tensor([src_mapping[int(idx)] for idx in filtered_edges[0].tolist()])
                    new_dst = torch.tensor([dst_mapping[int(idx)] for idx in filtered_edges[1].tolist()])
                    batch[edge_type].edge_index = torch.stack([new_src, new_dst])
                else:
                    batch[edge_type].edge_index = torch.empty(2, 0, dtype=torch.long)
            else:
                batch[edge_type].edge_index = torch.empty(2, 0, dtype=torch.long)
        
        # Mark target nodes
        target_mask = torch.zeros(len(all_paper_nodes), dtype=torch.bool)
        for i, node in enumerate(all_paper_nodes.tolist()):
            if node in target_nodes:
                target_mask[i] = True
        batch['paper'].target_mask = target_mask
        
        return batch
    
    def get_batches(self, indices, shuffle=True):
        """Generate batches from indices"""
        if shuffle:
            perm = torch.randperm(len(indices))
            indices = indices[perm]
        
        for i in range(0, len(indices), self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            yield self.create_minibatch(batch_indices.tolist())

# --- 4. Load Data ---
print("\n📂 Loading OGBN-MAG data...")
data_dict = load_ogbn_simple('./data')
data = create_pyg_hetero_data(data_dict)

# Get training indices
train_idx = data.train_idx
print(f"🎯 Training on {len(train_idx)} papers")

# --- 5. Model (unchanged) ---
print("\n🧠 Creating model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"   Device: {device}")

num_classes = data.num_classes
hidden_dim = 128
heads = 4

# Ensure compatibility
if num_classes % heads != 0:
    adjusted_classes = ((num_classes + heads - 1) // heads) * heads
    print(f"   Adjusting classes: {num_classes} → {adjusted_classes}")
    num_classes = adjusted_classes

class SimpleHGT(torch.nn.Module):
    def __init__(self, hidden_dim, out_dim, metadata, heads=4):
        super().__init__()
        self.conv1 = HGTConv(-1, hidden_dim, metadata, heads=heads)
        self.conv2 = HGTConv(hidden_dim, out_dim, metadata, heads=heads)
        self.dropout = torch.nn.Dropout(0.3)
    
    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = {key: self.dropout(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

model = SimpleHGT(hidden_dim, num_classes, data.metadata(), heads=heads).to(device)

# Initialize lazy parameters
print("   Initializing model parameters...")
with torch.no_grad():
    dummy_x_dict = {
        'paper': torch.randn(10, data_dict['paper_features'].shape[1]).to(device),
        'author': torch.randn(10, 128).to(device),
        'field_of_study': torch.randn(10, 64).to(device)
    }
    dummy_edge_dict = {
        edge_type: torch.randint(0, 10, (2, 20)).to(device)
        for edge_type in data.edge_types
    }
    _ = model(dummy_x_dict, dummy_edge_dict)

print("✅ Model parameters initialized!")

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

# --- 6. Training with Memory-Efficient Sampling ---
print("\n🏃 Training with memory-efficient sampling...")

# Create memory-efficient sampler
sampler = MemoryEfficientSampler(data, batch_size=128, num_neighbors=[15, 10])

def get_memory_usage():
    """Get current memory usage in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024 / 1024

print(f"Initial memory usage: {get_memory_usage():.2f} GB")

for epoch in range(1, 4):
    print(f"\n=== Epoch {epoch} ===")
    model.train()
    total_loss = 0
    total_examples = 0
    batch_count = 0
    
    for batch in sampler.get_batches(train_idx, shuffle=True):
        batch_count += 1
        if batch_count > 50:  # Limit batches for demo
            break
            
        try:
            # Move batch to device
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Forward pass
            out_dict = model(batch.x_dict, batch.edge_index_dict)
            
            # Get target nodes
            target_mask = batch['paper'].target_mask
            if target_mask.sum() == 0:
                continue
                
            paper_out = out_dict['paper'][target_mask][:, :data.num_classes]
            paper_labels = batch['paper'].y[target_mask]
            
            # Calculate loss
            loss = F.cross_entropy(paper_out, paper_labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Track metrics
            batch_size = target_mask.sum().item()
            total_loss += float(loss) * batch_size
            total_examples += batch_size
            
            if batch_count % 10 == 0:
                avg_loss = total_loss / total_examples
                memory = get_memory_usage()
                print(f"   Batch {batch_count}: Loss={avg_loss:.4f}, Memory={memory:.2f}GB, "
                      f"Papers={len(batch['paper'].x)}, Authors={len(batch['author'].x)}, "
                      f"Fields={len(batch['field_of_study'].x)}")
                
        except Exception as e:
            print(f"   Error in batch {batch_count}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Epoch summary
    if total_examples > 0:
        epoch_loss = total_loss / total_examples
        memory = get_memory_usage()
        print(f"✅ Epoch {epoch}: Loss={epoch_loss:.4f}, Memory={memory:.2f}GB, Batches={batch_count}")

print(f"\n🎉 Training complete!")
print(f"Final memory: {get_memory_usage():.2f}GB")

# Save model
torch.save({
    'model_state_dict': model.state_dict(),
    'hidden_dim': hidden_dim,
    'num_classes': num_classes,
    'heads': heads
}, 'memory_efficient_hgt_model.pt')

print("💾 Model saved!")
print("✅ Memory-efficient minibatch training successfully implemented!")