# Liriscat paper experiments
### 1. Init
#### 1.1. Import libraries (necessary)

In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTHONHASHSEED"] = "0"
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

import liriscat
liriscat.utils.set_seed(0)

import logging
import gc
import json
import torch
liriscat.utils.set_seed(0)
import pandas as pd
from importlib import reload

#### 1.2. Set up the loggers (recommended)

In [2]:
liriscat.utils.setuplogger(verbose = True, log_name="liriscat", debug=False)

### 2. CDM prediction
#### 2.1. Training and testing, sequential version

In [3]:
import warnings

gc.collect()
torch.cuda.empty_cache()

In [None]:
config = liriscat.utils.generate_eval_config(load_params=True, esc = 'error', valid_metric= 'mi_acc', pred_metrics = ["mi_acc"], profile_metrics = ['meta_doa'], save_params=False, n_query=6, num_epochs=100, batch_size=512)
liriscat.utils.set_seed(config["seed"])

config["dataset_name"] = "math2"
logging.info(config["dataset_name"])
config['learning_rate'] = 0.0001
config['inner_user_lr'] = 0.016848380924625605
config['lambda'] = 9.972254466547545e-06

config['meta_lr'] = 0.5
config['patience'] = 20#5
config['num_inner_users_epochs'] = 3
config['d_in'] = 4
config['num_responses'] = 12

#pred_metrics,df_interp = test(config)

CUDA is available. Using GPU.
[INFO 13:20] math2


In [5]:
logging.info(f'#### {config["dataset_name"]} ####')
logging.info(f'#### config : {config} ####')
config['embs_path']='../embs/'
config['params_path']='../ckpt/'

gc.collect()
torch.cuda.empty_cache()

# Dataset downloading for doa and rm
warnings.filterwarnings("ignore", message="invalid value encountered in divide")
warnings.filterwarnings("ignore", category=RuntimeWarning)

## Concept map format : {question_id : [category_id1, category_id2, ...]}
concept_map = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_concept_map.json', 'r'))
concept_map = {int(k): [int(x) for x in v] for k, v in concept_map.items()}

## Metadata map format : {"num_user_id": ..., "num_item_id": ..., "num_dimension_id": ...}
metadata = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_metadata.json', 'r'))


## Tensor containing the nb of modalities per question
nb_modalities = torch.load(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_nb_modalities.pkl',weights_only=True)


[INFO 13:20] #### math2 ####
[INFO 13:20] #### config : {'seed': 0, 'dataset_name': 'math2', 'load_params': True, 'save_params': False, 'embs_path': '../embs/', 'params_path': '../ckpt/', 'early_stopping': True, 'esc': 'error', 'verbose_early_stopping': False, 'disable_tqdm': False, 'valid_metric': 'mi_acc', 'learning_rate': 0.0001, 'batch_size': 512, 'valid_batch_size': 10000, 'num_epochs': 100, 'eval_freq': 1, 'patience': 20, 'device': device(type='cuda'), 'lambda': 9.972254466547545e-06, 'tensorboard': False, 'flush_freq': True, 'pred_metrics': ['mi_acc'], 'profile_metrics': ['meta_doa'], 'num_responses': 12, 'low_mem': False, 'n_query': 6, 'CDM': 'impact', 'i_fold': 0, 'num_inner_users_epochs': 3, 'num_inner_epochs': 10, 'inner_lr': 0.0001, 'inner_user_lr': 0.016848380924625605, 'meta_trainer': 'Adam', 'num_workers': 0, 'pin_memory': False, 'debug': False, 'd_in': 4} ####


In [None]:
meta_trainers = ['Approx_GAP']
for meta_trainer in meta_trainers : 
    config['meta_trainer'] = meta_trainer
    logging.info(f'#### meta_trainer : {config["meta_trainer"]} ####')
    for i_fold in range(1,2) : 
        config['i_fold'] = i_fold
            
        logging.info(f'#### i_fold : {i_fold} ####')
        ## Dataframe columns : (user_id, question_id, response, category_id)
        train_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_train_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        valid_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_valid_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        test_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_test_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})

        train_data = liriscat.dataset.CATDataset(train_df, concept_map, metadata, config,nb_modalities)
        valid_data = liriscat.dataset.EvalDataset(valid_df, concept_map, metadata, config,nb_modalities)
        test_data = liriscat.dataset.EvalDataset(test_df, concept_map, metadata, config,nb_modalities)

        for seed in range(1) :
            config['seed'] = seed
            logging.info(f'#### seed : {seed} ####')

            train_data.reset_rng()
            valid_data.reset_rng()
            test_data.reset_rng()

            S = liriscat.selectionStrategy.Random(metadata,**config)
            S.init_models(train_data, valid_data)
            S.train(train_data, valid_data)
            liriscat.utils.set_seed(0)
            S.reset_rng()
            d = (S.evaluate_test(test_data, train_data, valid_data))
            logging.info(d)
            logging.info(liriscat.utils.pareto_index(d))

    torch.cuda.empty_cache()

