In [None]:
import os
os.chdir("../")
# Set R environment variables using the conda environment path
r_home = '/sfs/gpfs/tardis/home/jq2uw/llm_nicu_vitalsigns/clip_env/lib/R'
os.environ['R_HOME'] = r_home
os.environ['R_LIBS'] = f"{r_home}/library"
os.environ['R_LIBS_USER'] = os.path.expanduser('~/R/goolf/4.3')
os.environ['LD_LIBRARY_PATH'] = f"{r_home}/lib:" + os.environ.get('LD_LIBRARY_PATH', '')

import torch
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

from config import *
from data import *
from train import *
from eval import *
from generation import *
from vital2d import *
print("using device: ", device)

# Configuration (customizable)

In [2]:

overwrite = True
model_name = 'test_mixture_syn' 

text1 = ('No trend.',1)
text2 = ('No seasonal pattern.',1)
text3 = ('No sharp shifts.',1)
counter_text11 = ('The time series shows upward linear trend.',1)
counter_text12 = ('The time series shows downward linear trend.',1)
counter_text2 = ('The time series exhibits a seasonal pattern.',1)
counter_text31 = ('The mean of the time series shifts upwards.',10)
counter_text32 = ('The mean of the time series shifts downwards.',10)


text_config = {'text_pairs': [
                    [text1, counter_text11, counter_text12],
                    [text2, counter_text2],
                    [text3, counter_text31, counter_text32]
                ],  'n': None}

attr_id = 3 # y_col by the third attribute (third element in the text_config['text_pairs'])
update_config(
    
    # Eval settings (clip)
    # ts2txt
    y_col = 'segment'+str(attr_id),
    y_levels = [t[0]for t in text_config['text_pairs'][attr_id-1]],
    y_pred_levels =[t[0]for t in text_config['text_pairs'][attr_id-1]],
    # txt2ts
    txt2ts_y_cols = ['segment1', 'segment2', 'segment3'],
    
    
    # Data settings
    text_col = 'text', #'ts_description',
    seq_length = 200,
    downsample = True,
    downsample_size = 15000,
    downsample_levels = [t[0]for t in text_config['text_pairs'][attr_id-1]],
    custom_target_cols = ['segment1', 'segment2', 'segment3', 'label'], # 'label' is the same as the default "by_label" target
    
    
    # Model settings
    model_name = model_name,
    **{'3d': False},  # Add this line
    embedded_dim = 512,
    concat_embeddings = False,
    clip_mu = False,
    variational = False,
    train_type = 'joint', # or 'vae', 'clip'
    clip_target_type = 'by_target', # or 'by_label'
    
    # Train settings
    batch_size = 512,
    init_lr = 1e-4,
    patience = 100,
    alpha = 1/100,
    num_saves = 5,
    num_epochs = 500,
    
    # Text configuration
    text_config = text_config
)
config_dict = get_config_dict()



# Data

In [None]:
# change data preparation for a given experiment
with open('prepare_experiment/synthetic.py', 'r') as file:
    exec(file.read())
# prepare model inputs
with open('run/inputs.py', 'r') as file:
    exec(file.read())

# Model (customizable)

In [None]:
# customize encoder and decoder here. 
ts_encoder = CNNEncoder(ts_dim = ts_f_dim.shape[1], output_dim=config_dict['embedded_dim'],
                        num_channels=[16], kernel_size=50, dropout=0)
ts_decoder = TransformerDecoder(ts_dim = ts_f_dim.shape[1], output_dim = config_dict['embedded_dim']+2, 
                 nhead = 8,
                 num_layers = 6,
                 dim_feedforward = 512,
                 dropout = 0.1)
text_encoder = TextEncoderCNN(text_dim = tx_f_dim.shape[1], output_dim=config_dict['embedded_dim'], 
                              num_channels=[16], kernel_size=50, dropout=0)
# overwrite = False
with open('run/model.py', 'r') as file:
    exec(file.read())


# Train

In [None]:
# overwrite = False
with open('run/train.py', 'r') as file:
    exec(file.read())


# Generation

In [None]:
tid=0
# viz_generation_marginal(df_train, model, config_dict, tid=tid)
viz_generation_conditional(df_train, model, config_dict, tid=tid)