In [1]:
import os
import torch
import argparse
import numpy as np
import torch.utils.data
from easydict import EasyDict as edict
from timeit import default_timer as timer

from utils.eval import Metric
from utils.gpu_dispatch import GPU
from utils.common_utils import dir_check, to_device, ws, unfold_dict, dict_merge, GpuId2CudaId, Logger

from algorithm.dataset import CleanDataset, TrafficDataset
from algorithm.diffGSL.model import DiffSTG
from typing import Tuple, Optional
def default_config(data='AIR_BJ'):
    config = edict()
    config.PATH_MOD = ws + '/output/model/'
    config.PATH_LOG = ws + '/output/log/'
    config.PATH_FORECAST = ws + '/output/forecast/'

    # Data Config
    config.data = edict()
    config.data.name = data
    config.data.path = ws + '/data/dataset/'
    config.graph_diffusion_step = 3


    config.data.feature_file = config.data.path + config.data.name + '/flow.npy'  # Add this line
    config.data.spatial = config.data.path + config.data.name + '/adj.npy'
    config.data.num_recent = 1

    # Data settings for different datasets
    if config.data.name == 'PEMS08':
        config.data.num_vertices = 170
        config.data.points_per_hour = 12
        config.data.val_start_idx = int(17856 * 0.6)
        config.data.test_start_idx = int(17856 * 0.8)

    if config.data.name == "AIR_BJ":
        config.data.num_vertices = 34
        config.data.points_per_hour = 1
        config.data.val_start_idx = int(8760 * 0.6)
        config.data.test_start_idx = int(8760 * 0.8)

    if config.data.name == 'AIR_GZ':
        config.data.num_vertices = 41
        config.data.points_per_hour = 1
        config.data.val_start_idx = int(8760 * 10 / 12)
        config.data.test_start_idx = int(8160 * 11 / 12)

    gpu_id = GPU().get_usefuel_gpu(max_memory=6000, condidate_gpu_id=[0])
    config.gpu_id = gpu_id
    if gpu_id != None:
        cuda_id = GpuId2CudaId(gpu_id)
        torch.cuda.set_device(f"cuda:{cuda_id}")
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('mps')
    # print(device)

    # Model config
    config.model = edict()
    config.model.T_p = 12
    config.model.T_h = 12
    config.model.V = config.data.num_vertices
    config.model.F = 1
    config.model.week_len = 7
    config.model.day_len = config.data.points_per_hour * 24
    config.model.device = device
    config.model.d_h = 32
    config.model.cheb_k = 3
    config.model.graph_diffusion_step = 3
    config.model.num_layers = 4

    # Diffusion model config
    config.model.N = 200
    config.model.sample_steps = 10
    config.model.epsilon_theta = 'GSTNet'
    config.model.is_label_condition = True
    config.model.beta_end = 0.02
    config.model.beta_schedule = 'quad'
    config.model.sample_strategy = 'ddpm'

    config.n_samples = 2
    config.model.channel_multipliers = [1, 2]
    config.model.supports_len = 2

    # Training config
    config.model_name = 'DiffSTG'
    config.is_test = False
    config.epoch = 50
    config.optimizer = "adam"
    config.lr = 1e-4
    config.batch_size = 32
    config.wd = 1e-5
    config.early_stop = 10
    config.start_epoch = 0
    config.device = device
    config.logger = Logger()

    if not os.path.exists(config.PATH_MOD):
        os.makedirs(config.PATH_MOD)
    if not os.path.exists(config.PATH_LOG):
        os.makedirs(config.PATH_LOG)
    if not os.path.exists(config.PATH_FORECAST):
        os.makedirs(config.PATH_FORECAST)
    return config

In [2]:
from algorithm.dataset import CleanDataset, TrafficDataset
from algorithm.diffstg.ugnet import UGnet
config = default_config("PEMS08")

clean_data = CleanDataset(config)
config.model.A = clean_data.adj
train_dataset = TrafficDataset(clean_data, (0 + config.model.T_p, config.data.val_start_idx - config.model.T_p + 1), config)
train_loader = torch.utils.data.DataLoader(train_dataset, config.batch_size, shuffle=True, pin_memory=True)

nvidia-smi > /Users/jcy/Desktop/DiffSTG-main/output/gpustat//gpustat.txt
wrong in load gpu info dict list index out of range
None gpu is avalible, try again later
sample num: 10690


