In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric_temporal.dataset import METRLADatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric.nn import GCNConv

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


Using device: cpu


In [3]:
class STTGCN(nn.Module):
    def __init__(self, num_nodes, input_dim, hidden_dim, output_dim, num_layers, dilation_factor=2):
        super(STTGCN, self).__init__()

        # Graph convolution layers
        self.gcn_layers = nn.ModuleList([GCNConv(input_dim if i == 0 else hidden_dim, hidden_dim) for i in range(num_layers)])

        # Temporal convolution
        self.temp_conv = nn.Conv2d(hidden_dim, hidden_dim, (1, 3), dilation=(1, dilation_factor), padding=(0, dilation_factor))

        # Fully connected output layer
        self.fc = nn.Linear(hidden_dim, output_dim)

        # Dynamic graph learning
        self.dynamic_adj = nn.Parameter(torch.randn(num_nodes, num_nodes), requires_grad=True)

        # Activation functions
        self.relu = nn.ReLU()

    def forward(self, x, edge_index, edge_weight=None):
        # Step 1: Dynamic Graph Learning - Adjust the edge weights dynamically
        adj = torch.softmax(self.dynamic_adj, dim=1)
        
        # Step 2: Apply Graph Convolution Layers with proper edge weight assignment
        for gcn in self.gcn_layers:
            if edge_weight is None:
                # Assign edge_weight from the learned adjacency matrix
                edge_weight = adj[edge_index[0], edge_index[1]]
            x = gcn(x, edge_index, edge_weight)
            x = self.relu(x)
        
        # Reshape for temporal convolution
        x = x.unsqueeze(0).permute(0, 2, 1).unsqueeze(3)  # Shape: [batch, channels, nodes, sequence_length]
        
        # Step 3: Temporal Convolution
        x = self.temp_conv(x)
        x = self.relu(x)
        x = x.squeeze(3).permute(0, 2, 1)  # Shape: [batch, nodes, channels]
        
        # Step 4: Output Layer
        x = self.fc(x)
        return x.squeeze(0)


In [4]:
loader = METRLADatasetLoader()
dataset = loader.get_dataset()

In [5]:
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)

In [6]:
num_nodes = dataset[0].x.shape[0]
input_dim = dataset[0].x.shape[1]
hidden_dim = 32
output_dim = 1
num_layers = 2
dilation_factor = 2

In [7]:
model = STTGCN(num_nodes=num_nodes, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim,
               num_layers=num_layers, dilation_factor=dilation_factor).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [8]:
def train(model, dataset, optimizer, criterion):
    model.train()
    total_loss = 0
    for snapshot in dataset:
        optimizer.zero_grad()
        
        # Forward pass
        x = snapshot.x.to(device)
        edge_index = snapshot.edge_index.to(device)
        edge_weight = snapshot.edge_attr.to(device) if snapshot.edge_attr is not None else None
        y = snapshot.y.to(device)
        
        # Forward pass through the model
        output = model(x, edge_index, edge_weight)
        
        # Compute loss
        loss = criterion(output, y)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(dataset)

In [9]:
def test(model, dataset, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for snapshot in dataset:
            x = snapshot.x.to(device)
            edge_index = snapshot.edge_index.to(device)
            edge_weight = snapshot.edge_attr.to(device) if snapshot.edge_attr is not None else None
            y = snapshot.y.to(device)
            
            output = model(x, edge_index, edge_weight)
            loss = criterion(output, y)
            total_loss += loss.item()
    return total_loss / len(dataset)

In [10]:
# Check shape of edge_index
for snapshot in train_dataset:
    assert snapshot.edge_index.shape[0] == 2, f"Expected edge_index shape [2, num_edges], but got {snapshot.edge_index.shape}"

# Proceed with training as before
num_epochs = 50
for epoch in range(num_epochs):
    train_loss = train(model, train_dataset, optimizer, criterion)
    test_loss = test(model, test_dataset, criterion)
    print(f'Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')


IndexError: index 2 is out of bounds for dimension 0 with size 2