In [7]:
import torch
from torch_geometric.loader import DataLoader
from torch.optim import Adam
from sklearn.metrics import accuracy_score
from models import GiG  
import matplotlib.pyplot as plt
import pickle
%matplotlib inline

config = {
    "node_level_module": "GIN",
    "num_node_features": 7,
    "node_layers": [32, 16],
    "node_level_hidden_layers_number": 2,
    "pooling": "mean",
    "population_level_module": "LGL",
    "population_layers": [16, 8],
    "temp": 0.5,
    "theta": 0.1,
    "gnn_type": "GraphConv",
    "gnn_layers": [8],
    "gnn_aggr": "mean",
    "classifier_layers": [8],
    "output_dim": 2405,
}

In [9]:
# Load your pickled data
with open("./Graph Outputs/train_pg_subgraph.pkl", "rb") as f:
    train_pg_subgraph = pickle.load(f)
with open("./Graph Outputs/val_pg_subgraph.pkl", "rb") as f:
    val_pg_subgraph = pickle.load(f)
with open("./Graph Outputs/test_pg_subgraph.pkl", "rb") as f:
    test_pg_subgraph = pickle.load(f)

# Create data loaders
batch_size = 8
train_loader = DataLoader(train_pg_subgraph, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_pg_subgraph, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_pg_subgraph, batch_size=batch_size, shuffle=False)

KeyboardInterrupt: 

In [None]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GiG(config).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Lists to store loss and accuracy
losses = []
accuracies = []

In [None]:
epochs = 3
for epoch in range(epochs):
    # Training
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        predictions, _, _, _, _ = model(batch)
        loss = criterion(predictions, batch.y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Calculate average loss
    avg_loss = total_loss / len(train_loader)
    losses.append(avg_loss)
    
    # Validation
    val_acc = evaluate(model, val_loader, device)
    accuracies.append(val_acc)
    
    # Print progress
    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Loss: {avg_loss:.4f}")
    print(f"Validation Accuracy: {val_acc:.4f}")
    
    # Plot progress
    plt.figure(figsize=(10, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(accuracies)
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Test final model
test_accuracy = evaluate(model, test_loader, device)
print(f"Final Test Accuracy: {test_accuracy:.4f}")