In [1]:
import argparse

args_dict = {
    "dataset_dir": "../pusht_cchi_v7_replay.zarr.zip",
    "save_dir": "./checkpoints/",
    "batch_size": 64,
    "lr": 1e-4,
    "num_epochs": 1000,
    "diffusion_iters": 1000,
}

args = argparse.Namespace(**args_dict)

In [2]:
from diffusers import get_scheduler
from diffusers import DDPMScheduler
import os
import torch
from pushTImageDataset import get_dataloader
from train_ddp import get_nets

pred_horizon = 16
obs_horizon = 2
action_horizon = 8

nets, ema = get_nets(obs_horizon=obs_horizon)

dataloader = get_dataloader(dataset_path=args.dataset_dir, pred_horizon=pred_horizon, obs_horizon=obs_horizon,
                            action_horizon=action_horizon,
                            batch_size=args.batch_size)

optimizer = torch.optim.AdamW(params=nets.parameters(), lr=args.lr, weight_decay=1e-6)

lr_scheduler = get_scheduler(name='cosine', optimizer=optimizer, num_warmup_steps=500,
                             num_training_steps=len(dataloader) * args.num_epochs)

noise_scheduler = DDPMScheduler(num_train_timesteps=args.diffusion_iters, beta_schedule='squaredcos_cap_v2',
                                clip_sample=True,
                                prediction_type='epsilon')
save_dir = args.save_dir
if not os.path.exists(save_dir):
    os.makedirs(save_dir)


In [3]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

def train_loop(nets, dataloader, optimizer, lr_scheduler, ema, noise_scheduler, num_epochs, save_directory):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    vision_encoder = nets['vision_encoder'].to(device)
    noise_pred_net = nets['noise_pred_net'].to(device)
    
    if torch.cuda.device_count() > 1:
        vision_encoder = nn.DataParallel(vision_encoder)
        noise_pred_net = nn.DataParallel(noise_pred_net)

    for epoch_idx in range(num_epochs):
        epoch_loss = []
        vision_encoder.eval()
        noise_pred_net.train()

        for nbatch in tqdm(dataloader, desc=f'Epoch {epoch_idx + 1}/{num_epochs}'):
            nimage = nbatch['image'].to(device)
            nagent_pos = nbatch['agent_pos'].to(device)
            naction = nbatch['action'].to(device)
            B = nagent_pos.shape[0]

            image_features = vision_encoder(nimage.flatten(end_dim=1))
            image_features = image_features.reshape(*nimage.shape[:2], -1)

            obs_features = torch.cat([image_features, nagent_pos], dim=-1)
            obs_cond = obs_features.flatten(start_dim=1)

            noise = torch.randn(naction.shape, device=device)
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (B,), device=device).long()

            noisy_actions = noise_scheduler.add_noise(naction, noise, timesteps)
            noise_pred = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)

            loss = nn.functional.mse_loss(noise_pred, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            # Update EMA weights if ema is not None
            # if ema is not None:
            #     ema.update_params(nets)

            loss_cpu = loss.item()
            epoch_loss.append(loss_cpu)

        avg_loss = np.mean(epoch_loss)
        print(f'Epoch {epoch_idx+1}/{num_epochs} - Loss: {avg_loss}')

        if (epoch_idx + 1) % 10 == 0:
            time_now = time.strftime("%Y%m%d-%H%M%S") + "-" + str(epoch_idx)
            save_directory_epoch = os.path.join(save_directory, time_now)
            os.makedirs(save_directory_epoch, exist_ok=True)
            
            torch.save({
                'vision_encoder': vision_encoder.state_dict(),
                'noise_pred_net': noise_pred_net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': lr_scheduler.state_dict()
            }, os.path.join(save_directory_epoch, 'model.pth'))
            
            print(f"Saved model to {save_directory_epoch}, loss: {avg_loss}")

In [None]:
train_loop(nets, dataloader, optimizer, lr_scheduler, ema, noise_scheduler, args.num_epochs, args.save_dir)