# Liriscat paper experiments
### 1. Init
#### 1.1. Import libraries (necessary)

In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTHONHASHSEED"] = "0"
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

import liriscat
liriscat.utils.set_seed(0)

import logging
import gc
import json
import torch
liriscat.utils.set_seed(0)
import pandas as pd
from importlib import reload

#### 1.2. Set up the loggers (recommended)

In [2]:
liriscat.utils.setuplogger(verbose = True, log_name="liriscat", debug=False)

### 2. CDM prediction
#### 2.1. Training and testing, sequential version

In [3]:
import warnings

gc.collect()
torch.cuda.empty_cache()

In [4]:
config = liriscat.utils.generate_eval_config(load_params=True, esc = 'error', valid_metric= 'mi_acc', pred_metrics = ["mi_acc"], profile_metrics = ['meta_doa'], save_params=False, n_query=6, num_epochs=100, batch_size=512)
liriscat.utils.set_seed(config["seed"])

config["dataset_name"] = "math2"
logging.info(config["dataset_name"])
config['learning_rate'] = 0.0001
config['inner_user_lr'] = 0.016848380924625605
config['lambda'] = 9.972254466547545e-06

config['meta_lr'] = 0.5
config['patience'] = 20#5
config['num_inner_users_epochs'] = 3
config['d_in'] = 4
config['num_responses'] = 12

#pred_metrics,df_interp = test(config)

CUDA is available. Using GPU.
[INFO 21:04] math2


In [5]:
logging.info(f'#### {config["dataset_name"]} ####')
logging.info(f'#### config : {config} ####')
config['embs_path']='../embs/'
config['params_path']='../ckpt/'

gc.collect()
torch.cuda.empty_cache()

# Dataset downloading for doa and rm
warnings.filterwarnings("ignore", message="invalid value encountered in divide")
warnings.filterwarnings("ignore", category=RuntimeWarning)

## Concept map format : {question_id : [category_id1, category_id2, ...]}
concept_map = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_concept_map.json', 'r'))
concept_map = {int(k): [int(x) for x in v] for k, v in concept_map.items()}

## Metadata map format : {"num_user_id": ..., "num_item_id": ..., "num_dimension_id": ...}
metadata = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_metadata.json', 'r'))


## Tensor containing the nb of modalities per question
nb_modalities = torch.load(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_nb_modalities.pkl',weights_only=True)


[INFO 21:04] #### math2 ####
[INFO 21:05] #### config : {'seed': 0, 'dataset_name': 'math2', 'load_params': True, 'save_params': False, 'embs_path': '../embs/', 'params_path': '../ckpt/', 'early_stopping': True, 'esc': 'error', 'verbose_early_stopping': False, 'disable_tqdm': False, 'valid_metric': 'mi_acc', 'learning_rate': 0.0001, 'batch_size': 512, 'valid_batch_size': 10000, 'num_epochs': 100, 'eval_freq': 1, 'patience': 20, 'device': device(type='cuda'), 'lambda': 9.972254466547545e-06, 'tensorboard': False, 'flush_freq': True, 'pred_metrics': ['mi_acc'], 'profile_metrics': ['meta_doa'], 'num_responses': 12, 'low_mem': False, 'n_query': 6, 'CDM': 'impact', 'i_fold': 0, 'num_inner_users_epochs': 3, 'num_inner_epochs': 10, 'inner_lr': 0.0001, 'inner_user_lr': 0.016848380924625605, 'meta_lr': 0.5, 'meta_trainer': 'Adam', 'num_workers': 0, 'pin_memory': False, 'debug': False, 'd_in': 4} ####


In [6]:
meta_trainers = ['Approx_GAP']
for meta_trainer in meta_trainers : 
    config['meta_trainer'] = meta_trainer
    logging.info(f'#### meta_trainer : {config["meta_trainer"]} ####')
    for i_fold in range(1,2) : 
        config['i_fold'] = i_fold
            
        logging.info(f'#### i_fold : {i_fold} ####')
        ## Dataframe columns : (user_id, question_id, response, category_id)
        train_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_train_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        valid_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_valid_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        test_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_test_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})

        train_data = liriscat.dataset.CATDataset(train_df, concept_map, metadata, config,nb_modalities)
        valid_data = liriscat.dataset.EvalDataset(valid_df, concept_map, metadata, config,nb_modalities)
        test_data = liriscat.dataset.EvalDataset(test_df, concept_map, metadata, config,nb_modalities)

        for seed in range(1) :
            config['seed'] = seed
            logging.info(f'#### seed : {seed} ####')

            train_data.reset_rng()
            valid_data.reset_rng()
            test_data.reset_rng()

            S = liriscat.selectionStrategy.Random(metadata,**config)
            S.init_models(train_data, valid_data)
            S.train(train_data, valid_data)
            liriscat.utils.set_seed(0)
            S.reset_rng()
            d = (S.evaluate_test(test_data, train_data, valid_data))
            logging.info(d)
            logging.info(liriscat.utils.pareto_index(d))

    torch.cuda.empty_cache()

