# Liriscat paper experiments
### 1. Init
#### 1.1. Import libraries (necessary)

In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTHONHASHSEED"] = "0"
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

import liriscat
liriscat.utils.set_seed(0)

import logging
import gc
import json
import torch
liriscat.utils.set_seed(0)
import pandas as pd

#### 1.2. Set up the loggers (recommended)

In [2]:
liriscat.utils.setuplogger(verbose = True, log_name="liriscat", debug=False)

### 2. CDM prediction
#### 2.1. Training and testing, sequential version

In [3]:
import warnings

gc.collect()
torch.cuda.empty_cache()

In [4]:
config = liriscat.utils.generate_eval_config(load_params=True,CDM='ncdm', esc = 'error', valid_metric= 'rmse', pred_metrics = ["mi_acc"], profile_metrics = ['meta_doa'], save_params=False, n_query=6, num_epochs=100, batch_size=512)
liriscat.utils.set_seed(config["seed"])

config["dataset_name"] = "algebra"
logging.info(config["dataset_name"])
config['learning_rate'] = 0.0001
config['inner_user_lr'] = 0.0016969685554352153
config['lambda'] = 2.2656270501845414e-06

config['meta_lr'] = 0.5
config['learning_users_emb_lr'] = 0.0005
config['patience'] = 20#5
config['num_inner_users_epochs'] = 3
config['d_in'] = 4
config['num_responses'] = 12


#pred_metrics,df_interp = test(config)

CUDA is available. Using GPU.
[INFO 33:05] algebra


In [5]:
def pareto_index2(d) : 
    d_acc = d[0]
    d_meta = d[1]

    r = []

    for i in range(len(d_acc)):
        r.append((0.5-d_acc[i]['mi_acc'])*(0.5-d_meta[i]['meta_doa']))
    return sum(r)

In [6]:
logging.info(f'#### {config["dataset_name"]} ####')
logging.info(f'#### config : {config} ####')
config['embs_path']='../embs/'
config['params_path']='../ckpt/'

gc.collect()
torch.cuda.empty_cache()

# Dataset downloading for doa and rm
warnings.filterwarnings("ignore", message="invalid value encountered in divide")
warnings.filterwarnings("ignore", category=RuntimeWarning)

## Concept map format : {question_id : [category_id1, category_id2, ...]}
concept_map = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_concept_map.json', 'r'))
concept_map = {int(k): [int(x) for x in v] for k, v in concept_map.items()}

## Metadata map format : {"num_user_id": ..., "num_item_id": ..., "num_dimension_id": ...}
metadata = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_metadata.json', 'r'))


## Tensor containing the nb of modalities per question
nb_modalities = torch.load(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_nb_modalities.pkl',weights_only=True)


[INFO 33:05] #### algebra ####
[INFO 33:05] #### config : {'seed': 0, 'dataset_name': 'algebra', 'load_params': True, 'save_params': False, 'embs_path': '../embs/', 'params_path': '../ckpt/', 'early_stopping': True, 'esc': 'error', 'verbose_early_stopping': False, 'disable_tqdm': False, 'valid_metric': 'rmse', 'learning_rate': 0.0001, 'batch_size': 512, 'valid_batch_size': 10000, 'num_epochs': 100, 'eval_freq': 1, 'patience': 20, 'device': device(type='cuda'), 'lambda': 2.2656270501845414e-06, 'tensorboard': False, 'flush_freq': True, 'pred_metrics': ['mi_acc'], 'profile_metrics': ['meta_doa'], 'num_responses': 12, 'low_mem': False, 'n_query': 6, 'CDM': 'ncdm', 'i_fold': 0, 'num_inner_users_epochs': 3, 'num_inner_epochs': 10, 'inner_lr': 0.0001, 'inner_user_lr': 0.0016969685554352153, 'meta_lr': 0.5, 'meta_trainer': 'Adam', 'num_workers': 0, 'pin_memory': False, 'debug': False, 'learning_users_emb_lr': 0.0005, 'd_in': 4} ####


[INFO 38:14] #### config : {'seed': 0, 'dataset_name': 'math2', 'load_params': True, 'save_params': False, 'embs_path': '../embs/', 'params_path': '../ckpt/', 'early_stopping': True, 'esc': 'error', 'verbose_early_stopping': False, 'disable_tqdm': False, 'valid_metric': 'mi_acc', 'learning_rate': 0.0001, 'batch_size': 512, 'valid_batch_size': 10000, 'num_epochs': 100, 'eval_freq': 1, 'patience': 20, 'device': device(type='cuda'), 'lambda': 9.972254466547545e-06, 'tensorboard': False, 'flush_freq': True, 'pred_metrics': ['mi_acc'], 'profile_metrics': ['meta_doa'], 'num_responses': 12, 'low_mem': False, 'n_query': 6, 'CDM': 'impact', 'i_fold': 0, 'num_inner_users_epochs': 3, 'num_inner_epochs': 10, 'inner_lr': 0.0001, 'inner_user_lr': 0.016848380924625605, 'meta_lr': 0.5, 'meta_trainer': 'Adam', 'num_workers': 0, 'pin_memory': False, 'debug': False, 'learning_users_emb_lr': 0.001, 'd_in': 4} ####


In [7]:
# NCDM
meta_trainers = ['GAP']
for meta_trainer in meta_trainers : 
    config['meta_trainer'] = meta_trainer
    logging.info(f'#### meta_trainer : {config["meta_trainer"]} ####')
    for i_fold in range(1) : 
        config['i_fold'] = i_fold
            
        logging.info(f'#### i_fold : {i_fold} ####')
        ## Dataframe columns : (user_id, question_id, response, category_id)
        train_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_train_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        valid_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_valid_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        test_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_test_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})

        train_data = liriscat.dataset.CATDataset(train_df, concept_map, metadata, config,nb_modalities)
        valid_data = liriscat.dataset.EvalDataset(valid_df, concept_map, metadata, config,nb_modalities)
        test_data = liriscat.dataset.EvalDataset(test_df, concept_map, metadata, config,nb_modalities)

        for seed in range(1) :
            config['seed'] = seed
            logging.info(f'#### seed : {seed} ####')

            train_data.reset_rng()
            valid_data.reset_rng()
            test_data.reset_rng()

            S = liriscat.selectionStrategy.Random(metadata,**config)
            S.init_models(train_data, valid_data)
            S.train(train_data, valid_data)
            liriscat.utils.set_seed(0)
            S.reset_rng()
            d = (S.evaluate_test(test_data, train_data, valid_data))
            logging.info(d)
            logging.info(liriscat.utils.pareto_index(d))
        torch.cuda.empty_cache()

[INFO 33:06] #### meta_trainer : GAP ####
[INFO 33:06] #### i_fold : 0 ####
[INFO 33:33] #### seed : 0 ####
[INFO 33:33] Random_cont_model
../ckpt/algebra_NCDM_fold_0_seed_0
[INFO 33:34] compiling selection model
[INFO 33:38] ------- START Training
[INFO 33:38] train on cuda


  0%|          | 0/100 [00:00<?, ?it/s]