[INFO 13:20] #### meta_trainer : Approx_GAP ####
[INFO 13:20] #### i_fold : 1 ####
[INFO 13:29] #### seed : 0 ####
[INFO 13:29] Random_cont_model
[INFO 13:29] compiling CDM model
[INFO 13:31] compiling selection model
[INFO 13:31] ------- START Training
[INFO 13:31] train on cuda


  0%|          | 0/100 [00:00<?, ?it/s]

[INFO 13:32] ----- User batch : 0
[INFO 13:37] ----- User batch : 1
[INFO 13:40] ----- User batch : 2
[INFO 13:43] ----- User batch : 3
[INFO 13:46] ----- User batch : 4
[INFO 13:48] - meta_params gn: [1.0539125204086304]
[INFO 13:48] - meta_params norm: [34.032470703125]
[INFO 13:48] - meta_params: 0.5
[INFO 13:54] rmse : 0.6702445149421692
[INFO 13:54] valid_rmse : 0.857090175151825
[INFO 13:54] valid_loss : 1.2114529609680176


  1%|          | 1/100 [00:22<36:33, 22.16s/it]

[INFO 13:55] ----- User batch : 0
[INFO 13:58] ----- User batch : 1
[INFO 14:02] ----- User batch : 2
[INFO 14:05] ----- User batch : 3
[INFO 14:09] ----- User batch : 4
[INFO 14:10] - meta_params gn: [0.999675452709198]
[INFO 14:10] - meta_params norm: [35.73713684082031]
[INFO 14:10] - meta_params: 0.5
[INFO 14:15] rmse : 0.6698727011680603
[INFO 14:15] valid_rmse : 0.8566147089004517
[INFO 14:15] valid_loss : 1.2094014883041382


  2%|▏         | 2/100 [00:43<35:28, 21.72s/it]

[INFO 14:16] ----- User batch : 0
[INFO 14:19] ----- User batch : 1
[INFO 14:23] ----- User batch : 2
[INFO 14:26] ----- User batch : 3
[INFO 14:30] ----- User batch : 4
[INFO 14:32] - meta_params gn: [1.0290346145629883]
[INFO 14:32] - meta_params norm: [37.426513671875]
[INFO 14:32] - meta_params: 0.5
[INFO 14:36] rmse : 0.6676375269889832
[INFO 14:36] valid_rmse : 0.8537564277648926
[INFO 14:36] valid_loss : 1.2065422534942627


  3%|▎         | 3/100 [01:04<34:19, 21.23s/it]

[INFO 14:37] ----- User batch : 0
[INFO 14:40] ----- User batch : 1
[INFO 14:44] ----- User batch : 2
[INFO 14:48] ----- User batch : 3
[INFO 14:51] ----- User batch : 4
[INFO 14:53] - meta_params gn: [0.9314559102058411]
[INFO 14:53] - meta_params norm: [39.079708099365234]
[INFO 14:53] - meta_params: 0.5
[INFO 14:58] rmse : 0.6680851578712463
[INFO 14:58] valid_rmse : 0.8543288111686707
[INFO 14:58] valid_loss : 1.205987811088562


  4%|▍         | 4/100 [01:25<34:12, 21.38s/it]

[INFO 14:58] ----- User batch : 0
[INFO 15:02] ----- User batch : 1
[INFO 15:05] ----- User batch : 2
[INFO 15:09] ----- User batch : 3
[INFO 15:12] ----- User batch : 4
[INFO 15:14] - meta_params gn: [0.9522802829742432]
[INFO 15:14] - meta_params norm: [40.78749465942383]
[INFO 15:14] - meta_params: 0.5
[INFO 15:19] rmse : 0.6685324907302856
[INFO 15:19] valid_rmse : 0.8549008965492249
[INFO 15:19] valid_loss : 1.2023955583572388


  5%|▌         | 5/100 [01:46<33:35, 21.21s/it]

[INFO 15:19] ----- User batch : 0
[INFO 15:23] ----- User batch : 1
[INFO 15:26] ----- User batch : 2
[INFO 15:30] ----- User batch : 3
[INFO 15:33] ----- User batch : 4
[INFO 15:35] - meta_params gn: [0.9346821904182434]
[INFO 15:35] - meta_params norm: [42.49063491821289]
[INFO 15:35] - meta_params: 0.5
[INFO 15:40] rmse : 0.6689050793647766
[INFO 15:40] valid_rmse : 0.8553773164749146
[INFO 15:40] valid_loss : 1.20346999168396


  6%|▌         | 6/100 [02:08<33:17, 21.25s/it]

[INFO 15:40] ----- User batch : 0
[INFO 15:44] ----- User batch : 1
[INFO 15:48] ----- User batch : 2
[INFO 15:51] ----- User batch : 3
[INFO 15:55] ----- User batch : 4
[INFO 15:57] - meta_params gn: [0.9060955047607422]
[INFO 15:57] - meta_params norm: [44.183658599853516]
[INFO 15:57] - meta_params: 0.5
[INFO 16:01] rmse : 0.6655446290969849
[INFO 16:01] valid_rmse : 0.851080060005188
[INFO 16:01] valid_loss : 1.2026267051696777


  7%|▋         | 7/100 [02:29<32:59, 21.29s/it]