[INFO 21:05] #### meta_trainer : Approx_GAP ####
[INFO 21:05] #### i_fold : 1 ####
[INFO 21:15] #### seed : 0 ####
[INFO 21:15] Random_cont_model
[INFO 21:15] compiling CDM model
[INFO 21:17] compiling selection model
[INFO 21:17] ------- START Training
[INFO 21:17] train on cuda


  0%|          | 0/100 [00:00<?, ?it/s]

[INFO 21:18] ----- User batch : 0
[INFO 21:23] ----- User batch : 1
[INFO 21:26] ----- User batch : 2
[INFO 21:29] ----- User batch : 3
[INFO 21:32] ----- User batch : 4
[INFO 21:34] - meta_params gn: [0.32099825143814087]
[INFO 21:34] - meta_params norm: [34.032470703125]
[INFO 21:34] - meta_params: 0.5
[INFO 21:41] rmse : 0.5917553305625916
[INFO 21:41] valid_rmse : 0.7567203640937805
[INFO 21:41] valid_loss : 0.9675600528717041


  1%|          | 1/100 [00:23<38:27, 23.31s/it]

[INFO 21:41] ----- User batch : 0
[INFO 21:45] ----- User batch : 1
[INFO 21:48] ----- User batch : 2
[INFO 21:52] ----- User batch : 3
[INFO 21:55] ----- User batch : 4
[INFO 21:57] - meta_params gn: [0.30382025241851807]
[INFO 21:57] - meta_params norm: [34.774452209472656]
[INFO 21:57] - meta_params: 0.5
[INFO 22:02] rmse : 0.5900689363479614
[INFO 22:02] valid_rmse : 0.7545638680458069
[INFO 22:02] valid_loss : 0.9673460721969604


  2%|▏         | 2/100 [00:44<35:39, 21.83s/it]

[INFO 22:02] ----- User batch : 0
[INFO 22:05] ----- User batch : 1
[INFO 22:09] ----- User batch : 2
[INFO 22:12] ----- User batch : 3
[INFO 22:15] ----- User batch : 4
[INFO 22:17] - meta_params gn: [0.24960941076278687]
[INFO 22:17] - meta_params norm: [35.700618743896484]
[INFO 22:17] - meta_params: 0.5
[INFO 22:22] rmse : 0.5819918513298035
[INFO 22:22] valid_rmse : 0.744235098361969
[INFO 22:22] valid_loss : 0.9673393368721008


  3%|▎         | 3/100 [01:04<34:08, 21.11s/it]

[INFO 22:22] ----- User batch : 0
[INFO 22:26] ----- User batch : 1
[INFO 22:30] ----- User batch : 2
[INFO 22:33] ----- User batch : 3
[INFO 22:36] ----- User batch : 4
[INFO 22:38] - meta_params gn: [0.2509273290634155]
[INFO 22:38] - meta_params norm: [36.580963134765625]
[INFO 22:38] - meta_params: 0.5
[INFO 22:43] rmse : 0.5885470509529114
[INFO 22:43] valid_rmse : 0.7526177167892456
[INFO 22:43] valid_loss : 0.9660078883171082


  4%|▍         | 4/100 [01:25<33:47, 21.12s/it]

