In [None]:
from Dataset.DidiDataset import DidiTrajectoryDataset, collectFunc
from Models.TrajUNet import TrajUNet
from DiffusionManager import DiffusionManager

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from tqdm import tqdm

# Configs

In [None]:
traj_length = 120
dataset_root = 'E:/Data/Didi/xian/nov'

stem_channels = 32
num_blocks = 4
max_diffusion_step = 300
res_blocks = 2
min_beta = 0.0001
max_beta = 0.005

init_lr = 1e-3

files_per_part = 2
batch_size = 32
epochs = 100

# Prepare

In [None]:
dataset = DidiTrajectoryDataset(dataset_root, traj_length)

In [None]:
model = TrajUNet(stem_channels, max_diffusion_step, num_blocks, res_blocks)
diff_manager = DiffusionManager(min_beta, max_beta, max_diffusion_step)

In [None]:
loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=init_lr)

In [None]:
for e in range(epochs):
    n_files_load = 0
    totla_num_files = dataset.num_files
    while dataset.loadNextParts(files_per_part):
        n_files_load  = min(n_files_load + files_per_part, totla_num_files)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collectFunc)
        for traj_0, attr in tqdm(dataloader, desc=f'Epoch {e} File {n_files_load}/{totla_num_files}'):
            # Diffusion forward
            t = torch.randint(0, max_diffusion_step, (traj_0.shape[0],)).cuda()
            epsilon = torch.randn_like(traj_0).cuda()
            traj_t = diff_manager.diffusionForward(traj_0, t, epsilon)

            optimizer.zero_grad()

            epsilon_pred = model(traj_t, t, attr)
            loss = loss_func(epsilon_pred, epsilon)

            loss.backward()
            optimizer.step()