In [38]:
import os
import sys
import matplotlib

import matplotlib.pyplot
import matplotlib.pyplot as plt
import time
import datetime
import argparse
import numpy as np
import pandas as pd
from random import SystemRandom
from sklearn import model_selection
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.nn.functional import relu
import torch.optim as optim

import lib.utils as utils
from lib.plotting import *

from lib.rnn_baselines import *
from lib.ode_rnn import *
from lib.create_latent_ode_model import create_LatentODE_model
from lib.parse_datasets import parse_datasets
from lib.ode_func import ODEFunc, ODEFunc_w_Poisson
from lib.diffeq_solver import DiffeqSolver
from mujoco_physics import HopperPhysics
import os
import numpy as np

import torch
import torch.nn as nn

import lib.utils as utils
from lib.diffeq_solver import DiffeqSolver
from generate_timeseries import Periodic_1d
from torch.distributions import uniform

from torch.utils.data import DataLoader
from mujoco_physics import HopperPhysics
from physionet import variable_time_collate_fn, get_data_min_max
from person_activity import PersonActivity, variable_time_collate_fn_activity

from sklearn import model_selection
import random
from lib.utils import compute_loss_all_batches

In [39]:
# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('-n',  type=int, default=10000, help="Size of the dataset")
parser.add_argument('--niters', type=int, default=20)
parser.add_argument('--lr',  type=float, default=1e-2, help="Starting learning rate.")
parser.add_argument('-b', '--batch-size', type=int, default=256)
parser.add_argument('--viz', action='store_true', default=False, help="Show plots while training")

parser.add_argument('--save', type=str, default='experiments/', help="Path for save checkpoints")
parser.add_argument('--load', type=str, default=None, help="ID of the experiment to load for evaluation. If None, run a new experiment.")
parser.add_argument('-r', '--random-seed', type=int, default=1991, help="Random_seed")

parser.add_argument('--dataset', type=str, default='physionet', help="Dataset to load. Available: physionet, activity, hopper, periodic")
parser.add_argument('-s', '--sample-tp', type=float, default=None, help="Number of time points to sub-sample."
	"If > 1, subsample exact number of points. If the number is in [0,1], take a percentage of available points per time series. If None, do not subsample")

parser.add_argument('-c', '--cut-tp', type=int, default=None, help="Cut out the section of the timeline of the specified length (in number of points)."
	"Used for periodic function demo.")

parser.add_argument('--quantization', type=float, default=5, help="Quantization on the physionet dataset."
	"Value 1 means quantization by 1 hour, value 0.1 means quantization by 0.1 hour = 6 min")

parser.add_argument('--latent-ode', action='store_true', default=True,  help="Run Latent ODE seq2seq model")
parser.add_argument('--z0-encoder', type=str, default='odernn', help="Type of encoder for Latent ODE model: odernn or rnn")

parser.add_argument('--classic-rnn', action='store_true', help="Run RNN baseline: classic RNN that sees true points at every point. Used for interpolation only.")
parser.add_argument('--rnn-cell', default="gru", help="RNN Cell type. Available: gru (default), expdecay")
parser.add_argument('--input-decay', action='store_true', help="For RNN: use the input that is the weighted average of impirical mean and previous value (like in GRU-D)")

parser.add_argument('--ode-rnn', action='store_true', help="Run ODE-RNN baseline: RNN-style that sees true points at every point. Used for interpolation only.")

parser.add_argument('--rnn-vae', action='store_true', help="Run RNN baseline: seq2seq model with sampling of the h0 and ELBO loss.")

parser.add_argument('-l', '--latents', type=int, default=6, help="Size of the latent state")
parser.add_argument('--rec-dims', type=int, default=40, help="Dimensionality of the recognition model (ODE or RNN).")

parser.add_argument('--rec-layers', type=int, default=3, help="Number of layers in ODE func in recognition ODE")
parser.add_argument('--gen-layers', type=int, default=3, help="Number of layers in ODE func in generative ODE")

parser.add_argument('-u', '--units', type=int, default=50, help="Number of units per layer in ODE func")
parser.add_argument('-g', '--gru-units', type=int, default=100, help="Number of units per layer in each of GRU update networks")

parser.add_argument('--poisson', action='store_true', help="Model poisson-process likelihood for the density of events in addition to reconstruction.")
parser.add_argument('--classif', action='store_true', help="Include binary classification loss -- used for Physionet dataset for hospiral mortality")