[INFO 33:40] - meta_params: 0.5
[INFO 33:40] - cross_cond: 0.5
[INFO 33:40] - meta_lambda: 0.5
[INFO 33:40] - learning_users_emb: 0.0005
[INFO 33:40] rmse : 0.4891076982021332
[INFO 33:40] valid_rmse : 2.946431875228882
[INFO 33:40] valid_loss : 5.026264667510986


  1%|          | 1/100 [00:02<04:18,  2.61s/it]

[INFO 33:42] - meta_params: 0.5
[INFO 33:42] - cross_cond: 0.5
[INFO 33:42] - meta_lambda: 0.5
[INFO 33:42] - learning_users_emb: 0.0005
[INFO 33:43] rmse : 0.48902466893196106
[INFO 33:43] valid_rmse : 2.9459316730499268
[INFO 33:43] valid_loss : 5.021945476531982


  2%|▏         | 2/100 [00:04<04:00,  2.45s/it]

[INFO 33:44] - meta_params: 0.5
[INFO 33:44] - cross_cond: 0.5
[INFO 33:44] - meta_lambda: 0.5
[INFO 33:44] - learning_users_emb: 0.0005
[INFO 33:45] rmse : 0.4887843132019043
[INFO 33:45] valid_rmse : 2.944483757019043
[INFO 33:45] valid_loss : 5.0117878913879395


  3%|▎         | 3/100 [00:07<03:43,  2.30s/it]

[INFO 33:47] - meta_params: 0.5
[INFO 33:47] - cross_cond: 0.5
[INFO 33:47] - meta_lambda: 0.5
[INFO 33:47] - learning_users_emb: 0.0005
[INFO 33:47] rmse : 0.48861759901046753
[INFO 33:47] valid_rmse : 2.943479537963867
[INFO 33:47] valid_loss : 5.000763416290283


  4%|▍         | 4/100 [00:09<03:38,  2.28s/it]

[INFO 33:49] - meta_params: 0.5
[INFO 33:49] - cross_cond: 0.5
[INFO 33:49] - meta_lambda: 0.5
[INFO 33:49] - learning_users_emb: 0.0005
[INFO 33:49] rmse : 0.4886450171470642
[INFO 33:49] valid_rmse : 2.9436447620391846
[INFO 33:49] valid_loss : 4.9931159019470215


  5%|▌         | 5/100 [00:11<03:29,  2.21s/it]

[INFO 33:51] - meta_params: 0.5
[INFO 33:51] - cross_cond: 0.5
[INFO 33:51] - meta_lambda: 0.5
[INFO 33:51] - learning_users_emb: 0.0005
[INFO 33:51] rmse : 0.48816171288490295
[INFO 33:51] valid_rmse : 2.9407331943511963
[INFO 33:51] valid_loss : 4.973065376281738


  6%|▌         | 6/100 [00:13<03:21,  2.14s/it]

[INFO 33:53] - meta_params: 0.5
[INFO 33:53] - cross_cond: 0.5
[INFO 33:53] - meta_lambda: 0.5
[INFO 33:53] - learning_users_emb: 0.0005
[INFO 33:54] rmse : 0.4880612790584564
[INFO 33:54] valid_rmse : 2.9401283264160156
[INFO 33:54] valid_loss : 4.972350597381592


  7%|▋         | 7/100 [00:15<03:26,  2.22s/it]

[INFO 33:55] - meta_params: 0.5
[INFO 33:55] - cross_cond: 0.5
[INFO 33:55] - meta_lambda: 0.5
[INFO 33:55] - learning_users_emb: 0.0005
[INFO 33:56] rmse : 0.4876853823661804
[INFO 33:56] valid_rmse : 2.937863826751709
[INFO 33:56] valid_loss : 4.943924427032471


  8%|▊         | 8/100 [00:17<03:18,  2.16s/it]

[INFO 33:57] - meta_params: 0.5
[INFO 33:57] - cross_cond: 0.5
[INFO 33:57] - meta_lambda: 0.5
[INFO 33:57] - learning_users_emb: 0.0005
[INFO 33:58] rmse : 0.48827916383743286
[INFO 33:58] valid_rmse : 2.9414408206939697
[INFO 33:58] valid_loss : 4.974517822265625


  9%|▉         | 9/100 [00:19<03:14,  2.14s/it]

[INFO 33:59] - meta_params: 0.5
[INFO 33:59] - cross_cond: 0.5
[INFO 33:59] - meta_lambda: 0.5
[INFO 33:59] - learning_users_emb: 0.0005
[INFO 34:00] rmse : 0.4879290759563446
[INFO 34:00] valid_rmse : 2.9393317699432373
[INFO 34:00] valid_loss : 4.949010372161865


 10%|█         | 10/100 [00:21<03:10,  2.12s/it]

[INFO 34:01] - meta_params: 0.5
[INFO 34:01] - cross_cond: 0.5
[INFO 34:01] - meta_lambda: 0.5
[INFO 34:01] - learning_users_emb: 0.0005
[INFO 34:02] rmse : 0.48771777749061584
[INFO 34:02] valid_rmse : 2.938058853149414
[INFO 34:02] valid_loss : 4.949815273284912


 11%|█         | 11/100 [00:24<03:09,  2.13s/it]

[INFO 34:04] - meta_params: 0.25
[INFO 34:04] - cross_cond: 0.25
[INFO 34:04] - meta_lambda: 0.25
[INFO 34:04] - learning_users_emb: 0.00025
[INFO 34:04] rmse : 0.48781153559684753
[INFO 34:04] valid_rmse : 2.9386236667633057
[INFO 34:04] valid_loss : 4.9640679359436035


 12%|█▏        | 12/100 [00:26<03:06,  2.12s/it]

[INFO 34:06] - meta_params: 0.25
[INFO 34:06] - cross_cond: 0.25
[INFO 34:06] - meta_lambda: 0.25
[INFO 34:06] - learning_users_emb: 0.00025
[INFO 34:06] rmse : 0.48649489879608154
[INFO 34:06] valid_rmse : 2.930692195892334
[INFO 34:06] valid_loss : 4.895852088928223


 13%|█▎        | 13/100 [00:28<03:01,  2.08s/it]

[INFO 34:08] - meta_params: 0.25
[INFO 34:08] - cross_cond: 0.25
[INFO 34:08] - meta_lambda: 0.25
[INFO 34:08] - learning_users_emb: 0.00025
[INFO 34:08] rmse : 0.48682478070259094
[INFO 34:08] valid_rmse : 2.9326794147491455
[INFO 34:08] valid_loss : 4.892848491668701


 14%|█▍        | 14/100 [00:30<02:56,  2.06s/it]

[INFO 34:10] - meta_params: 0.25
[INFO 34:10] - cross_cond: 0.25
[INFO 34:10] - meta_lambda: 0.25
[INFO 34:10] - learning_users_emb: 0.00025
[INFO 34:10] rmse : 0.48728522658348083
[INFO 34:10] valid_rmse : 2.935453176498413
[INFO 34:10] valid_loss : 4.916186332702637


 15%|█▌        | 15/100 [00:32<02:58,  2.11s/it]

[INFO 34:12] - meta_params: 0.25
[INFO 34:12] - cross_cond: 0.25
[INFO 34:12] - meta_lambda: 0.25
[INFO 34:12] - learning_users_emb: 0.00025
[INFO 34:12] rmse : 0.48764878511428833
[INFO 34:12] valid_rmse : 2.93764328956604
[INFO 34:12] valid_loss : 4.940404891967773


 16%|█▌        | 16/100 [00:34<02:53,  2.07s/it]