[INFO 16:02] ----- User batch : 0
[INFO 16:06] ----- User batch : 1
[INFO 16:09] ----- User batch : 2
[INFO 16:13] ----- User batch : 3
[INFO 16:16] ----- User batch : 4
[INFO 16:18] - meta_params gn: [0.9169228672981262]
[INFO 16:18] - meta_params norm: [45.8874626159668]
[INFO 16:18] - meta_params: 0.5
[INFO 16:23] rmse : 0.6675629019737244
[INFO 16:23] valid_rmse : 0.8536610007286072
[INFO 16:23] valid_loss : 1.2012722492218018


  8%|▊         | 8/100 [02:50<32:42, 21.33s/it]

[INFO 16:23] ----- User batch : 0
[INFO 16:27] ----- User batch : 1
[INFO 16:31] ----- User batch : 2
[INFO 16:34] ----- User batch : 3
[INFO 16:38] ----- User batch : 4
[INFO 16:40] - meta_params gn: [0.9223100543022156]
[INFO 16:40] - meta_params norm: [47.62602233886719]
[INFO 16:40] - meta_params: 0.5
[INFO 16:45] rmse : 0.6656194925308228
[INFO 16:45] valid_rmse : 0.8511757850646973
[INFO 16:45] valid_loss : 1.2014418840408325


  9%|▉         | 9/100 [03:12<32:34, 21.48s/it]

[INFO 16:45] ----- User batch : 0
[INFO 16:49] ----- User batch : 1
[INFO 16:52] ----- User batch : 2
[INFO 16:56] ----- User batch : 3
[INFO 16:59] ----- User batch : 4
[INFO 17:01] - meta_params gn: [0.8750215172767639]
[INFO 17:01] - meta_params norm: [49.36465072631836]
[INFO 17:01] - meta_params: 0.5
[INFO 17:06] rmse : 0.6664423942565918
[INFO 17:06] valid_rmse : 0.8522281050682068
[INFO 17:06] valid_loss : 1.1956381797790527


 10%|█         | 10/100 [03:33<31:59, 21.32s/it]

[INFO 17:06] ----- User batch : 0
[INFO 17:10] ----- User batch : 1
[INFO 17:13] ----- User batch : 2
[INFO 17:17] ----- User batch : 3
[INFO 17:20] ----- User batch : 4
[INFO 17:22] - meta_params gn: [0.8252658843994141]
[INFO 17:23] - meta_params norm: [51.08425521850586]
[INFO 17:23] - meta_params: 0.5
[INFO 17:27] rmse : 0.6671896576881409
[INFO 17:27] valid_rmse : 0.8531836867332458
[INFO 17:27] valid_loss : 1.2007783651351929


 11%|█         | 11/100 [03:55<31:49, 21.45s/it]

[INFO 17:28] ----- User batch : 0
[INFO 17:31] ----- User batch : 1
[INFO 17:35] ----- User batch : 2
[INFO 17:39] ----- User batch : 3
[INFO 17:42] ----- User batch : 4
[INFO 17:44] - meta_params gn: [0.8385043144226074]
[INFO 17:44] - meta_params norm: [52.809356689453125]
[INFO 17:44] - meta_params: 0.5
[INFO 17:49] rmse : 0.6652451157569885
[INFO 17:49] valid_rmse : 0.8506970405578613
[INFO 17:49] valid_loss : 1.1945230960845947


 12%|█▏        | 12/100 [04:16<31:23, 21.41s/it]

[INFO 17:49] ----- User batch : 0
[INFO 17:52] ----- User batch : 1
[INFO 17:56] ----- User batch : 2
[INFO 17:59] ----- User batch : 3
[INFO 18:03] ----- User batch : 4
[INFO 18:05] - meta_params gn: [0.8141980171203613]
[INFO 18:05] - meta_params norm: [54.541961669921875]
[INFO 18:05] - meta_params: 0.5
[INFO 18:10] rmse : 0.6635953187942505
[INFO 18:10] valid_rmse : 0.8485873341560364
[INFO 18:10] valid_loss : 1.1951755285263062


 13%|█▎        | 13/100 [04:37<30:52, 21.30s/it]

[INFO 18:10] ----- User batch : 0
[INFO 18:14] ----- User batch : 1
[INFO 18:17] ----- User batch : 2
[INFO 18:21] ----- User batch : 3
[INFO 18:24] ----- User batch : 4
[INFO 18:26] - meta_params gn: [0.8061243295669556]
[INFO 18:26] - meta_params norm: [56.26485061645508]
[INFO 18:26] - meta_params: 0.5
[INFO 18:31] rmse : 0.6656194925308228
[INFO 18:31] valid_rmse : 0.8511757850646973
[INFO 18:31] valid_loss : 1.1969937086105347


 14%|█▍        | 14/100 [04:58<30:23, 21.20s/it]