[INFO 22:43] ----- User batch : 0
[INFO 22:47] ----- User batch : 1
[INFO 22:51] ----- User batch : 2
[INFO 22:55] ----- User batch : 3
[INFO 22:58] ----- User batch : 4
[INFO 23:00] - meta_params gn: [0.2598417401313782]
[INFO 23:00] - meta_params norm: [37.65623474121094]
[INFO 23:00] - meta_params: 0.5
[INFO 23:05] rmse : 0.5839575529098511
[INFO 23:05] valid_rmse : 0.7467487454414368
[INFO 23:05] valid_loss : 0.9647387862205505


  5%|▌         | 5/100 [01:47<33:58, 21.46s/it]

[INFO 23:06] ----- User batch : 0
[INFO 23:09] ----- User batch : 1
[INFO 23:13] ----- User batch : 2
[INFO 23:17] ----- User batch : 3
[INFO 23:21] ----- User batch : 4
[INFO 23:23] - meta_params gn: [0.21979454159736633]
[INFO 23:23] - meta_params norm: [38.707122802734375]
[INFO 23:23] - meta_params: 0.5
[INFO 23:27] rmse : 0.5862568020820618
[INFO 23:27] valid_rmse : 0.749688982963562
[INFO 23:27] valid_loss : 0.9663800001144409


  6%|▌         | 6/100 [02:10<34:10, 21.82s/it]

[INFO 23:28] ----- User batch : 0
[INFO 23:32] ----- User batch : 1
[INFO 23:36] ----- User batch : 2
[INFO 23:39] ----- User batch : 3
[INFO 23:43] ----- User batch : 4
[INFO 23:45] - meta_params gn: [0.2516552209854126]
[INFO 23:45] - meta_params norm: [39.7512092590332]
[INFO 23:45] - meta_params: 0.5
[INFO 23:50] rmse : 0.5853211879730225
[INFO 23:50] valid_rmse : 0.7484925389289856
[INFO 23:50] valid_loss : 0.9647696614265442


  7%|▋         | 7/100 [02:32<34:02, 21.96s/it]

[INFO 23:50] ----- User batch : 0
[INFO 23:54] ----- User batch : 1
[INFO 23:58] ----- User batch : 2
[INFO 24:02] ----- User batch : 3
[INFO 24:06] ----- User batch : 4
[INFO 24:07] - meta_params gn: [0.25256407260894775]
[INFO 24:07] - meta_params norm: [40.740230560302734]
[INFO 24:07] - meta_params: 0.5
[INFO 24:12] rmse : 0.5838722586631775
[INFO 24:12] valid_rmse : 0.7466397285461426
[INFO 24:12] valid_loss : 0.9649967551231384


  8%|▊         | 8/100 [02:54<33:54, 22.11s/it]

[INFO 24:13] ----- User batch : 0
[INFO 24:17] ----- User batch : 1
[INFO 24:20] ----- User batch : 2
[INFO 24:24] ----- User batch : 3
[INFO 24:28] ----- User batch : 4
[INFO 24:30] - meta_params gn: [0.23202098906040192]
[INFO 24:30] - meta_params norm: [41.81241226196289]
[INFO 24:30] - meta_params: 0.25
[INFO 24:35] rmse : 0.5854063034057617
[INFO 24:35] valid_rmse : 0.7486013770103455
[INFO 24:35] valid_loss : 0.966931164264679


  9%|▉         | 9/100 [03:17<33:49, 22.30s/it]