parser.add_argument('--linear-classif', action='store_true', help="If using a classifier, use a linear classifier instead of 1-layer NN")
parser.add_argument('--extrap', action='store_true', help="Set extrapolation mode. If this flag is not set, run interpolation mode.")

parser.add_argument('-t', '--timepoints', type=int, default=100, help="Total number of time-points")
parser.add_argument('--max-t',  type=float, default=5., help="We subsample points in the interval [0, args.max_tp]")
parser.add_argument('--noise-weight', type=float, default=0.01, help="Noise amplitude for generated traejctories")




class Args:
    def __init__(self):
        self.n = 10000
        self.niters = 1000
        self.lr = 1e-2
        self.batch_size = 128
        self.viz = True
        self.save = 'experiments/'
        self.load = None
        self.random_seed = 1991
        self.dataset = 'physionet'
        self.sample_tp = 0.6
        self.cut_tp = None
        self.quantization = 5
        self.latent_ode = True
        self.z0_encoder = 'odernn'
        self.classic_rnn = False
        self.rnn_cell = "gru"
        self.input_decay = False
        self.ode_rnn = False
        self.rnn_vae = False
        self.latents = 20
        self.rec_dims = 10
        self.rec_layers = 5
        self.gen_layers = 5
        self.units = 50
        self.gru_units = 100
        self.poisson = False
        self.classif = False
        self.linear_classif = False
        self.extrap = False
        self.timepoints = 100
        self.max_t = 5.
        self.noise_weight = 0.01

# args 객체를 생성하고 필요한 설정을 할당합니다
args = Args()

# 몇 가지 설정을 수정합니다
args.batch_size = 64
args.classif = False
args.quantization = 5
args.niters = 1000
args.n = 100
args.sample_tp = 0.6
args.latents = 20
args.rec_dims = 10
args.rec_layers = 5
args.gen_layers = 5
args.latent_ode = True
args.viz = True
args.max_t = 200

# 사용할 디바이스를 설정합니다
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 파일 이름을 설정합니다
file_name = 'taejun_sim_'

# 설정을 확인합니다
print("Batch size:", args.batch_size)
print("Number of iterations:", args.niters)
print("Learning rate:", args.lr)
print("Device:", device)



In [40]:
args

Namespace(batch_size=64, classic_rnn=False, classif=False, cut_tp=None, dataset='physionet', extrap=False, gen_layers=5, gru_units=100, input_decay=False, l=20, latent_ode=True, latents=6, linear_classif=False, load=None, lr=0.01, max_t=5.0, n=2, niters=1000, noise_weight=0.01, ode_rnn=False, poisson=False, quantization=5, random_seed=1991, rec_dims=10, rec_layers=5, rnn_cell='gru', rnn_vae=False, s=30, sample_tp=None, save='experiments/', timepoints=100, units=50, viz=True, z0_encoder='odernn')

In [41]:
import os
import torch
from tqdm import tqdm
from jdcal import jd2gcal
from datetime import datetime