sh: nvidia-smi: command not found


In [3]:
beta_start = 0.002
beta_end = 0.01
N = 200
beta = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, N) ** 2
beta = beta.to(config.device)

alpha = 1.0 - beta

alpha_bar = torch.cumprod(alpha, dim=0)

def gather(consts: torch.Tensor, t: torch.Tensor):
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)

def q_xt_x0(x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor]=None):

    if eps is None:
        eps = torch.randn_like(x0)

    mean = gather(alpha_bar, t) ** 0.5 * x0
    var = 1 - gather(alpha_bar, t)

    return mean + eps * (var ** 0.5)

In [4]:
config.device

device(type='mps')

In [5]:
import torch.nn.functional as F

In [6]:
ugnet = UGnet(config.model)

In [7]:
#graphnet = GraphNet(config.model, rnn_num_units=32, input_dim=1, y_cov_dim=1,out_dim=1)
total_params = sum(p.numel() for p in ugnet.parameters())
print(f"total parameters of the model: {total_params}")
print()
for name, param in ugnet.named_parameters():
    print(name, param.shape)

total parameters of the model: 996793

down.0.res.tcn1.conv.weight torch.Size([32, 32, 3, 3])
down.0.res.tcn1.conv.bias torch.Size([32])
down.0.res.tcn2.conv.weight torch.Size([32, 32, 3, 3])
down.0.res.tcn2.conv.bias torch.Size([32])
down.0.res.t_conv.weight torch.Size([32, 32, 1, 1])
down.0.res.t_conv.bias torch.Size([32])
down.0.res.spatial.theta torch.Size([32, 32, 2])
down.0.res.spatial.b torch.Size([1, 32, 1, 1])
down.0.res.norm.weight torch.Size([170, 32])
down.0.res.norm.bias torch.Size([170, 32])
down.1.res.tcn1.conv.weight torch.Size([32, 32, 3, 3])
down.1.res.tcn1.conv.bias torch.Size([32])
down.1.res.tcn2.conv.weight torch.Size([32, 32, 3, 3])
down.1.res.tcn2.conv.bias torch.Size([32])
down.1.res.t_conv.weight torch.Size([32, 32, 1, 1])
down.1.res.t_conv.bias torch.Size([32])
down.1.res.spatial.theta torch.Size([32, 32, 2])
down.1.res.spatial.b torch.Size([1, 32, 1, 1])
down.1.res.norm.weight torch.Size([170, 32])
down.1.res.norm.bias torch.Size([170, 32])
down.2.conv.weigh

In [8]:
print(config.device)

mps


In [9]:
optimizer = torch.optim.Adam(ugnet.parameters(), lr=0.001, weight_decay=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

ugnet = ugnet.to(config.device)


for epoch in range(config.epoch):
    n, avg_loss, time_lst = 0, 0, []
    for i, batch in enumerate(train_loader):
        print('start')
        # if i > 3 and config.is_test:break
        time_start =  timer()
        future, history, pos_w, pos_d = batch # future:(B, T_p, V, F), history: (B, T_h, V, F)

        x0 = torch.cat((history, future), dim=1).to(config.device) #  (B, T, V, F)

        # get x0_masked
        mask =  torch.randint_like(history, low=0, high=100) < int(0.0 * 100)# mask the history in a ratio with mask_ratio
        history[mask] = 0
        x_masked = torch.cat((history, torch.zeros_like(future)), dim=1).to(config.device) # (B, T, V, F)

        # reshape
        x0 = x0.transpose(1,3) # (B, F, V, T)
        x_masked = x_masked.transpose(1,3) # (B, F, V, T)

        t = torch.randint(0, N, (x0.shape[0],), device=config.device, dtype=torch.long)
        eps = torch.rand_like(x0)
        xt = q_xt_x0(x_masked, t, eps)
        print('start')
        eps_theta = ugnet(xt, t, (x_masked, pos_w.to(config.device), pos_d.to(config.device)))

        loss = 10 * F.mse_loss(eps_theta, eps)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        n += 1
        avg_loss = avg_loss * (n - 1) / n + loss.item() / n

        time_lst.append((timer() - time_start))
        message = f"{i / len(train_loader) + epoch:6.1f}| {avg_loss:0.3f} {np.sum(time_lst):.1f}s"
        print('\r' + message, end='', flush=True)

start
start


: 

In [None]:
!free -mh

zsh:1: command not found: free