[INFO 34:14] - meta_params: 0.25
[INFO 34:14] - cross_cond: 0.25
[INFO 34:14] - meta_lambda: 0.25
[INFO 34:14] - learning_users_emb: 0.00025
[INFO 34:14] rmse : 0.4869665801525116
[INFO 34:14] valid_rmse : 2.9335336685180664
[INFO 34:14] valid_loss : 4.914363384246826


 17%|█▋        | 17/100 [00:36<02:52,  2.08s/it]

[INFO 34:16] - meta_params: 0.125
[INFO 34:16] - cross_cond: 0.125
[INFO 34:16] - meta_lambda: 0.125
[INFO 34:16] - learning_users_emb: 0.000125
[INFO 34:16] rmse : 0.486261248588562
[INFO 34:16] valid_rmse : 2.9292845726013184
[INFO 34:16] valid_loss : 4.881582736968994


 18%|█▊        | 18/100 [00:38<02:48,  2.05s/it]

[INFO 34:18] - meta_params: 0.125
[INFO 34:18] - cross_cond: 0.125
[INFO 34:18] - meta_lambda: 0.125
[INFO 34:18] - learning_users_emb: 0.000125
[INFO 34:19] rmse : 0.48665544390678406
[INFO 34:19] valid_rmse : 2.931659460067749
[INFO 34:19] valid_loss : 4.904518127441406


 19%|█▉        | 19/100 [00:40<02:52,  2.13s/it]

[INFO 34:20] - meta_params: 0.125
[INFO 34:20] - cross_cond: 0.125
[INFO 34:20] - meta_lambda: 0.125
[INFO 34:20] - learning_users_emb: 0.000125
[INFO 34:21] rmse : 0.4861386716365814
[INFO 34:21] valid_rmse : 2.928546190261841
[INFO 34:21] valid_loss : 4.86037540435791


 20%|██        | 20/100 [00:42<02:48,  2.11s/it]

[INFO 34:22] - meta_params: 0.125
[INFO 34:22] - cross_cond: 0.125
[INFO 34:22] - meta_lambda: 0.125
[INFO 34:22] - learning_users_emb: 0.000125
[INFO 34:23] rmse : 0.48685315251350403
[INFO 34:23] valid_rmse : 2.9328503608703613
[INFO 34:23] valid_loss : 4.887227535247803


 21%|██        | 21/100 [00:44<02:42,  2.06s/it]

[INFO 34:24] - meta_params: 0.125
[INFO 34:24] - cross_cond: 0.125
[INFO 34:24] - meta_lambda: 0.125
[INFO 34:24] - learning_users_emb: 0.000125
[INFO 34:25] rmse : 0.4868716299533844
[INFO 34:25] valid_rmse : 2.9329617023468018
[INFO 34:25] valid_loss : 4.91365909576416


 22%|██▏       | 22/100 [00:46<02:40,  2.06s/it]

[INFO 34:26] - meta_params: 0.125
[INFO 34:26] - cross_cond: 0.125
[INFO 34:26] - meta_lambda: 0.125
[INFO 34:26] - learning_users_emb: 0.000125
[INFO 34:27] rmse : 0.4860232472419739
[INFO 34:27] valid_rmse : 2.9278509616851807
[INFO 34:27] valid_loss : 4.878360748291016


 23%|██▎       | 23/100 [00:48<02:35,  2.02s/it]

[INFO 34:28] - meta_params: 0.0625
[INFO 34:28] - cross_cond: 0.0625
[INFO 34:28] - meta_lambda: 0.0625
[INFO 34:28] - learning_users_emb: 6.25e-05
[INFO 34:29] rmse : 0.4875358045101166
[INFO 34:29] valid_rmse : 2.936962842941284
[INFO 34:29] valid_loss : 4.928546905517578


 24%|██▍       | 24/100 [00:50<02:32,  2.01s/it]

[INFO 34:30] - meta_params: 0.0625
[INFO 34:30] - cross_cond: 0.0625
[INFO 34:30] - meta_lambda: 0.0625
[INFO 34:30] - learning_users_emb: 6.25e-05
[INFO 34:31] rmse : 0.4864150881767273
[INFO 34:31] valid_rmse : 2.9302115440368652
[INFO 34:31] valid_loss : 4.894988536834717


 25%|██▌       | 25/100 [00:53<02:35,  2.08s/it]

[INFO 34:32] - meta_params: 0.0625
[INFO 34:32] - cross_cond: 0.0625
[INFO 34:32] - meta_lambda: 0.0625
[INFO 34:32] - learning_users_emb: 6.25e-05
[INFO 34:33] rmse : 0.4861401617527008
[INFO 34:33] valid_rmse : 2.9285552501678467
[INFO 34:33] valid_loss : 4.875359058380127


 26%|██▌       | 26/100 [00:55<02:30,  2.04s/it]

[INFO 34:34] - meta_params: 0.03125
[INFO 34:34] - cross_cond: 0.03125
[INFO 34:34] - meta_lambda: 0.03125
[INFO 34:34] - learning_users_emb: 3.125e-05
[INFO 34:35] rmse : 0.48687052726745605
[INFO 34:35] valid_rmse : 2.932955026626587
[INFO 34:35] valid_loss : 4.901092052459717


 27%|██▋       | 27/100 [00:56<02:27,  2.02s/it]

[INFO 34:36] - meta_params: 0.03125
[INFO 34:36] - cross_cond: 0.03125
[INFO 34:36] - meta_lambda: 0.03125
[INFO 34:36] - learning_users_emb: 3.125e-05
[INFO 34:37] rmse : 0.4860990643501282
[INFO 34:37] valid_rmse : 2.9283077716827393
[INFO 34:37] valid_loss : 4.886013507843018


 28%|██▊       | 28/100 [00:59<02:25,  2.02s/it]

[INFO 34:38] - meta_params: 0.03125
[INFO 34:38] - cross_cond: 0.03125
[INFO 34:38] - meta_lambda: 0.03125
[INFO 34:38] - learning_users_emb: 3.125e-05
[INFO 34:39] rmse : 0.4872371554374695
[INFO 34:39] valid_rmse : 2.935163736343384
[INFO 34:39] valid_loss : 4.919351577758789


 29%|██▉       | 29/100 [01:01<02:22,  2.01s/it]

[INFO 34:40] - meta_params: 0.015625
[INFO 34:40] - cross_cond: 0.015625
[INFO 34:40] - meta_lambda: 0.015625
[INFO 34:40] - learning_users_emb: 1.5625e-05
[INFO 34:41] rmse : 0.4861816167831421
[INFO 34:41] valid_rmse : 2.928804874420166
[INFO 34:41] valid_loss : 4.8677978515625


 30%|███       | 30/100 [01:03<02:21,  2.03s/it]

[INFO 34:42] - meta_params: 0.015625
[INFO 34:42] - cross_cond: 0.015625
[INFO 34:42] - meta_lambda: 0.015625
[INFO 34:42] - learning_users_emb: 1.5625e-05
[INFO 34:43] rmse : 0.486569881439209
[INFO 34:43] valid_rmse : 2.9311439990997314
[INFO 34:43] valid_loss : 4.8879923820495605


 31%|███       | 31/100 [01:05<02:22,  2.06s/it]

