In [1]:
import math
import os
from tqdm import tqdm

from ray import tune
import numpy as np
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.nn.utils import clip_grad_norm
from torch.utils.data import random_split


from network.model import MyModel
from network.loss import MyLoss
from dataset.HighD_Dataset_DGL_1graph import HighD_Dataset
from dgl.dataloading import GraphDataLoader
from utils.util import AverageMeter

Using backend: pytorch


In [5]:
def eval(model, val_dataloader, device):
    losses = AverageMeter()
    model.to(device)
    model.eval()
    criterion = MyLoss()

    with tqdm(total=len(val_dataloader),desc='Validation round') as pbar:
        for i, (graph, X, Y, _mask) in enumerate(val_dataloader):
            # 先不采用batch训练
            graph = graph.to(device)
            X = X[0,...].to(device)
            Y = Y[0,...].to(device)
            # mask = mask[0,...].to(device)
            if X.shape[1]==0:
                continue

            ###### 临时！ 测试用mask
            mask = torch.ones_like(X,dtype=torch.uint8)  #[10,N,2]
            mask[9,np.random.randint(mask.shape[1]),:]=0
            mask = mask.to(device)

            output = model(graph, X*mask)  #[1,N,2]
            loss = criterion(output,Y,mask)

            losses.update(loss.item())
            pbar.set_description('Loss: {:.2f}'.format(loss.item()))
            pbar.update()

    return losses.avg

In [6]:
HighD_dataset = HighD_Dataset(X_len=10,X_step=1,Y_len=1,Y_step=1,diff=9,name='data_02',raw_dir='./dataset/')
val_dataloader = GraphDataLoader(HighD_dataset, batch_size=1, shuffle=False)
print("Dataset Ready!")
device = 'cpu'
model = MyModel(num_feats=4, output_dim=4, hidden_size=64, num_layers=3,seq_len=10, horizon=1, device=device)
model.load_state_dict(torch.load('ckpts/0329_testgpu/CP_epoch15_loss_2.6713295542438136.pth'))
val_loss = eval(model, val_dataloader, device)
print('validation loss = {}'.format(val_loss))

Loss: 5.00:   0%|          | 1/25067 [00:00<1:05:11,  6.41it/s]Dataset Ready!
Loss: 4.31: 100%|██████████| 25067/25067 [1:08:22<00:00,  6.11it/s]validation loss = 5.138872074645197



In [8]:
loader = enumerate(val_dataloader)

In [10]:
i, (graph, X, Y, _mask) = loader.__next__()

In [None]:
X = X[0,...]
Y = Y[0,...]

mask = torch.ones_like(X,dtype=torch.uint8)
rand_pos = np.random.randint(mask.shape[1])
mask[9,rand_pos,:]=0