In [7]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
class Temporal_Layer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel):
        super(Temporal_Layer, self).__init()
        self.conv1 = nn.Conv2d(in_channels, out_channels, (1, kernel))
        self.conv2 = nn.Conv2d(in_channels, out_channels, (1, kernel))
        
    def forward(self, x):
        x = x.permute(0,3,1,2)
        normal = self.conv1(x)
        sig = torch.sigmoid(self.conv2(x))
        out = normal * sig
        out = out.permute(0,2,3,1)
        # Convert back from NCHW to NHWC
        return out
        
        
class Stgcn_Block(nn.Module):
    def __init__(self, in_channels, spatial_channels, out_channels, nodes_num):
        super(Stgcn_Block, self).__init__()
        self.temporal_layer1 = Temporal_Layer(in_channels, out_channels, kernel = 3) 
        self.temporal_layer2 = Temporal_Layer(in_channles = spatial_channels, out_channels = out_channels, kernel = 3)
        
        self.weight = nn.Parameter(torch.FloatTensor(out_channels, spatial_channels))
        self.initialise_weight()
        
        self.batch_norm = nn.BatchNorm2d(num_nodes)
        
        
    def initialise_weight(self):
        std_dv = 1 / math.sqrt(self.weight.shape[1])
        self.weight.data.uniform(-std_dv, std_dv)
        
    def forward(self, x, adj_hat):
        # First temporal Block
        temporal_block1 = self.temporal_layer1
        
        #Spatial Graph Convolution
        t = temporal_block1.permute(1,0,2,3) #Converts tensor from nhwc to hnwc for multiplication with adj_matrix
        gconv1 = torch.einsum("ij, jklm -> kilm", [adj_hat, t]) #(h,h) * (h,n,w,c) -> (n,h,w,c)
        gconv2 = F.relu(torch.matmul(gconv1, self.weight))
        
        #Second Temporal Block
        temporal_block2 = self.temporal_layer2(gconv2) 
        
        out = self.batch_norm(temporal_block2)
        return out
  

class Stgcn_Model(nn.Module):
    
    def __init__(self, nodes_num, features_num, input_timesteps, num_output):
        super(Stgcn_Model, self).__init__()
        self.stgcn_block1 = Stgcn_Block(in_channels = features_num, spatial_channels = 16, out_channels = 64,
                                       nodes_num = nodes_num)
        
        self.stgcn_block2 = Stgcn_Block(in_channels = 64, spatial_channels = 16,  out_channels = 64,
                                       nodes_num = nodes_num)
        
        self.temporal_layer = Temporal_Layer(in_channels = 64, out_channels = 64, kernel = 3)
        self.fc = nn.Linear((input_timesteps * 64 * nodes_num), num_output)
        
    def forward(self, adj_hat, x):
        out1 = self.stgcn_block1(x, adj_hat)
        out2 = self.stgcn_block2(out1, adj_hat)
        out3 = self.temporal_layer(out2)
        out4 = self.fc(out3.reshape((out3.shape[0], out3.shape[1], -1)))
        return out4
        
        
        

In [11]:
def train(x_input, x_target, batch_size):
    num_samples = x_input.shape[0]
    shuffled_order = torch.randperm(num_samples)
    
    training_loss = []
    for i in range(math.ceil(num_samples / batch_size)):
        stgcn.train()
        optimizer.zero_grad()
        
        start = i * batch_size
        batch = shuffled_order[start:start+batch_size]
        x_batch = x_input[batch].to(device = device)
        y_batch = x_target[batch].to(device = device)
        
        
        out = stgcn(adj_mat, x_batch)
        loss = loss_criterion(out, y_batch)
        loss.backward()
        optimizer.step()
        
        training_loss.append(loss.detach().cpu().numpy())
        
    return sum(training_loss) / len(training_loss)
    
    
    
    

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 50
batch_size = 12
input_timesteps = 12
num_output = 3
adj_mat = np.random.rand(5)

stgcn = Stgcn_Model(nodes_num = adj_mat.shape[0], features_num = x_input.shape[3],
                    input_timesteps = input_timesteps, num_output = num_output).to(device = device)

optimizer = torch.optim.Adam(stgcn.parameters(), lr = 1e-3)
loss_criterion = nn.MSELoss()

training_loss = []
validation_loss = []


for epoch in range(epochs):
    loss = train(x_input, x_target, batch_size)
    training_loss.append(loss)
    
    with torch.no_grad():
        stgcn.eval()
        val_input = val_input.to(device= device)
        val_target = val_target.to(device = device)
        
        out = stgcn(adj_mat, val_input)
        val_loss = loss_criterion(out, val_target).to(device = 'cpu')
        validation_loss.append(val_loss.item())
        
    print("Training Loss: {}".format(loss))
    print("Validation Loss: {}".format(val_loss))
    
    
plt.plot(training_loss, label = 'Training Loss')
plt.plot(validation_loss, label = 'Validation Loss')
plt.legend()
plt.show()
    


NameError: name 'x_input' is not defined