[INFO 18:31] ----- User batch : 0
[INFO 18:35] ----- User batch : 1
[INFO 18:38] ----- User batch : 2
[INFO 18:42] ----- User batch : 3
[INFO 18:45] ----- User batch : 4
[INFO 18:47] - meta_params gn: [0.8100632429122925]
[INFO 18:47] - meta_params norm: [57.99504470825195]
[INFO 18:47] - meta_params: 0.5
[INFO 18:52] rmse : 0.6607359647750854
[INFO 18:52] valid_rmse : 0.84493088722229
[INFO 18:52] valid_loss : 1.1936310529708862


 15%|█▌        | 15/100 [05:19<30:03, 21.22s/it]

[INFO 18:52] ----- User batch : 0
[INFO 18:56] ----- User batch : 1
[INFO 18:59] ----- User batch : 2
[INFO 19:03] ----- User batch : 3
[INFO 19:06] ----- User batch : 4
[INFO 19:08] - meta_params gn: [0.8293323516845703]
[INFO 19:08] - meta_params norm: [59.73311996459961]
[INFO 19:08] - meta_params: 0.5
[INFO 19:13] rmse : 0.661715567111969
[INFO 19:13] valid_rmse : 0.8461835980415344
[INFO 19:13] valid_loss : 1.193700909614563


 16%|█▌        | 16/100 [05:40<29:33, 21.11s/it]

[INFO 19:13] ----- User batch : 0
[INFO 19:17] ----- User batch : 1
[INFO 19:21] ----- User batch : 2
[INFO 19:24] ----- User batch : 3
[INFO 19:28] ----- User batch : 4
[INFO 19:30] - meta_params gn: [0.8263159394264221]
[INFO 19:30] - meta_params norm: [61.461395263671875]
[INFO 19:30] - meta_params: 0.5
[INFO 19:34] rmse : 0.6586967706680298
[INFO 19:34] valid_rmse : 0.8423232436180115
[INFO 19:34] valid_loss : 1.1888935565948486


 17%|█▋        | 17/100 [06:02<29:23, 21.25s/it]

[INFO 19:35] ----- User batch : 0
[INFO 19:38] ----- User batch : 1
[INFO 19:42] ----- User batch : 2
[INFO 19:46] ----- User batch : 3
[INFO 19:49] ----- User batch : 4
[INFO 19:51] - meta_params gn: [0.833037793636322]
[INFO 19:51] - meta_params norm: [63.191192626953125]
[INFO 19:51] - meta_params: 0.5
[INFO 19:55] rmse : 0.6599059104919434
[INFO 19:55] valid_rmse : 0.8438694477081299
[INFO 19:55] valid_loss : 1.1932599544525146


 18%|█▊        | 18/100 [06:23<28:57, 21.19s/it]

[INFO 19:56] ----- User batch : 0
[INFO 19:59] ----- User batch : 1
[INFO 20:03] ----- User batch : 2
[INFO 20:07] ----- User batch : 3
[INFO 20:10] ----- User batch : 4
[INFO 20:12] - meta_params gn: [0.6743336319923401]
[INFO 20:12] - meta_params norm: [64.92317962646484]
[INFO 20:12] - meta_params: 0.5
[INFO 20:17] rmse : 0.6565753221511841
[INFO 20:17] valid_rmse : 0.8396103978157043
[INFO 20:17] valid_loss : 1.1881515979766846


 19%|█▉        | 19/100 [06:44<28:43, 21.27s/it]

[INFO 20:17] ----- User batch : 0
[INFO 20:21] ----- User batch : 1
[INFO 20:24] ----- User batch : 2
[INFO 20:28] ----- User batch : 3
[INFO 20:31] ----- User batch : 4
[INFO 20:33] - meta_params gn: [0.7052016854286194]
[INFO 20:33] - meta_params norm: [66.65099334716797]
[INFO 20:33] - meta_params: 0.5
[INFO 20:37] rmse : 0.6598303914070129
[INFO 20:37] valid_rmse : 0.8437728881835938
[INFO 20:37] valid_loss : 1.1841599941253662


 20%|██        | 20/100 [07:05<28:05, 21.07s/it]

[INFO 20:38] ----- User batch : 0
[INFO 20:41] ----- User batch : 1
[INFO 20:45] ----- User batch : 2
[INFO 20:49] ----- User batch : 3
[INFO 20:52] ----- User batch : 4
[INFO 20:54] - meta_params gn: [0.729849100112915]
[INFO 20:54] - meta_params norm: [68.38319396972656]
[INFO 20:54] - meta_params: 0.5
[INFO 20:59] rmse : 0.6553599834442139
[INFO 20:59] valid_rmse : 0.838056206703186
[INFO 20:59] valid_loss : 1.185343861579895


 21%|██        | 21/100 [07:27<27:55, 21.21s/it]