class CustomClass(object):
    params = ['Magnitude']  # Uncertainty_of_Magnitude 제외

    params_dict = {k: i for i, k in enumerate(params)}

    def __init__(self, root, train=True, preprocess=False,
                 quantization=1, n_samples=None, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")):
        self.root = root
        self.train = train
        self.quantization = quantization
        self.device = device
        self.reduce = "average"
        if preprocess:
            self.preprocess()

        self.data = torch.load(os.path.join(self.root, 'lc_' + str(self.quantization) + '.pt'))
        self.labels = torch.zeros(len(self.data))

        if n_samples is not None:
            self.data = self.data[:n_samples]
            self.labels = self.labels[:n_samples]

    def jd_to_total_hours(self,jd):
        jd_int = int(jd)
        jd_frac = jd - jd_int
        # jd2gcal 함수를 사용하여 율리우스 일을 그레고리언 날짜로 변환
        # 이 예에서는 jd2gcal 함수의 정의나 import 방법이 제공되지 않았으므로, 해당 함수의 정확한 동작을 가정합니다.
        year, month, day, fraction = jd2gcal(jd_int, jd_frac)
        # 연도 계산이 특정 요구 사항에 맞게 조정되어 있습니다.
        year = (year + 4553) + 2000
        hours = int(fraction * 24)
        converted_datetime = datetime(year, month, day, hours)
        unix_start = datetime(1970, 1, 1)
        total_hours = int((converted_datetime - unix_start).total_seconds() / 86400)
        return total_hours

    def preprocess(self):
        simulation_data_root = self.root 
        data_list = [f for f in os.listdir(simulation_data_root) if f.endswith('.lc')]
        data_list.sort()
        light_curves = []

        for name in tqdm(data_list):
            lc_name = name.partition('.')[0]
            with open(os.path.join(simulation_data_root, name)) as f:
                # next(f)  # 헤더 건너뛰기
                magnitudes = [float(line.rstrip().split(' ')[1]) for line in f]
                global_min = min(magnitudes)
                global_max = max(magnitudes)

            # 파일을 다시 열어서 데이터 처리
            with open(os.path.join(simulation_data_root, name)) as f:
                next(f)
                lines = f.readlines()
                prev_time = 0
                tt = [0.]
                vals = [torch.zeros(len(self.params), device=self.device)]
                mask = [torch.zeros(len(self.params), device=self.device)]
                for line in lines:
                    time, magnitude = line.rstrip().split(' ')[:2]
                    time = float(time)
                    magnitude = float(magnitude)
                    time = self.jd_to_total_hours(time)

                    # 파일별 최소/최대값으로 스케일링
                    scaled_magnitude = (magnitude - global_min) / (global_max - global_min) if global_max > global_min else 0.0

                    if time != prev_time:
                        tt.append(time)
                        vals.append(torch.zeros(len(self.params), device=self.device))
                        mask.append(torch.zeros(len(self.params), device=self.device))
                        prev_time = time

                    vals[-1][0] = scaled_magnitude  # 스케일된 값 사용
                    mask[-1][0] = 1

            tt = torch.tensor(tt, device=self.device)[1:]
            vals = torch.stack(vals)[1:]
            mask = torch.stack(mask)[1:]
            labels = None
            light_curves.append((lc_name, tt, vals, mask, labels))

        torch.save(light_curves, os.path.join(self.root, 'lc_' + str(self.quantization) + '.pt'))
        print('Done!')



    # def preprocess(self):
    #     simulation_data_root = self.root 
    #     data_list = [f for f in os.listdir(simulation_data_root) if f.endswith('.txt')]
    #     data_list.sort()
    #     light_curves = []

    #     # 전체 데이터셋에 대한 최소값과 최대값을 초기화
    #     global_min = float('inf')
    #     global_max = float('-inf')

    #     # 최소값과 최대값 찾기
    #     for name in tqdm(data_list):
    #         with open(os.path.join(simulation_data_root, name)) as f:
    #             next(f)  # 헤더 건너뛰기
    #             for line in f:
    #                 _, magnitude = line.rstrip().split(' ')[:2]
    #                 magnitude = float(magnitude)
    #                 global_min = min(global_min, magnitude)
    #                 global_max = max(global_max, magnitude)

    #     # 데이터 스케일링 및 처리
    #     for name in tqdm(data_list):
    #         lc_name = name.partition('.')[0]
    #         with open(os.path.join(simulation_data_root, name)) as f:
    #             next(f)
    #             lines = f.readlines()
    #             prev_time = 0
    #             tt = [0.]
    #             vals = [torch.zeros(len(self.params), device=self.device)]
    #             mask = [torch.zeros(len(self.params), device=self.device)]
    #             for line in lines:
    #                 time, magnitude = line.rstrip().split(' ')[:2]
    #                 time = float(time)
    #                 magnitude = float(magnitude)

    #                 # 스케일링: (magnitude - global_min) / (global_max - global_min)
    #                 scaled_magnitude = (magnitude - global_min) / (global_max - global_min)

    #                 if time != prev_time:
    #                     tt.append(time)
    #                     vals.append(torch.zeros(len(self.params), device=self.device))
    #                     mask.append(torch.zeros(len(self.params), device=self.device))
    #                     prev_time = time

    #                 vals[-1][0] = scaled_magnitude  # 스케일된 값 사용
    #                 mask[-1][0] = 1

    #         tt = torch.tensor(tt, device=self.device)[1:]
    #         vals = torch.stack(vals)[1:]
    #         mask = torch.stack(mask)[1:]
    #         labels = None
    #         light_curves.append((lc_name, tt, vals, mask, labels))

    #     torch.save(light_curves, os.path.join(self.root, 'lc_' + str(self.quantization) + '.pt'))
    #     print('Done!')


    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Split: {}\n'.format('train' if self.train else 'test')
        fmt_str += '    Root Location: {}\n'.format(self.root)
        fmt_str += '    Quantization: {}\n'.format(self.quantization)
        fmt_str += '    Reduce: {}\n'.format(self.reduce)
        return fmt_str

    def visualize(self, timesteps, data, mask, plot_name):
        width = 15
        height = 15

        non_zero_attributes = (torch.sum(mask,0) > 2).numpy()
        non_zero_idx = [i for i in range(len(non_zero_attributes)) if non_zero_attributes[i] == 1.]
        n_non_zero = sum(non_zero_attributes)

        mask = mask[:, non_zero_idx]
        data = data[:, non_zero_idx]

        params_non_zero = [self.params[i] for i in non_zero_idx]
        params_dict = {k: i for i, k in enumerate(params_non_zero)}

        n_col = 3
        n_row = n_non_zero // n_col + (n_non_zero % n_col > 0)
        fig, ax_list = plt.subplots(n_row, n_col, figsize=(width, height), facecolor='white')

        #for i in range(len(self.params)):
        for i in range(n_non_zero):
            param = params_non_zero[i]
            param_id = params_dict[param]

            tp_mask = mask[:,param_id].long()

            tp_cur_param = timesteps[tp_mask == 1.]
            data_cur_param = data[tp_mask == 1., param_id]

            ax_list[i // n_col, i % n_col].plot(tp_cur_param.numpy(), data_cur_param.numpy(),  marker='o') 
            ax_list[i // n_col, i % n_col].set_title(param)

        fig.tight_layout()
        fig.savefig(plot_name)
        plt.close(fig)


            # 예제 사용
            # device 설정: CUDA 사용 가능한 경우 CUDA 사용, 그렇지 않으면 CPU 사용
            #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

            # CustomClass 객체 생성 예제
            #simulation_sdss = CustomClass(root='/home/intern/SSD/intern/taejun/test', train=True, preprocess=True,quantization=1, device=device)

            # 객체 생성 시, preprocess=True로 설정하여 데이터 전처리 및 저장이 이루어집니다.


In [42]:
#코드에 구현된거 하지만 잘 안된다 수정해서 해보자
def variable_time_collate_fn(batch, args, device = device, data_type = "train", 
	data_min = None, data_max = None):
	"""
	Expects a batch of time series data in the form of (record_id, tt, vals, mask, labels) where
		- record_id is a patient id
		- tt is a 1-dimensional tensor containing T time values of observations.
		- vals is a (T, D) tensor containing observed values for D variables.
		- mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
		- labels is a list of labels for the current patient, if labels are available. Otherwise None.
	Returns:
		combined_tt: The union of all time observations.
		combined_vals: (M, T, D) tensor containing the observed values.
		combined_mask: (M, T, D) tensor containing 1 where values were observed and 0 otherwise.
	"""
	D = batch[0][2].shape[1]
	combined_tt, inverse_indices = torch.unique(torch.cat([ex[1] for ex in batch]), sorted=True, return_inverse=True)
	combined_tt = combined_tt.to(device)

	offset = 0
	combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device)
	combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device)
	
	combined_labels = None
	N_labels = 1

	combined_labels = torch.zeros(len(batch), N_labels) + torch.tensor(float('nan'))
	combined_labels = combined_labels.to(device = device)
	
	for b, (record_id, tt, vals, mask, labels) in enumerate(batch):
		tt = tt.to(device)
		vals = vals.to(device)
		mask = mask.to(device)
		if labels is not None:
			labels = labels.to(device)

		indices = inverse_indices[offset:offset + len(tt)]
		offset += len(tt)

		combined_vals[b, indices] = vals
		combined_mask[b, indices] = mask

		if labels is not None:
			combined_labels[b] = labels

	combined_vals, _, _ = utils.normalize_masked_data(combined_vals, combined_mask, 
	  	att_min = data_min, att_max = data_max)

	# if torch.max(combined_tt) != 0.:
	# 	combined_tt = combined_tt / torch.max(combined_tt)
	
	data_dict = {
		"data": combined_vals, 
		"time_steps": combined_tt,
		"mask": combined_mask,
		"labels": combined_labels}
	
	data_dict = utils.split_and_subsample_batch(data_dict, args, data_type = data_type)
	return data_dict

In [43]:
# first_batch = next(iter(train_dataloader))

# # 첫 번째 배치의 실제 내용을 출력하여 어떤 키가 있는지 확인
# print(first_batch.keys())

# # 첫 번째 배치의 구성요소들의 shape를 출력
# print("observed_data:", first_batch['observed_data'].shape)
# print("observed_tp:", first_batch['observed_tp'].shape)
# print("data_to_predict:", first_batch['data_to_predict'].shape)
# print("tp_to_predict:", first_batch['tp_to_predict'].shape)
# print("observed_mask:", first_batch['observed_mask'].shape)
# print("mask_predicted_data", first_batch['mask_predicted_data'].shape)

In [44]:
# num_zeros = torch.sum(first_batch['observed_mask']== 0).item()  # 마스크에서 0의 개수
# num_ones = torch.sum(first_batch['observed_mask']== 1).item()

# print(num_zeros)
# print(num_ones)

In [45]:
# # DataLoader에서 첫 번째 배치 데이터를 가져옴
# batch_data = next(iter(train_dataloader))
# print(batch_data.keys())
# # 마스킹된 데이터 확인
# mask = batch_data["mask"]  # "mask" 키를 사용하여 마스크 텐서를 가져옴

# # 마스크 텐서에서 0과 1의 개수를 세어서 출력
# num_zeros = torch.sum(mask == 0).item()  # 마스크에서 0의 개수
# num_ones = torch.sum(mask == 1).item()  # 마스크에서 1의 개수

# print(f"마스크에서 0의 개수: {num_zeros}")
# print(f"마스크에서 1의 개수: {num_ones}")

# # 옵셔널: 마스킹된 데이터 시각화
# # 데이터와 마스크를 시각화하는 코드를 추가할 수 있음


In [46]:
pt_file_path = '/home/intern/SSD/intern/taejun/data_normal/lc_5.pt'

# .pt 파일 로드
light_curves_data = torch.load(pt_file_path)

# 로드된 데이터의 타입과 크기 확인
print(f"Loaded data type: {type(light_curves_data)}")
print(f"Number of light curves in the dataset: {len(light_curves_data)}")

# 첫 번째 광도곡선 데이터의 구조 확인
first_light_curve = light_curves_data[0]
print(f"Structure of a single light curve data: {type(first_light_curve)}")
print(f"Record ID: {first_light_curve[0]}")
print(f"Time stamps tensor shape: {first_light_curve[1].shape}")
print(f"Magnitude tensor shape: {first_light_curve[2].shape}")
print(f"Mask tensor shape: {first_light_curve[3].shape}")
#print(f"Labels tensor shape: {first_light_curve[4].shape}")

# 첫 번째 광도곡선 데이터의 일부 내용 출력
print("\nSample data from the first light curve:")
print(f"Time stamps: {first_light_curve[1]}")  # 처음 5개의 시간 스탬프
print(f"Magnitude: {first_light_curve[2]}")  # 처음 5개의 광도 값


Loaded data type: <class 'list'>
Number of light curves in the dataset: 10
Structure of a single light curve data: <class 'tuple'>
Record ID: time_series_band_1
Time stamps tensor shape: torch.Size([99])
Magnitude tensor shape: torch.Size([99, 1])
Mask tensor shape: torch.Size([99, 1])

Sample data from the first light curve:
Time stamps: tensor([  1.0101,   2.0202,   3.0303,   4.0404,   5.0505,   6.0606,   7.0707,
          8.0808,   9.0909,  10.1010,  11.1111,  12.1212,  13.1313,  14.1414,
         15.1515,  16.1616,  17.1717,  18.1818,  19.1919,  20.2020,  21.2121,
         22.2222,  23.2323,  24.2424,  25.2525,  26.2626,  27.2727,  28.2828,
         29.2929,  30.3030,  31.3131,  32.3232,  33.3333,  34.3434,  35.3535,
         36.3636,  37.3737,  38.3838,  39.3939,  40.4040,  41.4141,  42.4242,
         43.4343,  44.4444,  45.4545,  46.4646,  47.4747,  48.4848,  49.4949,
         50.5051,  51.5152,  52.5253,  53.5354,  54.5455,  55.5556,  56.5657,
         57.5758,  58.5859,  59.596

In [47]:
train_dataset_obj = CustomClass(root='/home/intern/SSD/intern/taejun/test_normal', train=True, preprocess=True,quantization=5,n_samples = min(800, args.n), device=device)
		# Use custom collate_fn to combine samples with arbitrary time observations.
		# Returns the dataset along with mask and time steps


# Combine and shuffle samples from physionet Train and physionet Test
total_dataset = train_dataset_obj[:len(train_dataset_obj)]



# Shuffle and split
train_data, test_data = model_selection.train_test_split(total_dataset, train_size= 1, 
    random_state = 42, shuffle = True)

record_id, tt, vals, mask, labels = train_data[0]

n_samples = len(total_dataset)
input_dim = vals.size(-1)

batch_size = min(min(len(train_dataset_obj), args.batch_size), args.n)
data_min, data_max = get_data_min_max(total_dataset)

train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=False, 
    collate_fn= lambda batch: variable_time_collate_fn(batch, args, device, data_type = "train",
        data_min = data_min, data_max = data_max))
test_dataloader = DataLoader(train_data, batch_size = batch_size, shuffle=False, 
    collate_fn= lambda batch: variable_time_collate_fn(batch, args, device, data_type = "train",
        data_min = data_min, data_max = data_max))

attr_names = train_dataset_obj.params
data_objects = {"dataset_obj": train_dataset_obj, 
            "train_dataloader": utils.inf_generator(train_dataloader), 
            "test_dataloader": utils.inf_generator(test_dataloader),
            "input_dim": input_dim,
            "n_train_batches": len(train_dataloader),
            "n_test_batches": len(test_dataloader),
            "attr": attr_names, #optional
            }


0it [00:00, ?it/s]

Done!





ValueError: train_size=1 should be either positive and smaller than the number of samples 0 or a float in the (0, 1) range

In [None]:
# train_data

In [None]:
# test_data

In [None]:
# # `train_dataloader`에서 첫 번째 배치를 가져와서 형태와 내용을 확인합니다.
# print("Train DataLoader:")
# for i, batch in enumerate(test_dataloader):
#     print(f"Batch {i}:")
#     print("observed_data.shape:", batch['observed_data'].shape)
#     print("observed_data:", batch['observed_data'])
#     # 여기에서 필요한 다른 키들도 확인할 수 있습니다.
#     # 예를
# # 

In [None]:
import os

# Path to the directory where you want to save the file
directory = "experiments"

# Check if the directory exists
if not os.path.exists(directory):
    # If the directory does not exist, create it
    os.makedirs(directory)

In [None]:
if __name__ == '__main__':
	torch.manual_seed(args.random_seed)
	np.random.seed(args.random_seed)

	experimentID = args.load
	if experimentID is None:
		# Make a new experiment ID
		experimentID = int(SystemRandom().random()*100000)
	ckpt_path = os.path.join(args.save, "experiment_" + str(experimentID) + '.ckpt')

	start = time.time()
	print("Sampling dataset of {} training examples".format(args.n))
	
	input_command = sys.argv
	ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
	if len(ind) == 1:
		ind = ind[0]
		input_command = input_command[:ind] + input_command[(ind+2):]
	input_command = " ".join(input_command)

	utils.makedirs("results/")

	##################################################################
	data_obj = data_objects
	input_dim = data_obj["input_dim"]

	classif_per_tp = False
	if ("classif_per_tp" in data_obj):
		# do classification per time point rather than on a time series as a whole
		classif_per_tp = data_obj["classif_per_tp"]

	if args.classif and (args.dataset == "hopper" or args.dataset == "periodic"):
		raise Exception("Classification task is not available for MuJoCo and 1d datasets")

	n_labels = 1
	if args.classif:
		if ("n_labels" in data_obj):
			n_labels = data_obj["n_labels"]
		else:
			raise Exception("Please provide number of labels for classification task")

	##################################################################
	# Create the model
	obsrv_std = 0.01
	obsrv_std = torch.Tensor([obsrv_std]).to(device)
	z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))

	if args.rnn_vae:
		if args.poisson:
			print("Poisson process likelihood not implemented for RNN-VAE: ignoring --poisson")

		# Create RNN-VAE model
		model = RNN_VAE(input_dim, args.latents, 
			device = device, 
			rec_dims = args.rec_dims, 
			concat_mask = True, 
			obsrv_std = obsrv_std,
			z0_prior = z0_prior,
			use_binary_classif = args.classif,
			classif_per_tp = classif_per_tp,
			linear_classifier = args.linear_classif,
			n_units = args.units,
			input_space_decay = args.input_decay,
			cell = args.rnn_cell,
			n_labels = n_labels,
			train_classif_w_reconstr = (args.dataset == "physionet")
			).to(device)
		
	elif args.classic_rnn:
		if args.poisson:
			print("Poisson process likelihood not implemented for RNN: ignoring --poisson")

		if args.extrap:
			raise Exception("Extrapolation for standard RNN not implemented")
		# Create RNN model
		model = Classic_RNN(input_dim, args.latents, device, 
			concat_mask = True, obsrv_std = obsrv_std,
			n_units = args.units,
			use_binary_classif = args.classif,
			classif_per_tp = classif_per_tp,
			linear_classifier = args.linear_classif,
			input_space_decay = args.input_decay,
			cell = args.rnn_cell,
			n_labels = n_labels,
			train_classif_w_reconstr = (args.dataset == "physionet")
			).to(device)
		
	elif args.ode_rnn:
		# Create ODE-GRU model
		n_ode_gru_dims = args.latents
				
		if args.poisson:
			print("Poisson process likelihood not implemented for ODE-RNN: ignoring --poisson")

		if args.extrap:
			raise Exception("Extrapolation for ODE-RNN not implemented")

		ode_func_net = utils.create_net(n_ode_gru_dims, n_ode_gru_dims, 
			n_layers = args.rec_layers, n_units = args.units, nonlinear = nn.Tanh)

		rec_ode_func = ODEFunc(
			input_dim = input_dim, 
			latent_dim = n_ode_gru_dims,
			ode_func_net = ode_func_net,
			device = device).to(device)

		z0_diffeq_solver = DiffeqSolver(input_dim, rec_ode_func, "euler", args.latents, 
			odeint_rtol = 1e-3, odeint_atol = 1e-4, device = device)
	
		model = ODE_RNN(input_dim, n_ode_gru_dims, device = device, 
			z0_diffeq_solver = z0_diffeq_solver, n_gru_units = args.gru_units,
			concat_mask = True, obsrv_std = obsrv_std,
			use_binary_classif = args.classif,
			classif_per_tp = classif_per_tp,
			n_labels = n_labels,
			train_classif_w_reconstr = (args.dataset == "physionet")
			).to(device)
	elif args.latent_ode:
		model = create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, 
			classif_per_tp = classif_per_tp,
			n_labels = n_labels)
	else:
		raise Exception("Model not specified")

	##################################################################

	if args.viz:
		viz = Visualizations(device)

	##################################################################
	
	#Load checkpoint and evaluate the model
	if args.load is not None:
		utils.get_ckpt_model(ckpt_path, model, device)
		exit()

	##################################################################
	# Training

	log_path = "logs/" + file_name + "_" + str(experimentID) + ".log"
	if not os.path.exists("logs/"):
		utils.makedirs("logs/")
	logger = utils.get_logger(logpath=log_path, filepath=os.path.abspath(file_name))
	logger.info(input_command)

	optimizer = optim.Adamax(model.parameters(), lr=args.lr)

	num_batches = data_obj["n_train_batches"]

	for itr in range(1, num_batches * (args.niters + 1)):
		optimizer.zero_grad()
		utils.update_learning_rate(optimizer, decay_rate = 0.999, lowest = args.lr / 10)

		wait_until_kl_inc = 10
		if itr // num_batches < wait_until_kl_inc:
			kl_coef = 0.
		else:
			kl_coef = (1-0.99** (itr // num_batches - wait_until_kl_inc))

		batch_dict = utils.get_next_batch(data_obj["train_dataloader"])
		# print(batch_dict)
		train_res = model.compute_all_losses(batch_dict, n_traj_samples = 10, kl_coef = kl_coef)
		train_res["loss"].backward()
		optimizer.step()

		n_iters_to_viz = 10
		if itr % (n_iters_to_viz * num_batches) == 0:
			with torch.no_grad():

				test_res = compute_loss_all_batches(model, 
					data_obj["test_dataloader"], args,
					n_batches = data_obj["n_test_batches"],
					experimentID = experimentID,
					device = device,
					n_traj_samples = 10, kl_coef = kl_coef)

				message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
					itr//num_batches, 
					test_res["loss"].detach(), test_res["likelihood"].detach(), 
					test_res["kl_first_p"], test_res["std_first_p"],
					train_res['loss'].detach(),train_res['likelihood'].detach())
				sample_test = utils.get_next_batch(data_obj["test_dataloader"])
				no_zero_count = torch.sum(sample_test['observed_data']!= 0).item()
				zero_count = torch.sum(sample_test['observed_data'] == 0).item()
				print("zero_count 개수: ",zero_count)
				print("no_zero_count 개수: ",no_zero_count)
				plot_name_test = "test_2024_02_15_resultFinal_test{:04d}".format(itr//num_batches)
				plot_name_train = "train_2024_02_15_resultFinal_train{:04d}".format(itr//num_batches)
				Visualizations(device).draw_all_plots_one_dim(sample_test ,model.to(device), plot_name=plot_name_test, save = True)
				Visualizations(device).draw_all_plots_one_dim(batch_dict ,model.to(device), plot_name=plot_name_train, save = True)
				
		 	
				logger.info("Experiment " + str(experimentID))
				logger.info(message)
				logger.info("KL coef: {}".format(kl_coef))
				logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
				logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))
				
				

				if "mse" in test_res:
					logger.info("Test MSE: {:.4f}".format(test_res["mse"]))

				

			


			



/media/usr/SSD/intern/taejun/latent_ode/taejun_sim_
/home/intern/anaconda3/envs/taejun/lib/python3.8/site-packages/ipykernel_launcher.py --f=/home/intern/.local/share/jupyter/runtime/kernel-v2-3275204OMSDdZLhQHy2.json


Sampling dataset of 2 training examples
1
2
3


KeyError: 'n_train_batches'

In [None]:
with torch.no_grad():

    test_res = compute_loss_all_batches(model, 
        data_obj["test_dataloader"], args,
        n_batches = data_obj["n_test_batches"],
        experimentID = experimentID,
        device = device,
        n_traj_samples = 5, kl_coef = kl_coef)

    message = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
        itr//num_batches, 
        test_res["loss"].detach(), test_res["likelihood"].detach(), 
        test_res["kl_first_p"], test_res["std_first_p"])
    sample_test = utils.get_next_batch(data_obj["test_dataloader"])
    no_zero_count = torch.sum(sample_test['observed_data']!= 0).item()
    zero_count = torch.sum(sample_test['observed_data'] == 0).item()
    print("zero_count 개수: ",zero_count)
    print("no_zero_count 개수: ",no_zero_count)
    plot_name = "test_2024_02_15_10_{:04d}".format(itr//num_batches)
    viz.draw_all_plots_one_dim(sample_test ,model.to(device), plot_name=plot_name, save = True)
    

    logger.info("Experiment " + str(experimentID))
    logger.info(message)
    logger.info("KL coef: {}".format(kl_coef))
    logger.info("Train loss (one batch): {}".format(train_res["loss"].detach()))
    logger.info("Train CE loss (one batch): {}".format(train_res["ce_loss"].detach()))
    
    

    if "mse" in test_res:
        logger.info("Test MSE: {:.4f}".format(test_res["mse"]))


Computing loss... 0
zero_count 개수:  1015
no_zero_count 개수:  740
shape torch.Size([10, 1, 100, 1])


Experiment 89815
Epoch 0993 [Test seq (cond on sampled tp)] | Loss 370.338989 | Likelihood -481.314362 | KL fp 3.4721 | FP STD 0.5245|
KL coef: 0.9999487851046791
Train loss (one batch): 73.70389556884766
Train CE loss (one batch): 0.0
Test MSE: 0.0970


In [None]:

viz.draw_all_plots_one_dim(batch_dict ,model.to(device), plot_name=plot_name, save = True)

RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

In [None]:
print(traj_from_prior.shape)

NameError: name 'traj_from_prior' is not defined

In [None]:
batch_dict['observed_data'].shape

torch.Size([4, 523, 1])

In [None]:
data_obj["test_dataloader"]

<generator object inf_generator at 0x7fcfd36ec190>