In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/TrajectoryDiffusion

In [1]:
from Dataset.DidiDataset import DidiTrajectoryDataset, collectFunc
from Models.TrajUNet import TrajUNet
from DiffusionManager import DiffusionManager
from Utils import MovingAverage, saveModel, loadModel, exportONNX

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datetime import datetime
from os import makedirs

from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Configs

In [2]:
dataset_args = {
    "dataset_root": "E:/Data/Didi/xian/nov",
    "traj_length": 200,
    "lat_mean": 108.95038635089452,
    "lat_std": 0.02245034359640356,
    "lon_mean": 34.242824702030525,
    "lon_std": 0.019082048008517993
}

diffusion_args = {
    "min_beta": 0.0001,
    "max_beta": 0.005,
    "max_diffusion_step": 300,
}

model_args = {
    "channel_schedule": [128, 128, 256, 512, 1024],
    "diffusion_steps": diffusion_args["max_diffusion_step"],
    "res_blocks": 2,
}


init_lr = 1e-3

# Colab can have 51GB RAM or 12.7GB RAM, GPU is Tesla T4 which has 15GB RAM
files_per_part = 2
batch_size = 32
epochs = 100
log_interval = 100
save_interval = 10000

# Prepare

In [3]:
dataset = DidiTrajectoryDataset(**dataset_args)

In [3]:
model = TrajUNet(**model_args).cuda()
diff_manager = DiffusionManager(**diffusion_args)

In [4]:
sample_inputs = (torch.randn(1, 2, 200).cuda(), torch.tensor([0]).to(torch.long), torch.randn(1, 3).cuda())
exportONNX(model, sample_inputs, "trajunet.onnx")

  _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:]))


graph(%input_traj : Float(1, 2, 200, strides=[400, 200, 1], requires_grad=0, device=cuda:0),
      %time : Long(1, strides=[1], requires_grad=0, device=cpu),
      %attr : Float(1, 3, strides=[3, 1], requires_grad=0, device=cuda:0),
      %embed_block.time_embed_layers.0.weight : Float(128, 128, strides=[128, 1], requires_grad=1, device=cuda:0),
      %embed_block.time_embed_layers.0.bias : Float(128, strides=[1], requires_grad=1, device=cuda:0),
      %embed_block.time_embed_layers.2.weight : Float(128, 128, strides=[128, 1], requires_grad=1, device=cuda:0),
      %embed_block.time_embed_layers.2.bias : Float(128, strides=[1], requires_grad=1, device=cuda:0),
      %embed_block.attr_embed_layers.0.weight : Float(128, 3, strides=[3, 1], requires_grad=1, device=cuda:0),
      %embed_block.attr_embed_layers.0.bias : Float(128, strides=[1], requires_grad=1, device=cuda:0),
      %embed_block.attr_embed_layers.2.weight : Float(128, 128, strides=[128, 1], requires_grad=1, device=cuda:0),
  

In [5]:
loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=init_lr)

In [None]:
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
makedirs(f"Runs/{start_time}")
writer = SummaryWriter(f"Runs/{start_time}")

# Train

In [6]:
global_it = 0

mov_avg_loss = MovingAverage(log_interval)

for e in range(epochs):
    n_files_load = 0
    total_num_files = dataset.n_files
    while dataset.loadNextFiles(files_per_part):
        n_files_load  = min(n_files_load + files_per_part, total_num_files)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collectFunc)
        pbar = tqdm(dataloader, desc=f'Epoch {e} File {n_files_load}/{total_num_files}')
        for traj_0, attr in pbar:
            # Diffusion forward
            t = torch.randint(0, diffusion_args["max_diffusion_step"], (traj_0.shape[0],)).cuda()
            epsilon = torch.randn_like(traj_0).cuda()
            traj_t = diff_manager.diffusionForward(traj_0, t, epsilon)

            optimizer.zero_grad()

            epsilon_pred = model(traj_t, t, attr)
            loss = loss_func(epsilon_pred, epsilon)

            loss.backward()
            optimizer.step()

            global_it += 1
            mov_avg_loss << loss.item()
            pbar.set_postfix_str(f'Loss: {mov_avg_loss:.5f}')

            if global_it % log_interval == 0:
                writer.add_scalar('Loss', mov_avg_loss, global_it)

            if global_it % save_interval == 0:
                saveModel(model, f"Runs/{start_time}/{model.__class__.__name__}_{global_it}.pth")
            

Loading E:/Data/Didi/xian/nov\gps_20161101.pt
Loading E:/Data/Didi/xian/nov\gps_20161102.pt


Epoch 0 File 2/30:   8%|â–Š         | 199/2457 [00:20<03:58,  9.48it/s, Loss: 1.00379]


KeyboardInterrupt: 