In [None]:
import argparse
from diffusion_1d import Unet1D, TemporalUnet1D, GaussianDiffusion1D, Trainer1D
from filepath import EXP_PATH
import pprint as pp
import torch
from utils import Printer, make_dir
p = Printer(n_digits=6)
import os
import logging
import datetime
import sys
parser = argparse.ArgumentParser(description='Train EBM model')
parser.add_argument('--exp_id', default='inv_design', type=str,
                    help='experiment folder id')
parser.add_argument('--date_time', default='2023-04-23', type=str,
                    help='date for the experiment folder')
parser.add_argument('--dataset', default='nbody-4', type=str,
                    help='dataset to evaluate')
parser.add_argument('--model_type', default='temporal-unet1d', type=str,
                    help='model type.')
parser.add_argument('--batch_size', default=32, type=int,
                    help='size of batch of input to use')
parser.add_argument('--conditioned_steps', default=4, type=int,
                    help='conditioned steps')
parser.add_argument('--rollout_steps', default=20, type=int,
                    help='rollout steps')
parser.add_argument('--time_interval', default=4, type=int,
                    help='time interval')

In [None]:

if __name__ == "__main__":
    try:
        get_ipython().run_line_magic('matplotlib', 'inline')
        %load_ext autoreload
        %autoreload 2
        is_jupyter = True
        FLAGS = parser.parse_args([])
    except:
        FLAGS = parser.parse_args()
    # to get time
    current_datetime = datetime.datetime.now()
    formatted_datetime = current_datetime.strftime('%Y-%m-%d_%H-%M-%S')
    FLAGS.date_time=formatted_datetime
    pp.pprint(FLAGS.__dict__)
    n_bodies = eval(FLAGS.dataset.split("-")[1])

    if FLAGS.model_type == "unet1d":
        model = Unet1D(
            dim = 64,
            dim_mults = (1, 2, 4, 8),
            channels=FLAGS.conditioned_steps + FLAGS.rollout_steps,
        )
    elif FLAGS.model_type == "temporal-unet1d":
        model = TemporalUnet1D(
        horizon=FLAGS.conditioned_steps + FLAGS.rollout_steps,
        transition_dim=n_bodies*4,
        cond_dim=False,
        dim=64,
        dim_mults=(1, 2, 4, 8),
        attention=True,
        )
    else:
        raise

    diffusion = GaussianDiffusion1D(
        model,
        image_size = n_bodies*4,
        timesteps = 1000,           # number of steps
        sampling_timesteps = 250,   # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
        loss_type = 'l1',           # L1 or L2
    )

    exp_dirname = f"{FLAGS.exp_id}_{FLAGS.date_time}/"
    results_folder = EXP_PATH + exp_dirname + f"1D_dataset_{FLAGS.dataset}_cond_{FLAGS.conditioned_steps}_roll_{FLAGS.rollout_steps}_itv_{FLAGS.time_interval}"
    make_dir(results_folder + "/test")
    trainer = Trainer1D(
        diffusion,
        FLAGS.dataset,
        train_batch_size = FLAGS.batch_size,
        train_lr = 8e-5,
        train_num_steps = 2 ,        # total training steps
        save_and_sample_every = 1,     # save model every such steps
        gradient_accumulate_every = 2,    # gradient accumulation steps
        ema_decay = 0.995,                # exponential moving average decay
        amp = True,                       # turn on mixed precision
        calculate_fid = False,            # whether to calculate fid during training
        conditioned_steps = FLAGS.conditioned_steps,
        rollout_steps = FLAGS.rollout_steps,
        time_interval = FLAGS.time_interval,
        results_folder = results_folder,
    )

    trainer.train()


In [None]:

import numpy as np
import matplotlib.pyplot as plt
numpy_data=np.load("resultsloss_list_0808_temporal_unet1d_attention.npy")
x=np.linspace(0,len(numpy_data),len(numpy_data))
plt.plot(x, numpy_data)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('loss_list')
plt.grid(True)
plt.savefig('/usert/inverse_design/resultsloss_plot_0809_temporal-unet1d.png')
# plt.show()
np.save('/usert/inverse_design/resultsloss_list_0809_temporal-unet1d.npy', numpy_data)

model = TemporalUnet1D(
        horizon=24,
        transition_dim=16,
        cond_dim=False,
        dim=64,
        dim_mults=(1, 2, 4, 8),
        attention=True,
        )
model