[INFO 20:59] ----- User batch : 0
[INFO 21:03] ----- User batch : 1
[INFO 21:07] ----- User batch : 2
[INFO 21:10] ----- User batch : 3
[INFO 21:13] ----- User batch : 4
[INFO 21:15] - meta_params gn: [0.7072213292121887]
[INFO 21:15] - meta_params norm: [70.11939239501953]
[INFO 21:15] - meta_params: 0.5
[INFO 21:20] rmse : 0.6579399108886719
[INFO 21:20] valid_rmse : 0.8413553833961487
[INFO 21:20] valid_loss : 1.18706476688385


 22%|██▏       | 22/100 [07:48<27:31, 21.17s/it]

[INFO 21:21] ----- User batch : 0
[INFO 21:24] ----- User batch : 1
[INFO 21:28] ----- User batch : 2
[INFO 21:31] ----- User batch : 3
[INFO 21:35] ----- User batch : 4
[INFO 21:37] - meta_params gn: [0.7726195454597473]
[INFO 21:37] - meta_params norm: [71.86056518554688]
[INFO 21:37] - meta_params: 0.5
[INFO 21:41] rmse : 0.6542947292327881
[INFO 21:41] valid_rmse : 0.8366940021514893
[INFO 21:41] valid_loss : 1.1808844804763794


 23%|██▎       | 23/100 [08:09<27:16, 21.25s/it]

[INFO 21:42] ----- User batch : 0
[INFO 21:45] ----- User batch : 1
[INFO 21:49] ----- User batch : 2
[INFO 21:53] ----- User batch : 3
[INFO 21:56] ----- User batch : 4
[INFO 21:58] - meta_params gn: [0.6661918759346008]
[INFO 21:58] - meta_params norm: [73.55872344970703]
[INFO 21:58] - meta_params: 0.5
[INFO 22:02] rmse : 0.6576368808746338
[INFO 22:02] valid_rmse : 0.8409678339958191
[INFO 22:03] valid_loss : 1.1883528232574463


 24%|██▍       | 24/100 [08:30<26:50, 21.19s/it]

[INFO 22:03] ----- User batch : 0
[INFO 22:06] ----- User batch : 1
[INFO 22:10] ----- User batch : 2
[INFO 22:14] ----- User batch : 3
[INFO 22:17] ----- User batch : 4
[INFO 22:19] - meta_params gn: [0.6565229892730713]
[INFO 22:19] - meta_params norm: [75.25756072998047]
[INFO 22:19] - meta_params: 0.5
[INFO 22:24] rmse : 0.6591504216194153
[INFO 22:24] valid_rmse : 0.8429033160209656
[INFO 22:24] valid_loss : 1.1845974922180176


 25%|██▌       | 25/100 [08:51<26:31, 21.22s/it]

[INFO 22:24] ----- User batch : 0
[INFO 22:28] ----- User batch : 1
[INFO 22:31] ----- User batch : 2
[INFO 22:35] ----- User batch : 3
[INFO 22:38] ----- User batch : 4
[INFO 22:40] - meta_params gn: [0.6940228343009949]
[INFO 22:40] - meta_params norm: [76.96266174316406]
[INFO 22:40] - meta_params: 0.5
[INFO 22:45] rmse : 0.6586967706680298
[INFO 22:45] valid_rmse : 0.8423232436180115
[INFO 22:45] valid_loss : 1.1771727800369263


 26%|██▌       | 26/100 [09:13<26:11, 21.24s/it]

[INFO 22:45] ----- User batch : 0
[INFO 22:49] ----- User batch : 1
[INFO 22:53] ----- User batch : 2
[INFO 22:57] ----- User batch : 3
[INFO 23:00] ----- User batch : 4
[INFO 23:02] - meta_params gn: [0.6352449655532837]
[INFO 23:02] - meta_params norm: [78.6818618774414]
[INFO 23:02] - meta_params: 0.5
[INFO 23:07] rmse : 0.6568029522895813
[INFO 23:07] valid_rmse : 0.8399014472961426
[INFO 23:07] valid_loss : 1.1827210187911987


 27%|██▋       | 27/100 [09:34<25:54, 21.30s/it]

[INFO 23:07] ----- User batch : 0
[INFO 23:10] ----- User batch : 1
[INFO 23:14] ----- User batch : 2
[INFO 23:17] ----- User batch : 3
[INFO 23:20] ----- User batch : 4
[INFO 23:22] - meta_params gn: [0.6145342588424683]
[INFO 23:22] - meta_params norm: [80.39930725097656]
[INFO 23:22] - meta_params: 0.5
[INFO 23:26] rmse : 0.6536852717399597
[INFO 23:26] valid_rmse : 0.835914671421051
[INFO 23:26] valid_loss : 1.1744189262390137


 28%|██▊       | 28/100 [09:54<24:56, 20.79s/it]

[INFO 23:27] ----- User batch : 0
[INFO 23:30] ----- User batch : 1
[INFO 23:33] ----- User batch : 2
[INFO 23:36] ----- User batch : 3
[INFO 23:40] ----- User batch : 4
[INFO 23:41] - meta_params gn: [0.6304660439491272]
[INFO 23:41] - meta_params norm: [82.12349700927734]
[INFO 23:41] - meta_params: 0.5
[INFO 23:46] rmse : 0.6505526304244995
[INFO 23:46] valid_rmse : 0.8319087028503418
[INFO 23:46] valid_loss : 1.1764699220657349


 29%|██▉       | 29/100 [10:13<24:10, 20.43s/it]