[INFO 34:44] - meta_params: 0.015625
[INFO 34:44] - cross_cond: 0.015625
[INFO 34:44] - meta_lambda: 0.015625
[INFO 34:44] - learning_users_emb: 1.5625e-05
[INFO 34:45] rmse : 0.48684078454971313
[INFO 34:45] valid_rmse : 2.9327759742736816
[INFO 34:45] valid_loss : 4.888619899749756


 32%|███▏      | 32/100 [01:07<02:18,  2.04s/it]

[INFO 34:47] - meta_params: 0.0078125
[INFO 34:47] - cross_cond: 0.0078125
[INFO 34:47] - meta_lambda: 0.0078125
[INFO 34:47] - learning_users_emb: 7.8125e-06
[INFO 34:47] rmse : 0.4870133399963379
[INFO 34:47] valid_rmse : 2.9338152408599854
[INFO 34:47] valid_loss : 4.919463157653809


 33%|███▎      | 33/100 [01:09<02:16,  2.04s/it]

[INFO 34:49] - meta_params: 0.0078125
[INFO 34:49] - cross_cond: 0.0078125
[INFO 34:49] - meta_lambda: 0.0078125
[INFO 34:49] - learning_users_emb: 7.8125e-06
[INFO 34:49] rmse : 0.48680007457733154
[INFO 34:49] valid_rmse : 2.932530641555786
[INFO 34:49] valid_loss : 4.892086982727051


 34%|███▍      | 34/100 [01:11<02:12,  2.01s/it]

[INFO 34:51] - meta_params: 0.0078125
[INFO 34:51] - cross_cond: 0.0078125
[INFO 34:51] - meta_lambda: 0.0078125
[INFO 34:51] - learning_users_emb: 7.8125e-06
[INFO 34:51] rmse : 0.4869312345981598
[INFO 34:51] valid_rmse : 2.9333207607269287
[INFO 34:51] valid_loss : 4.907578468322754


 35%|███▌      | 35/100 [01:13<02:11,  2.02s/it]

[INFO 34:53] - meta_params: 0.00390625
[INFO 34:53] - cross_cond: 0.00390625
[INFO 34:53] - meta_lambda: 0.00390625
[INFO 34:53] - learning_users_emb: 3.90625e-06
[INFO 34:53] rmse : 0.4872041642665863
[INFO 34:53] valid_rmse : 2.934964895248413
[INFO 34:53] valid_loss : 4.927343845367432


 36%|███▌      | 36/100 [01:15<02:08,  2.00s/it]

[INFO 34:55] - meta_params: 0.00390625
[INFO 34:55] - cross_cond: 0.00390625
[INFO 34:55] - meta_lambda: 0.00390625
[INFO 34:55] - learning_users_emb: 3.90625e-06
[INFO 34:55] rmse : 0.48636412620544434
[INFO 34:55] valid_rmse : 2.9299044609069824
[INFO 34:55] valid_loss : 4.876018047332764


 37%|███▋      | 37/100 [01:17<02:10,  2.07s/it]

[INFO 34:57] - meta_params: 0.00390625
[INFO 34:57] - cross_cond: 0.00390625
[INFO 34:57] - meta_lambda: 0.00390625
[INFO 34:57] - learning_users_emb: 3.90625e-06
[INFO 34:57] rmse : 0.4866465628147125
[INFO 34:57] valid_rmse : 2.931605815887451
[INFO 34:57] valid_loss : 4.889874458312988


 38%|███▊      | 38/100 [01:19<02:08,  2.07s/it]

[INFO 34:59] - meta_params: 0.001953125
[INFO 34:59] - cross_cond: 0.001953125
[INFO 34:59] - meta_lambda: 0.001953125
[INFO 34:59] - learning_users_emb: 1.953125e-06
[INFO 34:59] rmse : 0.4865531027317047
[INFO 34:59] valid_rmse : 2.9310429096221924
[INFO 34:59] valid_loss : 4.886338233947754


 39%|███▉      | 39/100 [01:21<02:04,  2.04s/it]

[INFO 35:01] - meta_params: 0.001953125
[INFO 35:01] - cross_cond: 0.001953125
[INFO 35:01] - meta_lambda: 0.001953125
[INFO 35:01] - learning_users_emb: 1.953125e-06
[INFO 35:01] rmse : 0.48657432198524475
[INFO 35:01] valid_rmse : 2.931170701980591
[INFO 35:01] valid_loss : 4.88156270980835


 40%|████      | 40/100 [01:23<02:02,  2.05s/it]

[INFO 35:03] - meta_params: 0.001953125
[INFO 35:03] - cross_cond: 0.001953125
[INFO 35:03] - meta_lambda: 0.001953125
[INFO 35:03] - learning_users_emb: 1.953125e-06
[INFO 35:03] rmse : 0.4855756461620331
[INFO 35:03] valid_rmse : 2.925154447555542
[INFO 35:03] valid_loss : 4.847590923309326


 41%|████      | 41/100 [01:25<01:58,  2.01s/it]

[INFO 35:05] - meta_params: 0.001953125
[INFO 35:05] - cross_cond: 0.001953125
[INFO 35:05] - meta_lambda: 0.001953125
[INFO 35:05] - learning_users_emb: 1.953125e-06
[INFO 35:05] rmse : 0.48604488372802734
[INFO 35:05] valid_rmse : 2.927981376647949
[INFO 35:05] valid_loss : 4.861686706542969


 42%|████▏     | 42/100 [01:27<01:55,  1.99s/it]

[INFO 35:07] - meta_params: 0.001953125
[INFO 35:07] - cross_cond: 0.001953125
[INFO 35:07] - meta_lambda: 0.001953125
[INFO 35:07] - learning_users_emb: 1.953125e-06
[INFO 35:07] rmse : 0.4863715171813965
[INFO 35:07] valid_rmse : 2.9299490451812744
[INFO 35:07] valid_loss : 4.877751350402832


 43%|████▎     | 43/100 [01:29<01:56,  2.05s/it]

[INFO 35:09] - meta_params: 0.001953125
[INFO 35:09] - cross_cond: 0.001953125
[INFO 35:09] - meta_lambda: 0.001953125
[INFO 35:09] - learning_users_emb: 1.953125e-06
[INFO 35:09] rmse : 0.4873306155204773
[INFO 35:09] valid_rmse : 2.9357266426086426
[INFO 35:09] valid_loss : 4.921860694885254


 44%|████▍     | 44/100 [01:31<01:53,  2.02s/it]

[INFO 35:11] - meta_params: 0.0009765625
[INFO 35:11] - cross_cond: 0.0009765625
[INFO 35:11] - meta_lambda: 0.0009765625
[INFO 35:11] - learning_users_emb: 9.765625e-07
[INFO 35:11] rmse : 0.48672229051589966
[INFO 35:11] valid_rmse : 2.9320621490478516
[INFO 35:11] valid_loss : 4.8988871574401855


 45%|████▌     | 45/100 [01:33<01:50,  2.01s/it]

[INFO 35:13] - meta_params: 0.0009765625
[INFO 35:13] - cross_cond: 0.0009765625
[INFO 35:13] - meta_lambda: 0.0009765625
[INFO 35:13] - learning_users_emb: 9.765625e-07
[INFO 35:13] rmse : 0.485917866230011
[INFO 35:13] valid_rmse : 2.927216053009033
[INFO 35:13] valid_loss : 4.856006145477295


 46%|████▌     | 46/100 [01:35<01:48,  2.01s/it]

