In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
import numpy as np

# --- 1. Imports from your custom modules ---
# Ensure these files are in your python path
from models import MST_GNN
from dataset import InMemoryDynamicSP100  # Or DynamicSP100Stocks if using the on-the-fly version

# --- 2. Hyperparameters & Setup ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Data Params
PAST_WINDOW = 25
FUTURE_WINDOW = 1
BATCH_SIZE = 32
TRAIN_SPLIT = 0.8

# Model Params
INPUT_FEATURES = 9   # e.g., Open, Close, RSI, MACD...
HIDDEN_SIZE = 64
GRAPH_LAYERS = 2
CROSS_LAYERS = 2
LEARNING_RATE = 0.001
EPOCHS = 100

# --- 3. Data Preparation ---

# Define the graph builder function (if using the InMemory logic)
def correlation_graph_builder(window_data):
    # window_data: (Nodes, Time, Features)
    # Extract close prices (assuming index 0)
    prices = window_data[:, :, 0] # (Nodes, Time)
    # Correlation matrix
    corr = np.corrcoef(prices)
    # Thresholding (Connect stocks with > 0.6 correlation)
    adj = (corr > 0.6).astype(int)
    # Remove self-loops here (the model adds them manually later)
    np.fill_diagonal(adj, 0)
    return adj

# Load Dataset
print("Loading Data...")
dataset = InMemoryDynamicSP100(
    root="./data", 
    past_window=PAST_WINDOW, 
    future_window=FUTURE_WINDOW,
    adj_calculator=correlation_graph_builder
)

# Train/Test Split
# We split by time (first 80% time steps for train, last 20% for test)
num_train = int(len(dataset) * TRAIN_SPLIT)
train_dataset = dataset[:num_train]
test_dataset = dataset[num_train:]

# Create Loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# --- 4. Model Initialization ---
print("Initializing Model...")
model = MST_GNN(
    in_features=INPUT_FEATURES,
    hidden_size=HIDDEN_SIZE,
    num_graph_layers=GRAPH_LAYERS,
    num_cross_layers=CROSS_LAYERS
).to(DEVICE)

# Optimizer & Loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss() # Combines Sigmoid + Binary Cross Entropy

# --- 5. Training Functions ---

def train():
    model.train()
    total_loss = 0
    correct = 0
    total_samples = 0

    for batch in train_loader:
        batch = batch.to(DEVICE)
        optimizer.zero_grad()

        # 1. Forward Pass
        # Output shape: (Batch_Size * Num_Nodes, 1)
        logits = model(batch.x, batch.edge_index)
        
        # 2. Prepare Labels (Up/Down)
        # batch.y contains continuous returns (e.g., 0.02, -0.01)
        # We assume predictions are per-node. Check shape consistency.
        # Usually batch.y shape is (Batch_Size * Num_Nodes)
        labels = (batch.y > 0).float().view(-1, 1) 
        
        # 3. Calculate Loss
        loss = criterion(logits, labels)
        
        # 4. Backprop
        loss.backward()
        optimizer.step()

        # 5. Metrics
        total_loss += loss.item()
        
        # Convert logits to probability > 0.5 for accuracy check
        preds = (torch.sigmoid(logits) > 0.5).float()
        correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total_samples
    return avg_loss, accuracy

def test():
    model.eval()
    total_loss = 0
    correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(DEVICE)
            
            logits = model(batch.x, batch.edge_index)
            labels = (batch.y > 0).float().view(-1, 1)
            
            loss = criterion(logits, labels)
            total_loss += loss.item()
            
            preds = (torch.sigmoid(logits) > 0.5).float()
            correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / len(test_loader)
    accuracy = correct / total_samples
    return avg_loss, accuracy

# --- 6. Main Execution Loop ---

print("Starting Training...")
best_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train()
    test_loss, test_acc = test()

    print(f'Epoch {epoch:03d}: '
          f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | '
          f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
    
    # Save best model
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_mst_gnn_model.pth')
        print(f"  >>> New Best Model Saved! (Acc: {best_acc:.4f})")

print("Training Complete.")