In [None]:
import yaml
import torch

from torch_geometric.loader import DataLoader

from models import SWEGNN, GCN
from data import TemporalGraphDataset

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
with open('config.yaml') as f:
    config = yaml.safe_load(f)

dataset, info = TemporalGraphDataset(node_features=config['node_features'],
                    edge_features=config['edge_features'],
                    **config['dataset_parameters']).load()

In [None]:
print(dataset[0])
print(info)

Data(x=[1268, 6], edge_index=[2612, 2], edge_attr=[2612, 8], y=[1268, 1], pos=[1268, 2])
{'num_static_node_features': 3, 'num_dynamic_node_features': 1, 'num_static_edge_features': 5, 'num_dynamic_edge_features': 1}


In [29]:
num_train = int(len(dataset) * 0.8) # 80% train, 20% test

train_dataset = dataset[:num_train]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = dataset[num_train:]
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [11]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
base_model_params = {
    'static_node_features': info['num_static_node_features'],
    'dynamic_node_features': info['num_dynamic_node_features'],
    'static_edge_features': info['num_static_edge_features'],
    'dynamic_edge_features': info['num_dynamic_edge_features'],
    'device': device,
}
swe_gnn_params = config['SWEGNN']
model = SWEGNN(**swe_gnn_params, **base_model_params)


In [None]:
lr_info = config['training_parameters']
optimizer = torch.optim.Adam(model.parameters(), lr=lr_info['learning_rate'], weight_decay=lr_info['weight_decay'])
num_epochs = 10
loss = None # TODO: Replace loss function

def train():
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = loss(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch [{epoch + 1}/{num_epochs}], Training Loss: {epoch_loss:.4f}')


def test():
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            outputs = model(inputs)
            loss = loss(outputs, labels)
            val_loss += loss.item()

            # Calculate accuracy (for classification tasks)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Print validation statistics
    val_loss /= len(test_loader)
    val_accuracy = 100 * correct / total
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')