[INFO 35:15] - meta_params: 0.0009765625
[INFO 35:15] - cross_cond: 0.0009765625
[INFO 35:15] - meta_lambda: 0.0009765625
[INFO 35:15] - learning_users_emb: 9.765625e-07
[INFO 35:15] rmse : 0.4858495593070984
[INFO 35:15] valid_rmse : 2.926804542541504
[INFO 35:15] valid_loss : 4.850765228271484


 47%|████▋     | 47/100 [01:37<01:45,  1.99s/it]

[INFO 35:17] - meta_params: 0.00048828125
[INFO 35:17] - cross_cond: 0.00048828125
[INFO 35:17] - meta_lambda: 0.00048828125
[INFO 35:17] - learning_users_emb: 4.8828125e-07
[INFO 35:17] rmse : 0.4869682490825653
[INFO 35:17] valid_rmse : 2.9335436820983887
[INFO 35:17] valid_loss : 4.9103522300720215


 48%|████▊     | 48/100 [01:39<01:44,  2.00s/it]

[INFO 35:19] - meta_params: 0.00048828125
[INFO 35:19] - cross_cond: 0.00048828125
[INFO 35:19] - meta_lambda: 0.00048828125
[INFO 35:19] - learning_users_emb: 4.8828125e-07
[INFO 35:19] rmse : 0.4867984652519226
[INFO 35:19] valid_rmse : 2.932520866394043
[INFO 35:19] valid_loss : 4.9096503257751465


 49%|████▉     | 49/100 [01:41<01:40,  1.97s/it]

[INFO 35:21] - meta_params: 0.00048828125
[INFO 35:21] - cross_cond: 0.00048828125
[INFO 35:21] - meta_lambda: 0.00048828125
[INFO 35:21] - learning_users_emb: 4.8828125e-07
[INFO 35:21] rmse : 0.4868817627429962
[INFO 35:21] valid_rmse : 2.9330227375030518
[INFO 35:21] valid_loss : 4.913308143615723


 50%|█████     | 50/100 [01:43<01:38,  1.97s/it]

[INFO 35:23] - meta_params: 0.000244140625
[INFO 35:23] - cross_cond: 0.000244140625
[INFO 35:23] - meta_lambda: 0.000244140625
[INFO 35:23] - learning_users_emb: 2.44140625e-07
[INFO 35:23] rmse : 0.4865231215953827
[INFO 35:23] valid_rmse : 2.9308621883392334
[INFO 35:23] valid_loss : 4.887783050537109


 51%|█████     | 51/100 [01:45<01:36,  1.97s/it]

[INFO 35:25] - meta_params: 0.000244140625
[INFO 35:25] - cross_cond: 0.000244140625
[INFO 35:25] - meta_lambda: 0.000244140625
[INFO 35:25] - learning_users_emb: 2.44140625e-07
[INFO 35:25] rmse : 0.4868372082710266
[INFO 35:25] valid_rmse : 2.9327542781829834
[INFO 35:25] valid_loss : 4.900473117828369


 52%|█████▏    | 52/100 [01:47<01:34,  1.97s/it]

[INFO 35:27] - meta_params: 0.000244140625
[INFO 35:27] - cross_cond: 0.000244140625
[INFO 35:27] - meta_lambda: 0.000244140625
[INFO 35:27] - learning_users_emb: 2.44140625e-07
[INFO 35:27] rmse : 0.4864256978034973
[INFO 35:27] valid_rmse : 2.9302754402160645
[INFO 35:27] valid_loss : 4.898383617401123


 53%|█████▎    | 53/100 [01:49<01:32,  1.97s/it]

[INFO 35:29] - meta_params: 0.0001220703125
[INFO 35:29] - cross_cond: 0.0001220703125
[INFO 35:29] - meta_lambda: 0.0001220703125
[INFO 35:29] - learning_users_emb: 1.220703125e-07
[INFO 35:29] rmse : 0.4862603545188904
[INFO 35:29] valid_rmse : 2.929279327392578
[INFO 35:29] valid_loss : 4.8917646408081055


 54%|█████▍    | 54/100 [01:51<01:30,  1.96s/it]

[INFO 35:31] - meta_params: 0.0001220703125
[INFO 35:31] - cross_cond: 0.0001220703125
[INFO 35:31] - meta_lambda: 0.0001220703125
[INFO 35:31] - learning_users_emb: 1.220703125e-07
[INFO 35:31] rmse : 0.4869754910469055
[INFO 35:31] valid_rmse : 2.9335873126983643
[INFO 35:31] valid_loss : 4.898651123046875


 55%|█████▌    | 55/100 [01:53<01:27,  1.96s/it]

[INFO 35:33] - meta_params: 0.0001220703125
[INFO 35:33] - cross_cond: 0.0001220703125
[INFO 35:33] - meta_lambda: 0.0001220703125
[INFO 35:33] - learning_users_emb: 1.220703125e-07
[INFO 35:33] rmse : 0.486935555934906
[INFO 35:33] valid_rmse : 2.933346748352051
[INFO 35:33] valid_loss : 4.914449214935303


 55%|█████▌    | 55/100 [01:55<01:34,  2.09s/it]

[INFO 35:33] -- END Training --
[INFO 35:33] train on cuda



100%|██████████| 6/6 [00:00<00:00,  9.47it/s]


[INFO 35:57] ({0: {'mi_acc': 0.69744473695755}, 1: {'mi_acc': 0.69744473695755}, 2: {'mi_acc': 0.69744473695755}, 3: {'mi_acc': 0.69744473695755}, 4: {'mi_acc': 0.69744473695755}, 5: {'mi_acc': 0.6972299814224243}}, {0: {'meta_doa': np.float64(0.486973732475573)}, 1: {'meta_doa': np.float64(0.4880515469740171)}, 2: {'meta_doa': np.float64(0.489487092037853)}, 3: {'meta_doa': np.float64(0.4890770466088384)}, 4: {'meta_doa': np.float64(0.49200824003568105)}, 5: {'meta_doa': np.float64(0.4890502338331308)}})
[INFO 35:57] -0.01290107825768087


In [9]:
config['meta_trainer'] = 'MAML'

for seed in range(1) :
    config['seed'] = seed
    logging.info(f'#### seed : {seed} ####')
    
    train_data.reset_rng()
    valid_data.reset_rng()
    test_data.reset_rng()


    S = liriscat.selectionStrategy.Random(metadata,**config)
    S.init_models(train_data, valid_data)
    S.train(train_data, valid_data)
    liriscat.utils.set_seed(0)
    S.reset_rng()
    d = (S.evaluate_test(test_data, train_data, valid_data))
    print(liriscat.utils.pareto_index(d))
    print(d)
    torch.cuda.empty_cache()
    del S

[INFO 37:11] #### seed : 0 ####
[INFO 37:11] Random_cont_model
../ckpt/algebra_NCDM_fold_0_seed_0
[INFO 37:12] compiling selection model
[INFO 37:12] ------- START Training
[INFO 37:12] train on cuda


  0%|          | 0/100 [00:00<?, ?it/s]