[INFO 24:35] ----- User batch : 0
[INFO 24:39] ----- User batch : 1
[INFO 24:42] ----- User batch : 2
[INFO 24:46] ----- User batch : 3
[INFO 24:49] ----- User batch : 4
[INFO 24:51] - meta_params gn: [0.20894889533519745]
[INFO 24:51] - meta_params norm: [42.34563064575195]
[INFO 24:51] - meta_params: 0.25
[INFO 24:56] rmse : 0.5753619074821472
[INFO 24:56] valid_rmse : 0.7357568740844727
[INFO 24:56] valid_loss : 0.961367130279541


 10%|█         | 10/100 [03:38<32:51, 21.90s/it]

[INFO 24:56] ----- User batch : 0
[INFO 25:00] ----- User batch : 1
[INFO 25:04] ----- User batch : 2
[INFO 25:08] ----- User batch : 3
[INFO 25:11] ----- User batch : 4
[INFO 25:13] - meta_params gn: [0.24030670523643494]
[INFO 25:13] - meta_params norm: [42.84825897216797]
[INFO 25:13] - meta_params: 0.25
[INFO 25:18] rmse : 0.5837015509605408
[INFO 25:18] valid_rmse : 0.7464213967323303
[INFO 25:18] valid_loss : 0.9669182896614075


 11%|█         | 11/100 [04:00<32:29, 21.91s/it]

[INFO 25:18] ----- User batch : 0
[INFO 25:22] ----- User batch : 1
[INFO 25:26] ----- User batch : 2
[INFO 25:29] ----- User batch : 3
[INFO 25:33] ----- User batch : 4
[INFO 25:35] - meta_params gn: [0.1800566464662552]
[INFO 25:35] - meta_params norm: [43.373558044433594]
[INFO 25:35] - meta_params: 0.25
[INFO 25:40] rmse : 0.5816492438316345
[INFO 25:40] valid_rmse : 0.7437969446182251
[INFO 25:40] valid_loss : 0.9637653231620789


 12%|█▏        | 12/100 [04:22<32:05, 21.88s/it]

[INFO 25:40] ----- User batch : 0
[INFO 25:44] ----- User batch : 1
[INFO 25:48] ----- User batch : 2
[INFO 25:51] ----- User batch : 3
[INFO 25:55] ----- User batch : 4
[INFO 25:57] - meta_params gn: [0.18259382247924805]
[INFO 25:57] - meta_params norm: [43.90277862548828]
[INFO 25:57] - meta_params: 0.25
[INFO 26:02] rmse : 0.5782126188278198
[INFO 26:02] valid_rmse : 0.7394022941589355
[INFO 26:02] valid_loss : 0.9642583131790161


 13%|█▎        | 13/100 [04:44<31:49, 21.95s/it]

[INFO 26:02] ----- User batch : 0
[INFO 26:06] ----- User batch : 1
[INFO 26:10] ----- User batch : 2
[INFO 26:14] ----- User batch : 3
[INFO 26:17] ----- User batch : 4
[INFO 26:19] - meta_params gn: [0.1730150729417801]
[INFO 26:19] - meta_params norm: [44.435298919677734]
[INFO 26:19] - meta_params: 0.125
[INFO 26:24] rmse : 0.587699830532074
[INFO 26:24] valid_rmse : 0.7515342831611633
[INFO 26:24] valid_loss : 0.9673075079917908


 14%|█▍        | 14/100 [05:06<31:36, 22.05s/it]

[INFO 26:25] ----- User batch : 0
[INFO 26:28] ----- User batch : 1
[INFO 26:33] ----- User batch : 2
[INFO 26:37] ----- User batch : 3
[INFO 26:40] ----- User batch : 4
[INFO 26:42] - meta_params gn: [0.21412493288516998]
[INFO 26:42] - meta_params norm: [44.707908630371094]
[INFO 26:42] - meta_params: 0.125
[INFO 26:47] rmse : 0.5832746028900146
[INFO 26:47] valid_rmse : 0.7458754181861877
[INFO 26:47] valid_loss : 0.9643703103065491


 15%|█▌        | 15/100 [05:29<31:42, 22.39s/it]

