In [59]:
import os
import torch
import argparse
import numpy as np
import torch.utils.data
from easydict import EasyDict as edict
from timeit import default_timer as timer

from utils.eval import Metric
from utils.gpu_dispatch import GPU
from utils.common_utils import dir_check, to_device, ws, unfold_dict, dict_merge, GpuId2CudaId, Logger

from algorithm.dataset import CleanDataset, TrafficDataset
from algorithm.diffGSL.model import DiffSTG
from typing import Tuple, Optional

In [3]:
def default_config(data='AIR_BJ'):
    config = edict()
    config.PATH_MOD = ws + '/output/model/'
    config.PATH_LOG = ws + '/output/log/'
    config.PATH_FORECAST = ws + '/output/forecast/'

    # Data Config
    config.data = edict()
    config.data.name = data
    config.data.path = ws + '/data/dataset/'
    config.graph_diffusion_step = 3


    config.data.feature_file = config.data.path + config.data.name + '/flow.npy'  # Add this line
    config.data.spatial = config.data.path + config.data.name + '/adj.npy'
    config.data.num_recent = 1

    # Data settings for different datasets
    if config.data.name == 'PEMS08':
        config.data.num_vertices = 170
        config.data.points_per_hour = 12
        config.data.val_start_idx = int(17856 * 0.6)
        config.data.test_start_idx = int(17856 * 0.8)

    if config.data.name == "AIR_BJ":
        config.data.num_vertices = 34
        config.data.points_per_hour = 1
        config.data.val_start_idx = int(8760 * 0.6)
        config.data.test_start_idx = int(8760 * 0.8)

    if config.data.name == 'AIR_GZ':
        config.data.num_vertices = 41
        config.data.points_per_hour = 1
        config.data.val_start_idx = int(8760 * 10 / 12)
        config.data.test_start_idx = int(8160 * 11 / 12)

    gpu_id = GPU().get_usefuel_gpu(max_memory=6000, condidate_gpu_id=[0])
    config.gpu_id = gpu_id
    if gpu_id != None:
        cuda_id = GpuId2CudaId(gpu_id)
        torch.cuda.set_device(f"cuda:{cuda_id}")
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')
    # print(device)

    # Model config
    config.model = edict()
    config.model.T_p = 12
    config.model.T_h = 12
    config.model.V = config.data.num_vertices
    config.model.F = 1
    config.model.week_len = 7
    config.model.day_len = config.data.points_per_hour * 24
    config.model.device = device
    config.model.d_h = 32
    config.cheb_k = 3

    # Diffusion model config
    config.model.N = 200
    config.model.sample_steps = 10
    config.model.epsilon_theta = 'GSTNet'
    config.model.is_label_condition = True
    config.model.beta_end = 0.02
    config.model.beta_schedule = 'quad'
    config.model.sample_strategy = 'ddpm'

    config.n_samples = 2
    config.model.channel_multipliers = [1, 2]
    config.model.supports_len = 2

    # Training config
    config.model_name = 'DiffSTG'
    config.is_test = False
    config.epoch = 300
    config.optimizer = "adam"
    config.lr = 1e-4
    config.batch_size = 32
    config.wd = 1e-5
    config.early_stop = 10
    config.start_epoch = 0
    config.device = device
    config.logger = Logger()

    if not os.path.exists(config.PATH_MOD):
        os.makedirs(config.PATH_MOD)
    if not os.path.exists(config.PATH_LOG):
        os.makedirs(config.PATH_LOG)
    if not os.path.exists(config.PATH_FORECAST):
        os.makedirs(config.PATH_FORECAST)
    return config

In [4]:
config = default_config("AIR_BJ")

clean_data = CleanDataset(config)
train_dataset = TrafficDataset(clean_data, (0 + config.model.T_p, config.data.val_start_idx - config.model.T_p + 1), config)
train_loader = torch.utils.data.DataLoader(train_dataset, config.batch_size, shuffle=True, pin_memory=True)

nvidia-smi > /Users/jcy/Desktop/DiffSTG-main/output/gpustat//gpustat.txt
wrong in load gpu info dict list index out of range
None gpu is avalible, try again later
sample num: 5233


sh: nvidia-smi: command not found


In [5]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x167f54850>

In [6]:
for i, batch in enumerate(train_loader):
    if i > 3 and config.is_test:break
    time_start =  timer()
    future, history, pos_w, pos_d = batch # future:(B, T_p, V, F), history: (B, T_h, V, F)

In [9]:
print(future.size())
print(history.size())

torch.Size([17, 12, 34, 1])
torch.Size([17, 12, 34, 1])