[INFO 37:14] - learning_users_emb: 0.5
[INFO 37:15] rmse : 0.5669613480567932
[INFO 37:15] valid_rmse : 3.4154298305511475
[INFO 37:15] valid_loss : 6.258544445037842


  1%|          | 1/100 [00:02<04:36,  2.79s/it]

[INFO 37:17] - learning_users_emb: 0.5
[INFO 37:18] rmse : 0.5090577602386475
[INFO 37:18] valid_rmse : 3.066612958908081
[INFO 37:18] valid_loss : 4.702308654785156


  2%|▏         | 2/100 [00:05<04:49,  2.96s/it]

[INFO 37:20] - learning_users_emb: 0.5
[INFO 37:21] rmse : 0.4506419599056244
[INFO 37:21] valid_rmse : 2.7147107124328613
[INFO 37:21] valid_loss : 3.8463807106018066


  3%|▎         | 3/100 [00:08<04:37,  2.86s/it]

[INFO 37:23] - learning_users_emb: 0.5
[INFO 37:23] rmse : 0.4596650004386902
[INFO 37:23] valid_rmse : 2.769066333770752
[INFO 37:23] valid_loss : 3.9456357955932617


  4%|▍         | 4/100 [00:11<04:34,  2.86s/it]

[INFO 37:26] - learning_users_emb: 0.5
[INFO 37:26] rmse : 0.4676283895969391
[INFO 37:26] valid_rmse : 2.8170385360717773
[INFO 37:26] valid_loss : 3.955911874771118


  5%|▌         | 5/100 [00:14<04:29,  2.84s/it]

[INFO 37:28] - learning_users_emb: 0.5
[INFO 37:29] rmse : 0.4597303867340088
[INFO 37:29] valid_rmse : 2.7694602012634277
[INFO 37:29] valid_loss : 3.8525912761688232


  6%|▌         | 6/100 [00:17<04:27,  2.84s/it]

[INFO 37:31] - learning_users_emb: 0.25
[INFO 37:32] rmse : 0.4561121165752411
[INFO 37:32] valid_rmse : 2.7476634979248047
[INFO 37:32] valid_loss : 3.803600549697876


  7%|▋         | 7/100 [00:19<04:23,  2.83s/it]

[INFO 37:34] - learning_users_emb: 0.25
[INFO 37:35] rmse : 0.4522190988063812
[INFO 37:35] valid_rmse : 2.7242114543914795
[INFO 37:35] valid_loss : 3.750572443008423


  8%|▊         | 8/100 [00:22<04:20,  2.84s/it]

[INFO 37:37] - learning_users_emb: 0.25
[INFO 37:37] rmse : 0.44968509674072266
[INFO 37:37] valid_rmse : 2.708946466445923
[INFO 37:37] valid_loss : 3.717383861541748


  9%|▉         | 9/100 [00:25<04:14,  2.80s/it]

[INFO 37:39] - learning_users_emb: 0.25
[INFO 37:40] rmse : 0.44884151220321655
[INFO 37:40] valid_rmse : 2.703864574432373
[INFO 37:40] valid_loss : 3.707214832305908


 10%|█         | 10/100 [00:28<04:08,  2.76s/it]

[INFO 37:42] - learning_users_emb: 0.25
[INFO 37:43] rmse : 0.44911453127861023
[INFO 37:43] valid_rmse : 2.7055091857910156
[INFO 37:43] valid_loss : 3.7107105255126953


 11%|█         | 11/100 [00:31<04:11,  2.82s/it]

[INFO 37:45] - learning_users_emb: 0.25
[INFO 37:46] rmse : 0.44894835352897644
[INFO 37:46] valid_rmse : 2.7045083045959473
[INFO 37:46] valid_loss : 3.7078027725219727


 12%|█▏        | 12/100 [00:33<04:03,  2.77s/it]

[INFO 37:48] - learning_users_emb: 0.25
[INFO 37:49] rmse : 0.4479454755783081
[INFO 37:49] valid_rmse : 2.6984667778015137
[INFO 37:49] valid_loss : 3.6927552223205566


 13%|█▎        | 13/100 [00:36<04:00,  2.77s/it]

[INFO 37:50] - learning_users_emb: 0.25
[INFO 37:51] rmse : 0.4465717673301697
[INFO 37:51] valid_rmse : 2.6901915073394775
[INFO 37:51] valid_loss : 3.6728482246398926


 14%|█▍        | 14/100 [00:39<03:55,  2.73s/it]

[INFO 37:53] - learning_users_emb: 0.25
[INFO 37:54] rmse : 0.4461248219013214
[INFO 37:54] valid_rmse : 2.6874990463256836
[INFO 37:54] valid_loss : 3.6658730506896973


 15%|█▌        | 15/100 [00:41<03:50,  2.71s/it]

[INFO 37:56] - learning_users_emb: 0.25
[INFO 37:56] rmse : 0.4465804398059845
[INFO 37:56] valid_rmse : 2.690243721008301
[INFO 37:56] valid_loss : 3.671351432800293


 16%|█▌        | 16/100 [00:44<03:45,  2.69s/it]

[INFO 37:58] - learning_users_emb: 0.25
[INFO 37:59] rmse : 0.44671133160591125
[INFO 37:59] valid_rmse : 2.6910321712493896
[INFO 37:59] valid_loss : 3.6718199253082275


 17%|█▋        | 17/100 [00:47<03:43,  2.70s/it]

[INFO 38:01] - learning_users_emb: 0.25
[INFO 38:02] rmse : 0.4463825225830078
[INFO 38:02] valid_rmse : 2.689051389694214
[INFO 38:02] valid_loss : 3.6660096645355225


 18%|█▊        | 18/100 [00:49<03:39,  2.68s/it]

[INFO 38:04] - learning_users_emb: 0.125
[INFO 38:04] rmse : 0.44602227210998535
[INFO 38:04] valid_rmse : 2.6868813037872314
[INFO 38:04] valid_loss : 3.6598691940307617


 19%|█▉        | 19/100 [00:52<03:35,  2.66s/it]

[INFO 38:06] - learning_users_emb: 0.125
[INFO 38:07] rmse : 0.44562697410583496
[INFO 38:07] valid_rmse : 2.684499979019165
[INFO 38:07] valid_loss : 3.653902769088745


 20%|██        | 20/100 [00:55<03:33,  2.67s/it]

[INFO 38:09] - learning_users_emb: 0.125
[INFO 38:10] rmse : 0.44529399275779724
[INFO 38:10] valid_rmse : 2.6824939250946045
[INFO 38:10] valid_loss : 3.6494829654693604


 21%|██        | 21/100 [00:57<03:30,  2.66s/it]

[INFO 38:12] - learning_users_emb: 0.125
[INFO 38:13] rmse : 0.44508281350135803
[INFO 38:13] valid_rmse : 2.6812217235565186
[INFO 38:13] valid_loss : 3.6470110416412354


 22%|██▏       | 22/100 [01:00<03:29,  2.69s/it]

[INFO 38:14] - learning_users_emb: 0.125
[INFO 38:15] rmse : 0.4450179636478424
[INFO 38:15] valid_rmse : 2.68083119392395
[INFO 38:15] valid_loss : 3.6463844776153564


 23%|██▎       | 23/100 [01:03<03:26,  2.69s/it]

