In [1]:
from functions_aa import *
import os

def get_dataloader_and_params(artificial_missing_rate=0.1, batch_size = 3):
    raw_df                                   = read_data()
    train_set_ids, val_set_ids, test_set_ids = generate_train_val_test(raw_df, train_size=.8)
    data_dict                                = make_datasets_dict(raw_df,train_set_ids,val_set_ids,test_set_ids, artificial_missing_rate = artificial_missing_rate)

    dataset_saving_dir = f"{os.getcwd()}/data"

    saving_into_h5(dataset_saving_dir, data_dict, classification_dataset=False)

    args, model_args       = get_args(seq_len = 9, feature_num  = 51, batch_size   = batch_size)

    masked_imputation_task = True
    model_type             = 'SAITS'
    args.batch_size        = batch_size
    args.device            = 'cuda'
    unified_dataloader     = UnifiedDataLoader(args.dataset_path, args.seq_len, args.feature_num, model_type, args.batch_size, args.num_workers, masked_imputation_task)
    return unified_dataloader, args, model_args

#print(f'All done. Saved to {dataset_saving_dir}.')

import torch.optim as optim
import copy
_, args, model_args               =  get_dataloader_and_params()
args.optimizer_type     = 'Adam'
args.epochs             = 10
args.eval_every_n_steps = 25

model = SAITS(**model_args)


min_total_loss_val = 100000

if 'cuda' in args.device and torch.cuda.is_available() and torch.cuda.is_initialized():
    model = model.to(args.device)

optimizer               = getattr(optim, args.optimizer_type)(model.parameters(), lr= args.lr)


  from .autonotebook import tqdm as notebook_tqdm


Cuda is ready


In [2]:
unified_dataloader, _, _               =  get_dataloader_and_params()
train_dataloader, val_dataloader       = unified_dataloader.get_train_val_dataloader()

In [3]:
indices, X, missing_mask, X_holdout, indicating_mask = next(iter(train_dataloader))

In [4]:
for epoch in range(args.epochs):
    
    unified_dataloader, _, _               =  get_dataloader_and_params()
    train_dataloader, val_dataloader       = unified_dataloader.get_train_val_dataloader()
    
    
    total_loss_val = 0
    for data in train_dataloader:
        model.train()
        
        indices, X, missing_mask, X_holdout, indicating_mask = map(lambda x: x.to(args.device), data)
        inputs = {'indices': indices, 'X': X, 'missing_mask': missing_mask,'X_holdout': X_holdout, 'indicating_mask': indicating_mask}
        results = model(inputs, 'train')
        results = result_processing(results, args)
        optimizer.zero_grad()
        results['total_loss'].backward()
        optimizer.step()
    if epoch % args.eval_every_n_steps ==0 or epoch == args.epochs-1:
        model.eval()
        with torch.no_grad():
            for idx, data in enumerate(val_dataloader):
                    #inputs, results = model_processing(data, model, 'val', args = args)
                indices, X, missing_mask, X_holdout, indicating_mask = map(lambda x: x.to(args.device), data)
                inputs = {'indices': indices, 'X': X, 'missing_mask': missing_mask,
                    'X_holdout': X_holdout, 'indicating_mask': indicating_mask}
                results = model(inputs, 'validation')
                results = result_processing(results, args)
                L = results['total_loss']
                
        total_loss_val += L
        if total_loss_val < min_total_loss_val:
            best_params        = copy.deepcopy(model.state_dict())
            torch.save(best_params, 'state_dict_best_trained_model.pth')
            min_total_loss_val = total_loss_val
            print(f'Best total_loss: {min_total_loss_val}, best_params updated!')
                
        #total_loss_val += L
        print(f'Epoch total_loss: {total_loss_val}')

Best total_loss: 36.43928146362305, best_params updated!
Epoch total_loss: 36.43928146362305
Best total_loss: 23.934600830078125, best_params updated!
Epoch total_loss: 23.934600830078125


In [5]:
import torch
#saved_parameters_path = f"{os.getcwd()}/state_dict_best_trained_model.pth"

torch.save(best_params, 'state_dict_best_trained_model.pth')