In [51]:
# Want: a model that can take in any-sized input, then predict any-sized output.
# This model: takes in 36-month inputs, predicts x-month outputs.

# We should end-to-end this thing
import numpy as np
import torch
import torch.nn as nn
import torch_geometric
import torch_geometric.nn

In [2]:
from torch_geometric.nn import GCNConv
# Node embedding network
class GCN(torch.nn.Module):
    def __init__(self, input_dim=4, emb_dim=64, num_layers=5, dropout=0.3):

      super(GCN, self).__init__()

      self.convs = torch.nn.ModuleList([GCNConv(in_channels=input_dim, out_channels=emb_dim)] +
                  [GCNConv(in_channels=emb_dim, out_channels=emb_dim) for i in range(num_layers - 1)])
      self.bns = torch.nn.ModuleList([torch.nn.BatchNorm1d(num_features = emb_dim) for i in range(num_layers - 1)])
      self.softmax = torch.nn.LogSoftmax()

      self.dropout = dropout

    def reset_parameters(self):
      for conv in self.convs:
          conv.reset_parameters()
      for bn in self.bns:
          bn.reset_parameters()

    def forward(self, x, adj_t):
      for i in range(len(self.bns)):
        x = self.convs[i].forward(x, adj_t)
        x = self.bns[i](x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.dropout(x)
      out = self.convs[-1].forward(x, adj_t)
      return out

In [60]:
from torch_geometric.nn import global_add_pool, global_mean_pool

# Graph embedding network
class GCN_Graph(torch.nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(GCN_Graph, self).__init__()
        self.gnn_node = GCN()

        self.pool = global_mean_pool
        self.linear = torch.nn.Linear(hidden_dim, output_dim)

    def reset_parameters(self):
        self.gnn_node.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, batch, edge_index):
        x, edge_index, batch = batched_data.x, batched_data.edge_index, batched_data.batch

        out = None

        ############# Your code here ############
        ## Note:
        ## 1. Construct node embeddings using existing GCN model
        ## 2. Use the global pooling layer to aggregate features for each individual graph
        ## For more information please refer to the documentation:
        ## https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#global-pooling-layers
        ## 3. Use a linear layer to predict each graph's property
        ## (~3 lines of code)
        out = self.gnn_node(embed, edge_index)
        out = self.pool(out, batch)
        out = self.linear(out)
        #########################################

        return out

In [61]:
START = 10
SEQ_LEN = 32
'''
The decoder takes in the final hidden state from the encoder (batch x h_e). It predicts ONI scores for a set number of 
months into the future. 
'''
class dec(nn.Module):
    def __init__(self, hidden_dim, output_len): # hidden dim is 2(h_e)
        super(dec, self).__init__()
        self.rnn_cell = nn.RNNCell(1, hidden_dim) # +1 to concatenate predictions
        self.linear_out = nn.Linear(hidden_dim, 1)
        self.output_len = output_len

    def forward(self, initial_input):
        outputs = []

        hidden = initial_input # input should be batch x (2 x h_e) (just hidden states)
        input = torch.full((hidden.size(0), 1), START, dtype=torch.float32)
        
        for _ in range(self.output_len):
            hidden = self.rnn_cell(input, hidden)
            output = self.linear_out(hidden)
            outputs.append(output)
            input = output
        
        final = torch.stack(outputs, dim=1)  # batch x seq_len
        return final

In [54]:
'''
Given a batch x seq_len x node_count x node_dim matrix, return a batch x seq_len matrix containing ONI predictions.

For this baseline model, we have a defined input len (seq_len = 32) and output length (output_length=32). 

GNNRNN first uses a GCN to create node embeddings for each of the 1728 (lat x long, 24 x 72) nodes in the world graph. 
Then, each graph's set of node embeddings are passed through a linear layer to produce a graph embedding, yielding a 
batch x seq_len x graph_emb_dim matrix. This is finally fed into the encoder-decoder RNN architecture to produce predictions.

The entire model is end-to-end differentiable.
'''
class GNNRNN(nn.Module):
    # input_dim = graph embedding dimension; hidden_dim: = encoder hidden dim
    def __init__(self, node_embedder, graph_embedder, input_dim, hidden_dim, input_length, output_length):

        super(GNNRNN, self).__init__()

        self.ne = node_embedder
        self.ge = graph_embedder

        self.encoder = nn.RNN(input_size=input_dim, hidden_size=hidden_dim, batch_first=True, bidirectional=True)
        # batch_first --> input is batch x seq_len x dim
        self.decoder = dec(2 * hidden_dim, SEQ_LEN)

        self.output_length = output_length


    # x will be batch x input_len x emb_dim 
    def forward(self, x):
        _, temp = self.encoder(x) # 2 x batch x emb_dim. Want to get batch x (2 x emb_dim) as initial hidden state for decoder
        encoded = temp.transpose(0, 1).reshape(temp.size(1), -1)
        outputs = self.decoder(encoded)
        return outputs

In [55]:
test = torch.rand(8, 32, 64)
model = enc_dec(None, None, 64, 64, 32)
out = model(test)
print(out.size())

In [59]:
def train(model, dataloader, criterion, optimizer, epochs):
    model.train()  
    for epoch in range(epochs):
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            
            # Compute the loss
            loss = criterion(outputs, targets)
            
            # Backward pass
            loss.backward()
            
            # Update the parameters
            optimizer.step()
            
            # Print training progress
            if batch_idx % 10 == 0:  # Print every 10 batches
                print(f'Epoch [{epoch + 1}/{epochs}], Step [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')

encoder.weight_ih_l0 tensor([[ 0.0062,  0.0023, -0.0648,  ..., -0.1218, -0.0675,  0.0442],
        [-0.0951, -0.1173,  0.0773,  ...,  0.0954, -0.1030,  0.0092],
        [-0.0827, -0.0551, -0.0829,  ...,  0.0186, -0.0315, -0.1128],
        ...,
        [-0.0342,  0.1094,  0.0055,  ...,  0.0907, -0.0486,  0.0218],
        [-0.0228, -0.0009,  0.0543,  ..., -0.0164,  0.0778, -0.0741],
        [-0.1137, -0.0557,  0.0921,  ...,  0.0087, -0.0207,  0.0718]])
encoder.weight_hh_l0 tensor([[ 0.0686,  0.0714,  0.0008,  ...,  0.1163, -0.0897,  0.0671],
        [-0.0040,  0.1161,  0.0310,  ..., -0.0987,  0.0134,  0.0344],
        [-0.0323,  0.0837, -0.0737,  ...,  0.0769,  0.0600,  0.0093],
        ...,
        [ 0.0571, -0.1186,  0.0554,  ..., -0.0665, -0.0306,  0.0972],
        [-0.1001, -0.0739, -0.0878,  ...,  0.0763, -0.0215, -0.0108],
        [-0.0091,  0.1217,  0.0716,  ...,  0.0489,  0.0373,  0.0154]])
encoder.bias_ih_l0 tensor([-0.1185, -0.0184, -0.0161, -0.0008, -0.1083, -0.0288,  0.0399, 

torch.Size([8, 32, 1])
