In [1]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [2]:
from config import get_configs
from utils import load_data,load_data_week,metric,masked_mse
from dataset import ODDataset,ODWEEKDataset
from models import MYMODEL, MYMODEL2, MYMODEL3,MYMODEL4
from lib import train,test

In [3]:
args = get_configs()
args.model_name = "cross_day_week"

In [4]:

x, c, te, tc, adj, ac, y = load_data_week(args)


In [5]:
te.shape

torch.Size([726, 24, 31])

In [6]:
x.shape

torch.Size([726, 24, 66, 132])

In [7]:
adj.shape

torch.Size([726, 24, 66, 66])

In [8]:
y.shape

torch.Size([726, 24, 66, 66])

In [9]:
c.shape

torch.Size([726, 8, 12, 66, 132])

In [10]:
tc.shape

torch.Size([726, 8, 12, 31])

In [11]:
ac.shape

torch.Size([726, 8, 12, 66, 66])

In [12]:
x_train, x_valid, c_train, c_valid, adj_train, adj_valid, ac_train, ac_valid, te_train, te_valid,tc_train, tc_valid, y_train, y_valid = train_test_split(
    x, c, adj, ac, te, tc, y, test_size=args.test_size, random_state=42)
x_train, x_test, c_train, c_test, adj_train, adj_test, ac_train, ac_test, te_train, te_test, tc_train, tc_test, y_train, y_test = train_test_split(
    x_train, c_train, adj_train, ac_train, te_train, tc_train, y_train, test_size=args.test_size, random_state=42)

train_dataset = ODDataset(x_train, c_train, adj_train, ac_train, te_train, tc_train, y_train)
valid_dataset = ODDataset(x_valid, c_valid, adj_valid, ac_valid, te_valid, tc_valid, y_valid)
test_dataset = ODDataset(x_test, c_test, adj_test, ac_test, te_test, tc_test, y_test)

train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)

In [13]:

se = torch.Tensor(np.load(args.se_path)).to(args.device)
pe = torch.Tensor(pd.read_csv(args.pe_path,index_col=0).values).to(args.device)

In [14]:
if args.model_name == "vanilla":
    model = MYMODEL(
        s_in_channels=se.size(1),
        p_in_channels=pe.size(1),
        t_in_channels=te.size(2),
        o_in_channels=x.size(3),
        num_nodes=x.size(2),
        in_seq_len=args.in_seq_len,
        out_seq_len=args.out_seq_len,
        num_encoder_layers=args.num_encoder_layers,
        d_model=args.d_model,
        hidden_channels=args.hidden_channels,
        heads=args.heads,
        dropout=args.dropout
    ).to(args.device)
elif args.model_name == "cross_day":
    model = MYMODEL2(
        s_in_channels=se.size(1),
        p_in_channels=pe.size(1),
        t_in_channels=te.size(2),
        o_in_channels=x.size(3),
        num_nodes=x.size(2),
        in_seq_len=args.in_seq_len,
        out_seq_len=args.out_seq_len,
        num_encoder_layers=args.num_encoder_layers,
        d_model=args.d_model,
        hidden_channels=args.hidden_channels,
        heads=args.heads,
        dropout=args.dropout,
        alpha = args.alpha
    ).to(args.device)
elif args.model_name == "cross_day_att":
    model = MYMODEL3(
        s_in_channels=se.size(1),
        p_in_channels=pe.size(1),
        t_in_channels=te.size(2),
        o_in_channels=x.size(3),
        num_nodes=x.size(2),
        in_seq_len=args.in_seq_len,
        out_seq_len=args.out_seq_len,
        num_encoder_layers=args.num_encoder_layers,
        d_model=args.d_model,
        hidden_channels=args.hidden_channels,
        heads=args.heads,
        dropout=args.dropout
    ).to(args.device)
