In [1]:
import os
import torch
import argparse
import numpy as np
import torch.utils.data
from easydict import EasyDict as edict
from timeit import default_timer as timer

from utils.eval import Metric
from utils.gpu_dispatch import GPU
from utils.common_utils import dir_check, to_device, ws, unfold_dict, dict_merge, GpuId2CudaId, Logger

from algorithm.dataset import CleanDataset, TrafficDataset
from algorithm.diffGSL.model import DiffSTG, save2file

In [2]:
def default_config(data='AIR_BJ'):
    config = edict()
    config.PATH_MOD = ws + '/output/model/'
    config.PATH_LOG = ws + '/output/log/'
    config.PATH_FORECAST = ws + '/output/forecast/'

    # Data Config
    config.data = edict()
    config.data.name = data
    config.data.path = ws + '/data/dataset/'
    config.graph_diffusion_step = 3


    config.data.feature_file = config.data.path + config.data.name + '/flow.npy'  # Add this line
    config.data.spatial = config.data.path + config.data.name + '/adj.npy'
    config.data.num_recent = 1

    # Data settings for different datasets
    if config.data.name == 'PEMS08':
        config.data.num_vertices = 170
        config.data.points_per_hour = 12
        config.data.val_start_idx = int(17856 * 0.6)
        config.data.test_start_idx = int(17856 * 0.8)

    if config.data.name == "AIR_BJ":
        config.data.num_vertices = 34
        config.data.points_per_hour = 1
        config.data.val_start_idx = int(8760 * 0.6)
        config.data.test_start_idx = int(8760 * 0.8)

    if config.data.name == 'AIR_GZ':
        config.data.num_vertices = 41
        config.data.points_per_hour = 1
        config.data.val_start_idx = int(8760 * 10 / 12)
        config.data.test_start_idx = int(8160 * 11 / 12)

    gpu_id = GPU().get_usefuel_gpu(max_memory=6000, condidate_gpu_id=[0])
    config.gpu_id = gpu_id
    if gpu_id != None:
        cuda_id = GpuId2CudaId(gpu_id)
        torch.cuda.set_device(f"cuda:{cuda_id}")
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')
    # print(device)

    # Model config
    config.model = edict()
    config.model.T_p = 12
    config.model.T_h = 12
    config.model.V = config.data.num_vertices
    config.model.F = 1
    config.model.week_len = 7
    config.model.day_len = config.data.points_per_hour * 24
    config.model.device = device
    config.model.d_h = 32
    config.cheb_k = 3

    # Diffusion model config
    config.model.N = 200
    config.model.sample_steps = 10
    config.model.epsilon_theta = 'GSTNet'
    config.model.is_label_condition = True
    config.model.beta_end = 0.02
    config.model.beta_schedule = 'quad'
    config.model.sample_strategy = 'ddpm'

    config.n_samples = 2
    config.model.channel_multipliers = [1, 2]
    config.model.supports_len = 6
    config.model.mask_ratio = 0.0
    config.model.cheb_k = 3

    # Training config
    config.model_name = 'DiffSTG'
    config.is_test = False
    config.epoch = 300
    config.optimizer = "adam"
    config.lr = 1e-4
    config.batch_size = 32
    config.wd = 1e-5
    config.early_stop = 10
    config.start_epoch = 0
    config.device = device
    config.logger = Logger()

    if not os.path.exists(config.PATH_MOD):
        os.makedirs(config.PATH_MOD)
    if not os.path.exists(config.PATH_LOG):
        os.makedirs(config.PATH_LOG)
    if not os.path.exists(config.PATH_FORECAST):
        os.makedirs(config.PATH_FORECAST)
    return config

In [3]:
config = default_config("PEMS08")

clean_data = CleanDataset(config)
train_dataset = TrafficDataset(clean_data, (0 + config.model.T_p, config.data.val_start_idx - config.model.T_p + 1), config)
train_loader = torch.utils.data.DataLoader(train_dataset, config.batch_size, shuffle=True, pin_memory=True)
config.model.A = clean_data.adj

nvidia-smi > c:\Users\cyjun\OneDrive\바탕 화면\GSLDiff/output/gpustat//gpustat.txt
wrong in load gpu info dict [Errno 2] No such file or directory: 'c:\\Users\\cyjun\\OneDrive\\바탕 화면\\GSLDiff/output/gpustat//gpustat.txt'
None gpu is avalible, try again later
sample num: 10690


In [4]:
from algorithm.diffstg2.ugnet import UGnet

In [5]:
def gather(consts: torch.Tensor, t: torch.Tensor):
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)

def q_xt_x0( x0: torch.Tensor, t: torch.Tensor, eps): # forward diffusion process
    """
    Sample from  q(x_t|x_0) ~ N(x_t; \sqrt\bar\alpha_t * x_0, (1 - \bar\alpha_t)I) 
    """
    if eps is None:
        eps = torch.randn_like(x0)

    mean = gather(alpha_bar, t) ** 0.5 * x0
    var = 1 - gather(alpha_bar, t)

    return mean + eps * (var ** 0.5)

beta = torch.linspace(0.001 ** 0.5, 0.02 ** 0.5, config.model.N) ** 2
beta = beta.to(config.device)
alpha = 1.0 - beta
alpha_bar = torch.cumprod(alpha, dim=0)

In [6]:
eps_model = UGnet(config.model)

In [7]:
for i, batch in enumerate(train_loader):
    b = batch
    future, history, pos_w, pos_d = b
    x0 = torch.cat((history, future), dim=1).to(config.device) #  (B, T, V, F)
    mask =  torch.randint_like(history, low=0, high=100) < int(config.model.mask_ratio * 100) # mask the history in a ratio with mask_ratio
    history[mask] = 0
    x_masked = torch.cat((history, torch.zeros_like(future)), dim=1).to(config.device) # (B, T, V, F)

    x0 = x0.transpose(1,3) # (B, F, V, T)
    x_masked = x_masked.transpose(1,3) # (B, F, V, T)
    eps = torch.randn_like(x0)
    t = torch.randint(0, config.model.N, (x0.shape[0],), device=x0.device, dtype=torch.long)
    xt = q_xt_x0(x0, t, eps)

    if i == 2:
        break


In [8]:
t

tensor([156, 179, 147, 148, 115,  13, 115, 191,   2, 127,  61,  58,  83, 152,
         13,  93, 195,  91,  72, 103,  83,  52,   7,  51,  88,  24,  75,  75,
        111, 144,  12, 195])

In [9]:
print(t.size()) # time 
print(eps.size()) # epsilon target (after denoising)
print(xt.size()) # before denoising
print(x_masked.size())

torch.Size([32])
torch.Size([32, 1, 170, 24])
torch.Size([32, 1, 170, 24])
torch.Size([32, 1, 170, 24])


In [10]:
eps_theta = eps_model(xt, t, (x_masked, pos_w, pos_d))

start DownSample
torch.Size([32, 64, 170, 24])
done middle
torch.Size([32, 128, 170, 24])
torch.Size([32, 128, 170, 24])
torch.Size([32, 96, 170, 24])
start UpSample
torch.Size([32, 64, 170, 48])
torch.Size([32, 64, 170, 48])
torch.Size([32, 64, 170, 48])


In [11]:
eps_theta.size()

torch.Size([32, 1, 170, 24])