[INFO 23:46] ----- User batch : 0
[INFO 23:49] ----- User batch : 1
[INFO 23:53] ----- User batch : 2
[INFO 23:56] ----- User batch : 3
[INFO 23:59] ----- User batch : 4
[INFO 24:01] - meta_params gn: [0.6921994686126709]
[INFO 24:01] - meta_params norm: [83.84481811523438]
[INFO 24:01] - meta_params: 0.5
[INFO 24:06] rmse : 0.6490957736968994
[INFO 24:06] valid_rmse : 0.830045759677887
[INFO 24:06] valid_loss : 1.17615807056427


 30%|███       | 30/100 [10:33<23:37, 20.26s/it]

[INFO 24:06] ----- User batch : 0
[INFO 24:09] ----- User batch : 1
[INFO 24:13] ----- User batch : 2
[INFO 24:16] ----- User batch : 3
[INFO 24:19] ----- User batch : 4
[INFO 24:21] - meta_params gn: [0.621532142162323]
[INFO 24:21] - meta_params norm: [85.56565856933594]
[INFO 24:21] - meta_params: 0.5
[INFO 24:26] rmse : 0.6528462767601013
[INFO 24:26] valid_rmse : 0.834841787815094
[INFO 24:26] valid_loss : 1.1792118549346924


 31%|███       | 31/100 [10:54<23:20, 20.30s/it]

[INFO 24:26] ----- User batch : 0
[INFO 24:30] ----- User batch : 1
[INFO 24:33] ----- User batch : 2
[INFO 24:37] ----- User batch : 3
[INFO 24:40] ----- User batch : 4
[INFO 24:42] - meta_params gn: [0.5886638760566711]
[INFO 24:42] - meta_params norm: [87.29208374023438]
[INFO 24:42] - meta_params: 0.25
[INFO 24:47] rmse : 0.6510120034217834
[INFO 24:47] valid_rmse : 0.832496166229248
[INFO 24:47] valid_loss : 1.1797019243240356


 32%|███▏      | 32/100 [11:15<23:13, 20.50s/it]

[INFO 24:47] ----- User batch : 0
[INFO 24:51] ----- User batch : 1
[INFO 24:54] ----- User batch : 2
[INFO 24:58] ----- User batch : 3
[INFO 25:01] ----- User batch : 4
[INFO 25:03] - meta_params gn: [0.5957353711128235]
[INFO 25:03] - meta_params norm: [88.15045166015625]
[INFO 25:03] - meta_params: 0.25
[INFO 25:07] rmse : 0.6511650681495667
[INFO 25:07] valid_rmse : 0.8326919078826904
[INFO 25:07] valid_loss : 1.1743756532669067


 33%|███▎      | 33/100 [11:35<22:51, 20.47s/it]

[INFO 25:08] ----- User batch : 0
[INFO 25:11] ----- User batch : 1
[INFO 25:15] ----- User batch : 2
[INFO 25:18] ----- User batch : 3
[INFO 25:21] ----- User batch : 4
[INFO 25:23] - meta_params gn: [0.6643566489219666]
[INFO 25:23] - meta_params norm: [89.00918579101562]
[INFO 25:23] - meta_params: 0.25
[INFO 25:28] rmse : 0.6552839875221252
[INFO 25:28] valid_rmse : 0.8379590511322021
[INFO 25:28] valid_loss : 1.1783827543258667


 34%|███▍      | 34/100 [11:55<22:29, 20.44s/it]

[INFO 25:28] ----- User batch : 0
[INFO 25:31] ----- User batch : 1
[INFO 25:35] ----- User batch : 2
[INFO 25:38] ----- User batch : 3
[INFO 25:41] ----- User batch : 4
[INFO 25:43] - meta_params gn: [0.6044062972068787]
[INFO 25:43] - meta_params norm: [89.8681869506836]
[INFO 25:43] - meta_params: 0.125
[INFO 25:47] rmse : 0.6552079319953918
[INFO 25:47] valid_rmse : 0.8378617763519287
[INFO 25:47] valid_loss : 1.1768302917480469


 35%|███▌      | 35/100 [12:15<21:55, 20.24s/it]

[INFO 25:48] ----- User batch : 0
[INFO 25:51] ----- User batch : 1
[INFO 25:55] ----- User batch : 2
[INFO 25:58] ----- User batch : 3
[INFO 26:01] ----- User batch : 4
[INFO 26:03] - meta_params gn: [0.5988017916679382]
[INFO 26:03] - meta_params norm: [90.29784393310547]
[INFO 26:03] - meta_params: 0.125
[INFO 26:07] rmse : 0.6560438871383667
[INFO 26:07] valid_rmse : 0.8389307856559753
[INFO 26:07] valid_loss : 1.1756902933120728


 36%|███▌      | 36/100 [12:35<21:27, 20.12s/it]