[INFO 26:48] ----- User batch : 0
[INFO 26:52] ----- User batch : 1
[INFO 26:55] ----- User batch : 2
[INFO 26:59] ----- User batch : 3
[INFO 27:03] ----- User batch : 4
[INFO 27:05] - meta_params gn: [0.22235871851444244]
[INFO 27:05] - meta_params norm: [44.989967346191406]
[INFO 27:05] - meta_params: 0.125
[INFO 27:09] rmse : 0.5880388617515564
[INFO 27:09] valid_rmse : 0.7519678473472595
[INFO 27:09] valid_loss : 0.9649600982666016


 16%|█▌        | 16/100 [05:51<31:15, 22.32s/it]

[INFO 27:10] ----- User batch : 0
[INFO 27:14] ----- User batch : 1
[INFO 27:17] ----- User batch : 2
[INFO 27:21] ----- User batch : 3
[INFO 27:24] ----- User batch : 4
[INFO 27:26] - meta_params gn: [0.20049700140953064]
[INFO 27:26] - meta_params norm: [45.268333435058594]
[INFO 27:26] - meta_params: 0.0625
[INFO 27:31] rmse : 0.5798475742340088
[INFO 27:31] valid_rmse : 0.7414930462837219
[INFO 27:31] valid_loss : 0.9626057147979736


 17%|█▋        | 17/100 [06:13<30:31, 22.07s/it]

[INFO 27:31] ----- User batch : 0
[INFO 27:35] ----- User batch : 1
[INFO 27:38] ----- User batch : 2
[INFO 27:42] ----- User batch : 3
[INFO 27:45] ----- User batch : 4
[INFO 27:47] - meta_params gn: [0.17482729256153107]
[INFO 27:47] - meta_params norm: [45.40971374511719]
[INFO 27:47] - meta_params: 0.0625
[INFO 27:51] rmse : 0.5850657820701599
[INFO 27:51] valid_rmse : 0.7481659650802612
[INFO 27:51] valid_loss : 0.966075599193573


 18%|█▊        | 18/100 [06:33<29:29, 21.58s/it]

[INFO 27:52] ----- User batch : 0
[INFO 27:55] ----- User batch : 1
[INFO 27:58] ----- User batch : 2
[INFO 28:02] ----- User batch : 3
[INFO 28:05] ----- User batch : 4
[INFO 28:07] - meta_params gn: [0.16437658667564392]
[INFO 28:07] - meta_params norm: [45.544979095458984]
[INFO 28:07] - meta_params: 0.0625
[INFO 28:12] rmse : 0.5856615900993347
[INFO 28:12] valid_rmse : 0.7489278316497803
[INFO 28:12] valid_loss : 0.9658194184303284


 19%|█▉        | 19/100 [06:54<28:42, 21.26s/it]

[INFO 28:12] ----- User batch : 0
[INFO 28:16] ----- User batch : 1
[INFO 28:19] ----- User batch : 2
[INFO 28:22] ----- User batch : 3
[INFO 28:26] ----- User batch : 4
[INFO 28:27] - meta_params gn: [0.21306151151657104]
[INFO 28:27] - meta_params norm: [45.67259216308594]
[INFO 28:27] - meta_params: 0.03125
[INFO 28:32] rmse : 0.586596667766571
[INFO 28:32] valid_rmse : 0.7501236200332642
[INFO 28:32] valid_loss : 0.9622656106948853


 20%|██        | 20/100 [07:14<28:00, 21.00s/it]

[INFO 28:33] ----- User batch : 0
[INFO 28:36] ----- User batch : 1
[INFO 28:39] ----- User batch : 2
[INFO 28:43] ----- User batch : 3
[INFO 28:46] ----- User batch : 4
[INFO 28:48] - meta_params gn: [0.15843690931797028]
[INFO 28:48] - meta_params norm: [45.738792419433594]
[INFO 28:48] - meta_params: 0.03125
[INFO 28:53] rmse : 0.5825908184051514
[INFO 28:53] valid_rmse : 0.7450010180473328
[INFO 28:53] valid_loss : 0.9621067643165588


 21%|██        | 21/100 [07:35<27:25, 20.83s/it]