[INFO 38:17] - learning_users_emb: 0.125
[INFO 38:18] rmse : 0.4451082944869995
[INFO 38:18] valid_rmse : 2.68137526512146
[INFO 38:18] valid_loss : 3.6475167274475098


 24%|██▍       | 24/100 [01:05<03:23,  2.68s/it]

[INFO 38:20] - learning_users_emb: 0.125
[INFO 38:21] rmse : 0.4451723098754883
[INFO 38:21] valid_rmse : 2.6817610263824463
[INFO 38:21] valid_loss : 3.647751569747925


 25%|██▌       | 25/100 [01:08<03:20,  2.68s/it]

[INFO 38:22] - learning_users_emb: 0.125
[INFO 38:23] rmse : 0.4449920654296875
[INFO 38:23] valid_rmse : 2.6806750297546387
[INFO 38:23] valid_loss : 3.644580602645874


 26%|██▌       | 26/100 [01:11<03:17,  2.67s/it]

[INFO 38:25] - learning_users_emb: 0.125
[INFO 38:26] rmse : 0.44465139508247375
[INFO 38:26] valid_rmse : 2.6786229610443115
[INFO 38:26] valid_loss : 3.6392905712127686


 27%|██▋       | 27/100 [01:13<03:13,  2.65s/it]

[INFO 38:28] - learning_users_emb: 0.125
[INFO 38:28] rmse : 0.4442954659461975
[INFO 38:28] valid_rmse : 2.676478862762451
[INFO 38:28] valid_loss : 3.6337239742279053


 28%|██▊       | 28/100 [01:16<03:10,  2.65s/it]

[INFO 38:30] - learning_users_emb: 0.125
[INFO 38:31] rmse : 0.4440152943134308
[INFO 38:31] valid_rmse : 2.674790859222412
[INFO 38:31] valid_loss : 3.629537582397461


 29%|██▉       | 29/100 [01:19<03:08,  2.66s/it]

[INFO 38:33] - learning_users_emb: 0.125
[INFO 38:34] rmse : 0.443987637758255
[INFO 38:34] valid_rmse : 2.674624443054199
[INFO 38:34] valid_loss : 3.6286356449127197


 30%|███       | 30/100 [01:21<03:05,  2.65s/it]

[INFO 38:36] - learning_users_emb: 0.125
[INFO 38:36] rmse : 0.4440348446369171
[INFO 38:36] valid_rmse : 2.6749086380004883
[INFO 38:36] valid_loss : 3.6287858486175537


 31%|███       | 31/100 [01:24<03:02,  2.64s/it]

[INFO 38:38] - learning_users_emb: 0.125
[INFO 38:39] rmse : 0.4440179467201233
[INFO 38:39] valid_rmse : 2.674807071685791
[INFO 38:39] valid_loss : 3.6279146671295166


 32%|███▏      | 32/100 [01:27<02:59,  2.64s/it]

[INFO 38:41] - learning_users_emb: 0.125
[INFO 38:42] rmse : 0.44388213753700256
[INFO 38:42] valid_rmse : 2.6739888191223145
[INFO 38:42] valid_loss : 3.625288486480713


 33%|███▎      | 33/100 [01:29<02:57,  2.65s/it]

[INFO 38:44] - learning_users_emb: 0.125
[INFO 38:44] rmse : 0.44364267587661743
[INFO 38:44] valid_rmse : 2.67254638671875
[INFO 38:44] valid_loss : 3.6212306022644043


 34%|███▍      | 34/100 [01:32<02:53,  2.63s/it]

[INFO 38:46] - learning_users_emb: 0.125
[INFO 38:47] rmse : 0.4432556927204132
[INFO 38:47] valid_rmse : 2.670215129852295
[INFO 38:47] valid_loss : 3.615617513656616


 35%|███▌      | 35/100 [01:34<02:51,  2.63s/it]

[INFO 38:49] - learning_users_emb: 0.125
[INFO 38:50] rmse : 0.44276097416877747
[INFO 38:50] valid_rmse : 2.6672348976135254
[INFO 38:50] valid_loss : 3.6084325313568115


 36%|███▌      | 36/100 [01:37<02:48,  2.63s/it]

[INFO 38:51] - learning_users_emb: 0.125
[INFO 38:52] rmse : 0.4422556757926941
[INFO 38:52] valid_rmse : 2.6641907691955566
[INFO 38:52] valid_loss : 3.6015284061431885


 37%|███▋      | 37/100 [01:40<02:45,  2.63s/it]

[INFO 38:54] - learning_users_emb: 0.125
[INFO 38:55] rmse : 0.4419209957122803
[INFO 38:55] valid_rmse : 2.662174701690674
[INFO 38:55] valid_loss : 3.5977120399475098


 38%|███▊      | 38/100 [01:42<02:43,  2.64s/it]

[INFO 38:57] - learning_users_emb: 0.125
[INFO 38:57] rmse : 0.4418443441390991
[INFO 38:57] valid_rmse : 2.661712884902954
[INFO 38:57] valid_loss : 3.5974276065826416


 39%|███▉      | 39/100 [01:45<02:41,  2.64s/it]

[INFO 38:59] - learning_users_emb: 0.125
[INFO 39:00] rmse : 0.4420427083969116
[INFO 39:00] valid_rmse : 2.662907838821411
[INFO 39:00] valid_loss : 3.6001946926116943


 40%|████      | 40/100 [01:48<02:38,  2.64s/it]

[INFO 39:02] - learning_users_emb: 0.125
[INFO 39:03] rmse : 0.4424031972885132
[INFO 39:03] valid_rmse : 2.6650795936584473
[INFO 39:03] valid_loss : 3.6051974296569824


 41%|████      | 41/100 [01:51<02:40,  2.72s/it]

[INFO 39:05] - learning_users_emb: 0.0625
[INFO 39:06] rmse : 0.4424974322319031
[INFO 39:06] valid_rmse : 2.665647268295288
[INFO 39:06] valid_loss : 3.6064157485961914


 42%|████▏     | 42/100 [01:53<02:36,  2.69s/it]

[INFO 39:07] - learning_users_emb: 0.0625
[INFO 39:08] rmse : 0.4424394965171814
[INFO 39:08] valid_rmse : 2.6652982234954834
[INFO 39:08] valid_loss : 3.6052494049072266


 43%|████▎     | 43/100 [01:56<02:31,  2.66s/it]

[INFO 39:10] - learning_users_emb: 0.0625
[INFO 39:11] rmse : 0.4422804117202759
[INFO 39:11] valid_rmse : 2.664339780807495
[INFO 39:11] valid_loss : 3.6029069423675537


 44%|████▍     | 44/100 [01:58<02:28,  2.65s/it]

[INFO 39:12] - learning_users_emb: 0.03125
[INFO 39:13] rmse : 0.4421996474266052
[INFO 39:13] valid_rmse : 2.663853406906128
[INFO 39:13] valid_loss : 3.6014745235443115


 45%|████▌     | 45/100 [02:01<02:20,  2.55s/it]

[INFO 39:15] - learning_users_emb: 0.03125
[INFO 39:16] rmse : 0.4420846104621887
[INFO 39:16] valid_rmse : 2.6631603240966797
[INFO 39:16] valid_loss : 3.5998172760009766


 46%|████▌     | 46/100 [02:03<02:18,  2.56s/it]

