In [1]:
from models.models import *
from models.models_2D import *

dropout = 0.03
in_dim = 100
in_dim_2D = 52
device = torch.device('cpu')
k_days = 7
act_func = 'relu'

In [2]:
dico_model = {'GAT': GAT(in_dim=[in_dim, 64, 64, 64],
                heads=[4, 4, 2],
                dropout=dropout,
                bias=True,
                device=device,
                act_func=act_func,
                n_sequences=k_days),

            'ST-GATTCN': STGATCN(n_sequences=k_days,
                                num_of_layers=3,
                                in_channels=in_dim,
                                    end_channels=64,
                                    skip_channels=32,
                                    residual_channels=32,
                                    dilation_channels=32,
                                    dropout=dropout,
                                    heads=6, act_func=act_func,
                                    device=device),
                                
            'ST-GATCONV': STGATCONV(k_days,
                                    num_of_layers=3,
                                    in_channels=in_dim,
                                    hidden_channels=32,
                                    residual_channels=64,
                                    end_channels=32,
                                    dropout=dropout,
                                    heads=6,
                                    act_func=act_func,
                                    device=device),

            'ST-GCNCONV' : STGCNCONV(k_days,
                                    num_of_layers=3,
                                    in_channels=in_dim,
                                    hidden_channels=64,
                                    residual_channels=64,
                                    end_channels=32,
                                    dropout=dropout,
                                    act_func=act_func,
                                    device=device),

            'ATGN' : TemporalGNN(in_channels=in_dim,
                                hidden_channels=64,
                                out_channels=64,
                                n_sequences=k_days,
                                device=device,
                                act_func=act_func,
                                dropout=dropout),

            'ST-GATLTSM' : ST_GATLSTM(in_channels=in_dim,
                                    hidden_channels=[64, 64, 64],
                                        out_channels=64,
                                        end_channels=32,
                                        n_sequences=k_days,
                                        device=device, act_func=act_func, heads=6, dropout=dropout),

            'Zhang' : Zhang(in_channels=in_dim_2D,
                        hidden_channels=64,
                            end_channels=128,
                            dropout=dropout,
                            binary=False,
                            device=device,
                            n_sequences=k_days),

            'ConvLSTM' : CONVLSTM(in_channels=in_dim_2D,
                                hidden_dim=[32, 32, 32],
                                    end_channels=64,
                                    n_sequences=k_days,
                                    dropout=dropout,
                                    device=device, act_func=act_func)
            }

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
name = 'GAT'
display(dico_model[name])

GAT(
  (dropout): Dropout(p=0.03, inplace=False)
  (net): Sequential(
    (0) - GATConv(100, 64, heads=4): x, edge_index -> x
    (1) - ReLU(): x -> x
    (2) - GATConv(256, 64, heads=4): x, edge_index -> x
    (3) - ReLU(): x -> x
    (4) - GATConv(256, 64, heads=2): x, edge_index -> x
  )
  (output): OutputLayer(
    (fc): Linear(64, 64, bias=True)
    (activation): ReLU()
    (fc2): Linear(64, 1, bias=True)
    (output): ReLU()
  )
)