[INFO 28:53] ----- User batch : 0
[INFO 28:57] ----- User batch : 1
[INFO 29:00] ----- User batch : 2
[INFO 29:03] ----- User batch : 3
[INFO 29:07] ----- User batch : 4
[INFO 29:09] - meta_params gn: [0.15965402126312256]
[INFO 29:09] - meta_params norm: [45.80405044555664]
[INFO 29:09] - meta_params: 0.03125
[INFO 29:13] rmse : 0.5805345773696899
[INFO 29:14] valid_rmse : 0.7423715591430664
[INFO 29:14] valid_loss : 0.9636510014533997


 22%|██▏       | 22/100 [07:56<27:06, 20.85s/it]

[INFO 29:14] ----- User batch : 0
[INFO 29:18] ----- User batch : 1
[INFO 29:21] ----- User batch : 2
[INFO 29:25] ----- User batch : 3
[INFO 29:28] ----- User batch : 4
[INFO 29:30] - meta_params gn: [0.18625697493553162]
[INFO 29:30] - meta_params norm: [45.87118911743164]
[INFO 29:30] - meta_params: 0.015625
[INFO 29:35] rmse : 0.5842134952545166
[INFO 29:35] valid_rmse : 0.7470760941505432
[INFO 29:35] valid_loss : 0.9619596004486084


 23%|██▎       | 23/100 [08:17<26:59, 21.04s/it]

[INFO 29:35] ----- User batch : 0
[INFO 29:39] ----- User batch : 1
[INFO 29:43] ----- User batch : 2
[INFO 29:47] ----- User batch : 3
[INFO 29:50] ----- User batch : 4
[INFO 29:52] - meta_params gn: [0.20612940192222595]
[INFO 29:52] - meta_params norm: [45.9033203125]
[INFO 29:52] - meta_params: 0.015625
[INFO 29:57] rmse : 0.5856615900993347
[INFO 29:57] valid_rmse : 0.7489278316497803
[INFO 29:57] valid_loss : 0.9657606482505798


 24%|██▍       | 24/100 [08:39<26:56, 21.28s/it]

[INFO 29:57] ----- User batch : 0
[INFO 30:01] ----- User batch : 1
[INFO 30:04] ----- User batch : 2
[INFO 30:08] ----- User batch : 3
[INFO 30:11] ----- User batch : 4
[INFO 30:13] - meta_params gn: [0.17991210520267487]
[INFO 30:13] - meta_params norm: [45.93521499633789]
[INFO 30:13] - meta_params: 0.015625
[INFO 30:18] rmse : 0.5789015889167786
[INFO 30:18] valid_rmse : 0.7402833700180054
[INFO 30:18] valid_loss : 0.9651282429695129


 25%|██▌       | 25/100 [09:00<26:25, 21.13s/it]

[INFO 30:18] ----- User batch : 0
[INFO 30:22] ----- User batch : 1
[INFO 30:25] ----- User batch : 2
[INFO 30:29] ----- User batch : 3
[INFO 30:33] ----- User batch : 4
[INFO 30:34] - meta_params gn: [0.17730624973773956]
[INFO 30:34] - meta_params norm: [45.968502044677734]
[INFO 30:34] - meta_params: 0.0078125
[INFO 30:39] rmse : 0.58147794008255
[INFO 30:39] valid_rmse : 0.7435778975486755
[INFO 30:39] valid_loss : 0.9635178446769714


 26%|██▌       | 26/100 [09:21<26:09, 21.20s/it]

[INFO 30:39] ----- User batch : 0
[INFO 30:43] ----- User batch : 1
[INFO 30:46] ----- User batch : 2
[INFO 30:49] ----- User batch : 3
[INFO 30:52] ----- User batch : 4
[INFO 30:54] - meta_params gn: [0.17675575613975525]
[INFO 30:54] - meta_params norm: [45.98550796508789]
[INFO 30:54] - meta_params: 0.0078125
[INFO 30:59] rmse : 0.5779540538787842
[INFO 30:59] valid_rmse : 0.7390716671943665
[INFO 30:59] valid_loss : 0.9643872976303101


 27%|██▋       | 27/100 [09:41<25:13, 20.73s/it]