elif args.model_name == "cross_day_week":
    model = MYMODEL4(
        s_in_channels=se.size(1),
        p_in_channels=pe.size(1),
        t_in_channels=te.size(2),
        o_in_channels=x.size(3),
        num_nodes=x.size(2),
        in_seq_len=args.in_seq_len,
        out_seq_len=args.out_seq_len,
        num_encoder_layers=args.num_encoder_layers,
        d_model=args.d_model,
        hidden_channels=args.hidden_channels,
        heads=args.heads,
        dropout=args.dropout,
        alpha = args.alpha
    ).to(args.device)

In [15]:
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

7,860,611 total parameters.
7,860,611 training parameters.


In [16]:
print(model)

MYMODEL4(
  (spgat): SPGAT(
    (od_emb): ODGAT(
      (gat): GAT(
        (conv1): GATConv(132, 128, heads=8)
        (fc1): Linear(in_features=1024, out_features=128, bias=True)
        (conv2): GATConv(128, 128, heads=8)
        (fc2): Linear(in_features=1024, out_features=128, bias=True)
      )
    )
  )
  (sta): STA(
    (tf): Linear(in_features=31, out_features=128, bias=True)
    (sf): Linear(in_features=128, out_features=128, bias=True)
    (pf): Linear(in_features=13, out_features=128, bias=True)
    (fc): Linear(in_features=128, out_features=128, bias=True)
    (fc1): Linear(in_features=256, out_features=128, bias=True)
    (positional_encoding): PositionalEncoding()
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_featur

In [17]:
optimizer = optim.Adam(model.parameters(),lr=args.lr)
criterion = masked_mse
scheduler = StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

In [18]:
train(model,train_dataloader,se,pe,valid_dataloader,optimizer,criterion,scheduler,args)

Epoch: 00, Loss: 4.8763
test loss: 6.4976
Epoch: 01, Loss: 4.2600
test loss: 3.8986
Epoch: 02, Loss: 3.5784
test loss: 3.5931
Epoch: 03, Loss: 3.3666
test loss: 3.4109
Epoch: 04, Loss: 3.2582
test loss: 3.3439
Epoch: 05, Loss: 3.1940
test loss: 3.3435
Epoch: 06, Loss: 3.1370
test loss: 3.3296
Epoch: 07, Loss: 3.0741
test loss: 3.2144
Epoch: 08, Loss: 3.0124
test loss: 3.1164
Epoch: 09, Loss: 2.9742
test loss: 3.0840
Epoch: 10, Loss: 2.9402
Epoch: 11, Loss: 2.9169
test loss: 3.0751
Epoch: 12, Loss: 2.8987
test loss: 3.0583
Epoch: 13, Loss: 2.8849
test loss: 3.0483
Epoch: 14, Loss: 2.8738
test loss: 3.0470
Epoch: 15, Loss: 2.8656
Epoch: 16, Loss: 2.8581
Epoch: 17, Loss: 2.8514
Epoch: 18, Loss: 2.8456
Epoch: 19, Loss: 2.8391
Epoch: 20, Loss: 2.8318
Epoch: 21, Loss: 2.8256
Epoch: 22, Loss: 2.8207
test loss: 3.0450
Epoch: 23, Loss: 2.8171
test loss: 3.0407
Epoch: 24, Loss: 2.8145
test loss: 3.0383
Epoch: 25, Loss: 2.8138
test loss: 3.0357
Epoch: 26, Loss: 2.8146
test loss: 3.0035
Epoch: 27,

In [19]:
# model = torch.load(f"{args.model_path}/model_{args.model_name}_epoch_{args.epoch}_batchsize_{args.batch_size}_lr_{args.lr}.pth").to(args.device)

In [20]:
# test_loss,p = test(model,test_dataloader,se,pe,criterion,args)

In [21]:
# test_loss

In [22]:
# y_test = y_test.to(args.device)

In [23]:
# mse,rmse,mae,mape = metric(p,y_test)


In [24]:
# mse

In [25]:
# mape

In [26]:
# rmse

In [27]:
# mae

In [28]:
# args.device