In [None]:
import dgl
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Define a MLP model
class MLP(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(in_feats, hidden_size)
        self.linear2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = self.linear2(x)
        return x

# Load the saved graph with features
bin_file_path = "/src/data/processed/oulad_graph_with_features.bin"
graph_list, _ = dgl.load_graphs(bin_file_path)
graph = graph_list[0]

# Define training parameters
num_classes = 4
hidden_size = 64
lr = 0.001
num_epochs = 10

# Prepare training and validation data
features = graph.nodes['student'].data['features']
labels = torch.randint(0, num_classes, (graph.num_nodes('student'),), dtype=torch.long)

# Create train-validation split
train_mask, val_mask = train_test_split(range(graph.num_nodes('student')), test_size=0.2, random_state=42)
train_features = features[train_mask]
train_labels = labels[train_mask]
val_features = features[val_mask]
val_labels = labels[val_mask]

# Initialize the MLP model, loss function, and optimizer
model = MLP(features.shape[1], hidden_size, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Train the MLP model
for epoch in range(num_epochs):
    # Training phase
    model.train()
    optimizer.zero_grad()

    # Forward pass
    logits = model(train_features)
    
    # Compute loss
    loss = criterion(logits, train_labels)

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

    # Evaluate model on validation set
    model.eval()
    with torch.no_grad():
        val_logits = model(val_features)
        val_loss = criterion(val_logits, val_labels)
        
        # Predictions and accuracy
        _, val_pred = torch.max(val_logits, dim=1)
        val_accuracy = accuracy_score(val_labels.numpy(), val_pred.numpy())

        # Training accuracy
        train_logits = model(train_features)
        _, train_pred = torch.max(train_logits, dim=1)
        train_accuracy = accuracy_score(train_labels.numpy(), train_pred.numpy())

    # Test setup
    test_features = torch.randn(10, features.shape[1])
    test_labels = torch.randint(0, num_classes, (10,), dtype=torch.long)
    with torch.no_grad():
        test_logits = model(test_features)
        _, test_pred = torch.max(test_logits, dim=1)

        # Calculate test accuracy
        test_accuracy = accuracy_score(test_labels.numpy(), test_pred.numpy())

    print(f"Epoch {epoch + 1}, Train Loss: {loss.item():.4f}, Train Accuracy: {train_accuracy * 100:.2f}%, "
          f"Val Loss: {val_loss.item():.4f}, Val Accuracy: {val_accuracy * 100:.2f}%, "
          f"Test Accuracy: {test_accuracy * 100:.2f}%")