[INFO 30:59] ----- User batch : 0
[INFO 31:02] ----- User batch : 1
[INFO 31:06] ----- User batch : 2
[INFO 31:09] ----- User batch : 3
[INFO 31:12] ----- User batch : 4
[INFO 31:14] - meta_params gn: [0.15672211349010468]
[INFO 31:14] - meta_params norm: [46.00244140625]
[INFO 31:14] - meta_params: 0.0078125
[INFO 31:19] rmse : 0.5819918513298035
[INFO 31:19] valid_rmse : 0.744235098361969
[INFO 31:19] valid_loss : 0.9605883359909058


 28%|██▊       | 28/100 [10:01<24:41, 20.58s/it]

[INFO 31:20] ----- User batch : 0
[INFO 31:23] ----- User batch : 1
[INFO 31:27] ----- User batch : 2
[INFO 31:31] ----- User batch : 3
[INFO 31:34] ----- User batch : 4
[INFO 31:36] - meta_params gn: [0.18341712653636932]
[INFO 31:36] - meta_params norm: [46.01991653442383]
[INFO 31:36] - meta_params: 0.0078125
[INFO 31:41] rmse : 0.5789876580238342
[INFO 31:41] valid_rmse : 0.7403934001922607
[INFO 31:41] valid_loss : 0.9628831148147583


 29%|██▉       | 29/100 [10:23<24:46, 20.94s/it]

[INFO 31:41] ----- User batch : 0
[INFO 31:45] ----- User batch : 1
[INFO 31:48] ----- User batch : 2
[INFO 31:52] ----- User batch : 3
[INFO 31:55] ----- User batch : 4
[INFO 31:57] - meta_params gn: [0.17030194401741028]
[INFO 31:57] - meta_params norm: [46.03704833984375]
[INFO 31:57] - meta_params: 0.0078125
[INFO 32:01] rmse : 0.584810197353363
[INFO 32:01] valid_rmse : 0.747839093208313
[INFO 32:01] valid_loss : 0.9649723768234253


 29%|██▉       | 29/100 [10:43<26:15, 22.20s/it]

[INFO 32:01] -- END Training --
[INFO 32:01] train on cuda
Learning users emb shape: torch.Size([3911, 16])



100%|██████████| 6/6 [00:04<00:00,  1.42it/s]


[INFO 32:12] ({0: {'mi_acc': 0.635077178478241}, 1: {'mi_acc': 0.6367712616920471}, 2: {'mi_acc': 0.642551064491272}, 3: {'mi_acc': 0.6440458297729492}, 4: {'mi_acc': 0.651419997215271}, 5: {'mi_acc': 0.653911292552948}}, {0: {'meta_doa': 0.5062931470726217}, 1: {'meta_doa': 0.5107671239843226}, 2: {'meta_doa': 0.5167609327968822}, 3: {'meta_doa': 0.5239388405435574}, 4: {'meta_doa': 0.5335979268072653}, 5: {'meta_doa': 0.543656102709063}})
[INFO 32:12] 0.01996683782378124


In [None]:
config['meta_trainer'] = 'Approx_GAP'
#config['meta_trainer'] = 'Adam'
for seed in range(1) :
    config['seed'] = seed
    logging.info(f'#### seed : {seed} ####')
    
    train_data.reset_rng()
    valid_data.reset_rng()
    test_data.reset_rng()


    S = liriscat.selectionStrategy.Random(metadata,**config)
    S.init_models(train_data, valid_data)
    S.train(train_data, valid_data)
    liriscat.utils.set_seed(0)
    S.reset_rng()
    d = (S.evaluate_test(test_data, train_data, valid_data))
    print(liriscat.utils.pareto_index(d))
    print(d)
    torch.cuda.empty_cache()