[INFO 26:08] ----- User batch : 0
[INFO 26:11] ----- User batch : 1
[INFO 26:14] ----- User batch : 2
[INFO 26:18] ----- User batch : 3
[INFO 26:21] ----- User batch : 4
[INFO 26:23] - meta_params gn: [0.5940858125686646]
[INFO 26:23] - meta_params norm: [90.7286376953125]
[INFO 26:23] - meta_params: 0.125
[INFO 26:28] rmse : 0.654066264629364
[INFO 26:28] valid_rmse : 0.8364018797874451
[INFO 26:28] valid_loss : 1.1762202978134155


 37%|███▋      | 37/100 [12:55<21:12, 20.19s/it]

[INFO 26:28] ----- User batch : 0
[INFO 26:32] ----- User batch : 1
[INFO 26:35] ----- User batch : 2
[INFO 26:39] ----- User batch : 3
[INFO 26:42] ----- User batch : 4
[INFO 26:44] - meta_params gn: [0.607503354549408]
[INFO 26:44] - meta_params norm: [91.15703582763672]
[INFO 26:44] - meta_params: 0.0625
[INFO 26:48] rmse : 0.6549797654151917
[INFO 26:48] valid_rmse : 0.8375700116157532
[INFO 26:48] valid_loss : 1.1743037700653076


 38%|███▊      | 38/100 [13:16<20:52, 20.20s/it]

[INFO 26:48] ----- User batch : 0
[INFO 26:52] ----- User batch : 1
[INFO 26:56] ----- User batch : 2
[INFO 26:59] ----- User batch : 3
[INFO 27:02] ----- User batch : 4
[INFO 27:04] - meta_params gn: [0.6315391063690186]
[INFO 27:04] - meta_params norm: [91.37257385253906]
[INFO 27:04] - meta_params: 0.0625
[INFO 27:09] rmse : 0.6571821570396423
[INFO 27:09] valid_rmse : 0.8403863906860352
[INFO 27:09] valid_loss : 1.1757551431655884


 39%|███▉      | 39/100 [13:36<20:45, 20.41s/it]

[INFO 27:09] ----- User batch : 0
[INFO 27:13] ----- User batch : 1
[INFO 27:16] ----- User batch : 2
[INFO 27:19] ----- User batch : 3
[INFO 27:22] ----- User batch : 4
[INFO 27:24] - meta_params gn: [0.5955454707145691]
[INFO 27:24] - meta_params norm: [91.58595275878906]
[INFO 27:24] - meta_params: 0.0625
[INFO 27:29] rmse : 0.6575611233711243
[INFO 27:29] valid_rmse : 0.8408709764480591
[INFO 27:29] valid_loss : 1.174181580543518


 40%|████      | 40/100 [13:56<20:18, 20.31s/it]

[INFO 27:29] ----- User batch : 0
[INFO 27:33] ----- User batch : 1
[INFO 27:36] ----- User batch : 2
[INFO 27:40] ----- User batch : 3
[INFO 27:43] ----- User batch : 4
[INFO 27:45] - meta_params gn: [0.6062400937080383]
[INFO 27:45] - meta_params norm: [91.80094909667969]
[INFO 27:45] - meta_params: 0.0625
[INFO 27:49] rmse : 0.6538376808166504
[INFO 27:49] valid_rmse : 0.8361095786094666
[INFO 27:49] valid_loss : 1.174543857574463


 41%|████      | 41/100 [14:17<19:56, 20.29s/it]

[INFO 27:50] ----- User batch : 0
[INFO 27:53] ----- User batch : 1
[INFO 27:56] ----- User batch : 2
[INFO 28:00] ----- User batch : 3
[INFO 28:03] ----- User batch : 4
[INFO 28:05] - meta_params gn: [0.5785790681838989]
[INFO 28:05] - meta_params norm: [92.01519012451172]
[INFO 28:05] - meta_params: 0.0625
[INFO 28:09] rmse : 0.654066264629364
[INFO 28:09] valid_rmse : 0.8364018797874451
[INFO 28:09] valid_loss : 1.1767113208770752


 42%|████▏     | 42/100 [14:37<19:32, 20.21s/it]

[INFO 28:10] ----- User batch : 0
[INFO 28:13] ----- User batch : 1
[INFO 28:16] ----- User batch : 2
[INFO 28:20] ----- User batch : 3
[INFO 28:23] ----- User batch : 4
[INFO 28:25] - meta_params gn: [0.574317991733551]
[INFO 28:25] - meta_params norm: [92.2272720336914]
[INFO 28:25] - meta_params: 0.0625
[INFO 28:29] rmse : 0.6544470191001892
[INFO 28:29] valid_rmse : 0.8368887305259705
[INFO 28:29] valid_loss : 1.176623821258545


 43%|████▎     | 43/100 [14:57<19:09, 20.17s/it]