[INFO 39:18] - learning_users_emb: 0.03125
[INFO 39:18] rmse : 0.44198861718177795
[INFO 39:18] valid_rmse : 2.6625821590423584
[INFO 39:18] valid_loss : 3.598261833190918


 47%|████▋     | 47/100 [02:06<02:16,  2.58s/it]

[INFO 39:20] - learning_users_emb: 0.015625
[INFO 39:21] rmse : 0.4419458508491516
[INFO 39:21] valid_rmse : 2.6623244285583496
[INFO 39:21] valid_loss : 3.5975544452667236


 48%|████▊     | 48/100 [02:09<02:14,  2.59s/it]

[INFO 39:23] - learning_users_emb: 0.015625
[INFO 39:24] rmse : 0.4419134259223938
[INFO 39:24] valid_rmse : 2.6621291637420654
[INFO 39:24] valid_loss : 3.5969557762145996


 49%|████▉     | 49/100 [02:11<02:12,  2.60s/it]

[INFO 39:26] - learning_users_emb: 0.015625
[INFO 39:26] rmse : 0.44188162684440613
[INFO 39:26] valid_rmse : 2.6619374752044678
[INFO 39:26] valid_loss : 3.5964481830596924


 50%|█████     | 50/100 [02:14<02:10,  2.61s/it]

[INFO 39:28] - learning_users_emb: 0.015625
[INFO 39:29] rmse : 0.4418759346008301
[INFO 39:29] valid_rmse : 2.6619033813476562
[INFO 39:29] valid_loss : 3.5961105823516846


 51%|█████     | 51/100 [02:16<02:07,  2.61s/it]

[INFO 39:31] - learning_users_emb: 0.015625
[INFO 39:31] rmse : 0.44186410307884216
[INFO 39:31] valid_rmse : 2.661832094192505
[INFO 39:31] valid_loss : 3.5958616733551025


 52%|█████▏    | 52/100 [02:19<02:04,  2.60s/it]

[INFO 39:33] - learning_users_emb: 0.015625
[INFO 39:34] rmse : 0.44185876846313477
[INFO 39:34] valid_rmse : 2.661799907684326
[INFO 39:34] valid_loss : 3.595569372177124


 53%|█████▎    | 53/100 [02:22<02:02,  2.60s/it]

[INFO 39:36] - learning_users_emb: 0.015625
[INFO 39:37] rmse : 0.4418698847293854
[INFO 39:37] valid_rmse : 2.6618669033050537
[INFO 39:37] valid_loss : 3.595484972000122


 54%|█████▍    | 54/100 [02:24<01:58,  2.59s/it]

[INFO 39:38] - learning_users_emb: 0.015625
[INFO 39:39] rmse : 0.4418908953666687
[INFO 39:39] valid_rmse : 2.6619935035705566
[INFO 39:39] valid_loss : 3.595421075820923


 55%|█████▌    | 55/100 [02:27<01:56,  2.59s/it]

[INFO 39:41] - learning_users_emb: 0.015625
[INFO 39:42] rmse : 0.44189217686653137
[INFO 39:42] valid_rmse : 2.662001132965088
[INFO 39:42] valid_loss : 3.595299482345581


 56%|█████▌    | 56/100 [02:29<01:53,  2.58s/it]

[INFO 39:44] - learning_users_emb: 0.015625
[INFO 39:44] rmse : 0.44189947843551636
[INFO 39:44] valid_rmse : 2.6620450019836426
[INFO 39:44] valid_loss : 3.5952422618865967


 57%|█████▋    | 57/100 [02:32<01:51,  2.59s/it]

[INFO 39:46] - learning_users_emb: 0.0078125
[INFO 39:47] rmse : 0.4419003427028656
[INFO 39:47] valid_rmse : 2.662050247192383
[INFO 39:47] valid_loss : 3.59509015083313


 58%|█████▊    | 58/100 [02:34<01:48,  2.58s/it]

[INFO 39:49] - learning_users_emb: 0.0078125
[INFO 39:50] rmse : 0.44190770387649536
[INFO 39:50] valid_rmse : 2.6620945930480957
[INFO 39:50] valid_loss : 3.595012903213501


 58%|█████▊    | 58/100 [02:37<01:54,  2.72s/it]

[INFO 39:50] -- END Training --
[INFO 39:50] train on cuda



100%|██████████| 6/6 [00:00<00:00,  8.92it/s]


-0.022887408128349388
({0: {'mi_acc': 0.7017393112182617}, 1: {'mi_acc': 0.7017393112182617}, 2: {'mi_acc': 0.7017393112182617}, 3: {'mi_acc': 0.7017393112182617}, 4: {'mi_acc': 0.7017393112182617}, 5: {'mi_acc': 0.7017393112182617}}, {0: {'meta_doa': np.float64(0.4731912698967316)}, 1: {'meta_doa': np.float64(0.4785325609134498)}, 2: {'meta_doa': np.float64(0.488984180994137)}, 3: {'meta_doa': np.float64(0.4829888078191557)}, 4: {'meta_doa': np.float64(0.4831064777484556)}, 5: {'meta_doa': np.float64(0.4797462898645038)}})


In [None]:
del S

In [None]:
meta_trainers = ['GAP']
for meta_trainer in meta_trainers : 
    config['meta_trainer'] = meta_trainer
    logging.info(f'#### meta_trainer : {config["meta_trainer"]} ####')
    for i_fold in range(1,2) : 
        config['i_fold'] = i_fold
            
        logging.info(f'#### i_fold : {i_fold} ####')
        ## Dataframe columns : (user_id, question_id, response, category_id)
        train_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_train_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        valid_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_valid_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        test_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_test_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})

        train_data = liriscat.dataset.CATDataset(train_df, concept_map, metadata, config,nb_modalities)
        valid_data = liriscat.dataset.EvalDataset(valid_df, concept_map, metadata, config,nb_modalities)
        test_data = liriscat.dataset.EvalDataset(test_df, concept_map, metadata, config,nb_modalities)

        for seed in range(1) :
            config['seed'] = seed
            logging.info(f'#### seed : {seed} ####')

            train_data.reset_rng()
            valid_data.reset_rng()
            test_data.reset_rng()

            S = liriscat.selectionStrategy.Random(metadata,**config)
            S.init_models(train_data, valid_data)
            S.train(train_data, valid_data)
            liriscat.utils.set_seed(0)
            S.reset_rng()
            d = (S.evaluate_test(test_data, train_data, valid_data))
            logging.info(d)
            logging.info(liriscat.utils.pareto_index(d))

    torch.cuda.empty_cache()

In [None]:
config['meta_trainer'] = 'GAP'
#config['meta_trainer'] = 'Adam'
for seed in range(1) :
    config['seed'] = seed
    logging.info(f'#### seed : {seed} ####')
    
    train_data.reset_rng()
    valid_data.reset_rng()
    test_data.reset_rng()


    S = liriscat.selectionStrategy.Random(metadata,**config)
    S.init_models(train_data, valid_data)
    S.train(train_data, valid_data)
    liriscat.utils.set_seed(0)
    S.reset_rng()
    d = (S.evaluate_test(test_data, train_data, valid_data))
    print(liriscat.utils.pareto_index(d))
    print(d)
    torch.cuda.empty_cache()