In [None]:
import torch
import torch.optim as optim
import os
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import pickle
from model_definitions import GRU_submodel
from utils import get_ade,get_fde
import datetime
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
print(datetime.datetime.now())
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.manual_seed(0)
torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
epoch_num = 3000
learning_rate = 0.001
embed_dim = 128
hidden_dim = 128
dropout_rate = 0.5
pred_len = 20
batch_size = 512
numrc = 7

In [None]:
with open('datas/ATC_PoPPL_sub_x.pickle', 'rb') as f:
    x = pickle.load(f)
with open('datas/ATC_PoPPL_sub_y.pickle', 'rb') as f:
    y = pickle.load(f)

In [None]:
for rc in range(numrc):
    tensor_x = torch.Tensor(x[rc])/10/100
    tensor_y = torch.tensor(y[rc])/10/100
    datas = TensorDataset(tensor_x,tensor_y)
    train_size = int(0.8 * len(datas))
    test_size = len(datas) - train_size

    net = GRU_submodel(embed_dim,hidden_dim,dropout_rate,pred_len)
    net.to(device)
    lossfn = nn.MSELoss()
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate)

    train = DataLoader(datas,batch_size=batch_size,shuffle=False)
    for epoch in range(epoch_num):  # loop over the dataset multiple times
        running_loss = []
        net.train()
        for i, data in enumerate(train, 0):
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            output1 = net(inputs)
            loss = lossfn(output1, labels.float())
            loss.backward()
            optimizer.step()
            running_loss.append(loss.item())
        print('[%d] loss: %.3f' %(epoch + 1, np.mean(running_loss)))
        net.eval()
        accuADE = []
        if epoch % 50 == 0:
            _, test_dataset = torch.utils.data.random_split(datas, [train_size, test_size], generator=torch.Generator().manual_seed(epoch))
            test = DataLoader(test_dataset,batch_size=batch_size,shuffle=False)
            for i, data in enumerate(test, 0):
                inputs, labels = data[0].to(device), data[1].to(device)
                output1 = net(inputs)
                accuADE.append(get_ade(output1.cpu().detach().numpy(),labels.cpu().detach().numpy()))
            print('[%d, %5d] ADE: %.3f' %(epoch + 1, i + 1, np.mean(accuADE)))

    torch.save(net.state_dict(), "./subnets/ATC_PoPPL_submodel_"+str(rc)+".pth")
print(datetime.datetime.now())