[INFO 28:30] ----- User batch : 0
[INFO 28:33] ----- User batch : 1
[INFO 28:37] ----- User batch : 2
[INFO 28:40] ----- User batch : 3
[INFO 28:43] ----- User batch : 4
[INFO 28:45] - meta_params gn: [0.6087666153907776]
[INFO 28:45] - meta_params norm: [92.43732452392578]
[INFO 28:45] - meta_params: 0.03125
[INFO 28:50] rmse : 0.6555120348930359
[INFO 28:50] valid_rmse : 0.8382506966590881
[INFO 28:50] valid_loss : 1.1724828481674194


 44%|████▍     | 44/100 [15:17<18:57, 20.32s/it]

[INFO 28:50] ----- User batch : 0
[INFO 28:54] ----- User batch : 1
[INFO 28:57] ----- User batch : 2
[INFO 29:00] ----- User batch : 3
[INFO 29:03] ----- User batch : 4
[INFO 29:05] - meta_params gn: [0.5661097764968872]
[INFO 29:05] - meta_params norm: [92.54194641113281]
[INFO 29:05] - meta_params: 0.03125
[INFO 29:09] rmse : 0.6567270755767822
[INFO 29:09] valid_rmse : 0.8398044109344482
[INFO 29:09] valid_loss : 1.1752777099609375


 45%|████▌     | 45/100 [15:37<18:24, 20.08s/it]

[INFO 29:10] ----- User batch : 0
[INFO 29:13] ----- User batch : 1
[INFO 29:17] ----- User batch : 2
[INFO 29:20] ----- User batch : 3
[INFO 29:23] ----- User batch : 4
[INFO 29:25] - meta_params gn: [0.593576967716217]
[INFO 29:25] - meta_params norm: [92.64635467529297]
[INFO 29:25] - meta_params: 0.03125
[INFO 29:29] rmse : 0.6514710783958435
[INFO 29:29] valid_rmse : 0.8330832123756409
[INFO 29:29] valid_loss : 1.1748018264770508


 46%|████▌     | 46/100 [15:57<18:04, 20.08s/it]

[INFO 29:30] ----- User batch : 0
[INFO 29:33] ----- User batch : 1
[INFO 29:36] ----- User batch : 2
[INFO 29:40] ----- User batch : 3
[INFO 29:43] ----- User batch : 4
[INFO 29:45] - meta_params gn: [0.5933699607849121]
[INFO 29:45] - meta_params norm: [92.7505874633789]
[INFO 29:45] - meta_params: 0.03125
[INFO 29:49] rmse : 0.653380274772644
[INFO 29:49] valid_rmse : 0.8355246186256409
[INFO 29:49] valid_loss : 1.1731940507888794


 47%|████▋     | 47/100 [16:17<17:35, 19.92s/it]

[INFO 29:49] ----- User batch : 0
[INFO 29:53] ----- User batch : 1
[INFO 29:56] ----- User batch : 2
[INFO 29:59] ----- User batch : 3
[INFO 30:03] ----- User batch : 4
[INFO 30:04] - meta_params gn: [0.5804523825645447]
[INFO 30:04] - meta_params norm: [92.8545913696289]
[INFO 30:04] - meta_params: 0.015625
[INFO 30:09] rmse : 0.6526936292648315
[INFO 30:09] valid_rmse : 0.8346465826034546
[INFO 30:09] valid_loss : 1.1744719743728638


 48%|████▊     | 48/100 [16:36<17:09, 19.79s/it]

[INFO 30:09] ----- User batch : 0
[INFO 30:13] ----- User batch : 1
[INFO 30:16] ----- User batch : 2
[INFO 30:20] ----- User batch : 3
[INFO 30:23] ----- User batch : 4
[INFO 30:25] - meta_params gn: [0.6483296155929565]
[INFO 30:25] - meta_params norm: [92.90681457519531]
[INFO 30:25] - meta_params: 0.015625
[INFO 30:29] rmse : 0.6521590352058411
[INFO 30:29] valid_rmse : 0.8339629173278809
[INFO 30:29] valid_loss : 1.1745935678482056


 49%|████▉     | 49/100 [16:57<16:58, 19.98s/it]

[INFO 30:29] ----- User batch : 0
[INFO 30:33] ----- User batch : 1
[INFO 30:36] ----- User batch : 2
[INFO 30:39] ----- User batch : 3


In [None]:
config['meta_trainer'] = 'Approx_GAP'
#config['meta_trainer'] = 'Adam'
for seed in range(1) :
    config['seed'] = seed
    logging.info(f'#### seed : {seed} ####')
    
    train_data.reset_rng()
    valid_data.reset_rng()
    test_data.reset_rng()


    S = liriscat.selectionStrategy.Random(metadata,**config)
    S.init_models(train_data, valid_data)
    S.train(train_data, valid_data)
    liriscat.utils.set_seed(0)
    S.reset_rng()
    d = (S.evaluate_test(test_data, train_data, valid_data))
    print(liriscat.utils.pareto_index(d))
    print(d)
    torch.cuda.empty_cache()