In [10]:
# get x0
x = torch.cat((history, future), dim=1).to(config.device) #  (B, T, V, F)
print(x.size())

torch.Size([17, 24, 34, 1])


In [12]:
mask =  torch.randint_like(history, low=0, high=100) < int(0.01 * 100)# mask the history in a ratio with mask_ratio
history[mask] = 0
x_masked = torch.cat((history, torch.zeros_like(future)), dim=1).to(config.device) # (B, T, V, F)

In [17]:
torch.sum(x_masked != 0)

tensor(6870)

In [18]:
# reshape
x = x.transpose(1,3) # (B, F, V, T)
x_masked = x_masked.transpose(1,3) # (B, F, V, T)

In [20]:
print(x.size())
print(x_masked.size())

torch.Size([17, 1, 34, 24])
torch.Size([17, 1, 34, 24])


In [60]:
# loss calculate

# loss = 10 * model.loss(x, (x_masked, pos_w.to(config.device), pos_d.to(config.device))) 
# x : xo , c = c: The condition, c is a tuple of torch tensor, here c = (feature, pos_w, pos_d)

In [None]:
'''
def loss(self, x0: torch.Tensor, c: Tuple):
    """
    Loss calculation
    x0: (B, ...)
    c: The condition, c is a tuple of torch tensor, here c = (feature, pos_w, pos_d)
    """
    #
    t = torch.randint(0, self.N, (x0.shape[0],), device=x0.device, dtype=torch.long)

    # Note that in the paper, t \in [1, T], but in the code, t \in [0, T-1]
    eps = torch.randn_like(x0)

    xt = self.q_xt_x0(x0, t, eps)

    # rint("eps_model_device:",self.eps_model.device)

    x, y_cov = c

    eps_theta = self.eps_model(xt, y_cov)
    loss = nn.MSELoss()

    # return F.mse_loss(eps, eps_theta)
    return loss(eps, eps_theta)
'''

In [37]:
t = torch.randint(0, 12, (x.shape[0],), device=x.device, dtype=torch.long)

In [38]:
t.size() # batch size 

torch.Size([17])

In [24]:
eps = torch.randn_like(x)

In [25]:
eps.size()

torch.Size([17, 1, 34, 24])

In [52]:
def gather(consts: torch.Tensor, t: torch.Tensor): # 이런 상수 텐서는 주로 모델의 계산 과정에서 사용되며, 특정 시간 스텝에 따라 필요한 값들을 가져올 때 사용
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)

In [39]:
beta = torch.linspace(0.0001 ** 0.5, 0.02 ** 0.5, 12) ** 2
beta = beta.to(config.device)

In [40]:
beta.size()

torch.Size([12])

In [61]:
beta = torch.linspace(0.0001 ** 0.5, 0.02 ** 0.5, 12) ** 2
beta = beta.to(config.device)
alpha = 1.0 - beta

alpha_bar = torch.cumprod(alpha, dim=0) # 누적 곱은 각 시간 스텝까지의 누적된 확률이나 가중치를 계산

In [42]:
alpha_bar

tensor([0.9999, 0.9994, 0.9983, 0.9962, 0.9928, 0.9880, 0.9814, 0.9728, 0.9620,
        0.9487, 0.9328, 0.9141])

In [43]:
def q_xt_x0(x0: torch.Tensor, t: torch.Tensor, eps): # forward diffusion process
    """
    Sample from  q(x_t|x_0) ~ N(x_t; \sqrt\bar\alpha_t * x_0, (1 - \bar\alpha_t)I) 
    """
    if eps is None:
        eps = torch.randn_like(x0)

    mean = gather(alpha_bar, t) ** 0.5 * x0
    var = 1 - gather(alpha_bar, t)

    return mean + eps * (var ** 0.5)

In [46]:
mean = gather(alpha_bar, t) ** 0.5 * x

In [47]:
mean.size()

torch.Size([17, 1, 34, 24])

In [51]:
t.size()

torch.Size([17])

In [53]:
gather(alpha_bar, t).size()

torch.Size([17, 1, 1, 1])

In [57]:
alpha_bar.size()

torch.Size([12])

In [58]:
t.size()

torch.Size([17])

In [54]:
c = alpha_bar.gather(-1, t)

In [55]:
c.size()

torch.Size([17])

In [48]:
var = 1 - gather(alpha_bar, t)

In [49]:
var.size()

torch.Size([17, 1, 1, 1])

In [44]:
xt = q_xt_x0(x, t, eps)

In [45]:
xt.size()

torch.Size([17, 1, 34, 24])