In [None]:
import os
os.chdir("..")
from epilearn.data import UniversalDataset
from epilearn.utils import transforms
from epilearn.tasks.forecast import Forecast
from epilearn.tasks.detection import Detection
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np


## Models

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.utils import dense_to_sparse
from epilearn.models.Temporal.base import BaseModel
from epilearn.models.Spatial.base import BaseModel
from epilearn.models.SpatialTemporal.base import BaseModel

class CustomizedTemporal(BaseModel):
    def __init__(self,
                num_features,
                num_timesteps_input,
                num_timesteps_output,
                hidden_size,
                num_layers,
                bidirectional,
                device = 'cpu'):
        super(CustomizedTemporal, self).__init__(device=device)
        self.num_feats = num_features
        self.hidden = hidden_size
        self.num_layers = num_layers
        self.bidirectional=bidirectional
        self.lookback = num_timesteps_input
        self.horizon = num_timesteps_output
        self.device = device

        self.lstm = nn.LSTM(input_size=self.num_feats, hidden_size=self.hidden, num_layers=self.num_layers, batch_first=True, bidirectional=self.bidirectional)
        self.fc = nn.Linear(self.hidden, self.horizon)

    def forward(self, feature, graph=None, states=None, dynamic_graph=None, **kargs):        
        # Forward propagate LSTM
        out, _ = self.lstm(feature)  # out: tensor of shape (batch, seq_length, hidden_size * num_directions)
        
        # Decode the last hidden state
        out = self.fc(out[:, -1, :])

        return out

    def initialize(self):
        for name, param in self.lstm.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)


class CustomizedSpatial(BaseModel):
    def __init__(self,
                num_nodes,
                num_features,
                num_timesteps_input,
                num_timesteps_output,
                hidden_size,
                device = 'cpu'):
        super(CustomizedSpatial, self).__init__(device=device)
        self.num_nodes = num_nodes
        self.num_feats = num_features
        self.hidden = hidden_size
        self.lookback = num_timesteps_input
        self.horizon = num_timesteps_output
        self.device = device

        self.gcn = GCNConv(in_channels=self.num_feats, out_channels=self.hidden)
        self.fc = nn.Linear(self.hidden, self.horizon)

    def forward(self, feature, graph, states=None, dynamic_graph=None, **kargs):
        x = feature.transpose(1,2).reshape(-1, self.num_nodes, self.num_feats)
        edge_index, _ = dense_to_sparse(graph)

        x = self.gcn(x, edge_index=edge_index)

        
        out = self.fc(x)

        return out

    def initialize(self):
        pass


class CustomizedSpatialTemporal(BaseModel):
    def __init__(self,
                num_nodes,
                num_features,
                num_timesteps_input,
                num_timesteps_output,
                num_channels_output,
                hidden_size,
                num_layers,
                bidirectional,
                device = 'cpu'):
        super(CustomizedSpatialTemporal, self).__init__(device=device)
        self.num_nodes = num_nodes
        self.num_feats = num_features
        self.hidden = hidden_size
        self.num_layers = num_layers
        self.bidirectional=bidirectional
        self.lookback = num_timesteps_input
        self.horizon = num_timesteps_output
        self.num_out_channels = num_channels_output
        self.device = device

        self.gcn = GCNConv(in_channels=self.num_feats, out_channels=self.hidden)
        self.lstm = nn.LSTM(input_size=self.hidden, hidden_size=self.hidden, num_layers=self.num_layers, batch_first=True, bidirectional=self.bidirectional)
        self.fc = nn.Linear(self.hidden, self.num_out_channels*self.horizon)

    def forward(self, feature, graph, states=None, dynamic_graph=None, **kargs):
        # message passing to update node features
        edge_index, _ = dense_to_sparse(graph)

        x = self.gcn(feature.float(), edge_index=edge_index)

        x = x.transpose(1,2).reshape(-1, self.lookback, self.hidden)
        # Forward propagate LSTM
        out, _ = self.lstm(x)  # out: tensor of shape (batch, seq_length, hidden_size * num_directions)
        # Decode the last hidden state
        out = out[:, -1, :]
        out = out.reshape(-1, self.num_nodes, self.hidden)
        out = self.fc(out).reshape(-1, self.num_nodes, self.horizon, self.num_out_channels)

        return out.transpose(1,2) # return shape (batch, horizon, num_nodes, num_channels_output)

    def initialize(self):
        pass

### Datasets