# Liriscat paper experiments
### 1. Init
#### 1.1. Import libraries (necessary)

In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["PYTHONHASHSEED"] = "0"
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"

import liriscat
liriscat.utils.set_seed(0)

import logging
import gc
import json
import torch
liriscat.utils.set_seed(0)
import pandas as pd

#### 1.2. Set up the loggers (recommended)

In [2]:
liriscat.utils.setuplogger(verbose = True, log_name="liriscat", debug=True)

### 2. CDM prediction
#### 2.1. Training and testing, sequential version

In [3]:
import warnings

gc.collect()
torch.cuda.empty_cache()

In [4]:
config = liriscat.utils.generate_eval_config(load_params=True, esc = 'error', valid_metric= 'mi_acc', pred_metrics = ["mi_acc"], profile_metrics = ['meta_doa'], save_params=False, n_query=6, num_epochs=100, batch_size=512)
liriscat.utils.set_seed(config["seed"])

config["dataset_name"] = "assist0910"
logging.info(config["dataset_name"])
config['learning_rate'] = 0.0001
config['inner_user_lr'] = 0.0016969685554352153
config['lambda'] = 2.2656270501845414e-06

config['meta_lr'] = 0.5
config['learning_users_emb_lr'] = 0.0005
config['patience'] = 20#5
config['num_inner_users_epochs'] = 3
config['d_in'] = 4
config['num_responses'] = 12

#pred_metrics,df_interp = test(config)

CUDA is available. Using GPU.
[INFO 05:05] math2


In [5]:
def pareto_index2(d) : 
    d_acc = d[0]
    d_meta = d[1]

    r = []

    for i in range(len(d_acc)):
        r.append((0.5-d_acc[i]['mi_acc'])*(0.5-d_meta[i]['meta_doa']))
    return sum(r)

In [6]:
logging.info(f'#### {config["dataset_name"]} ####')
logging.info(f'#### config : {config} ####')
config['embs_path']='../embs/'
config['params_path']='../ckpt/'

gc.collect()
torch.cuda.empty_cache()

# Dataset downloading for doa and rm
warnings.filterwarnings("ignore", message="invalid value encountered in divide")
warnings.filterwarnings("ignore", category=RuntimeWarning)

## Concept map format : {question_id : [category_id1, category_id2, ...]}
concept_map = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_concept_map.json', 'r'))
concept_map = {int(k): [int(x) for x in v] for k, v in concept_map.items()}

## Metadata map format : {"num_user_id": ..., "num_item_id": ..., "num_dimension_id": ...}
metadata = json.load(open(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_metadata.json', 'r'))


## Tensor containing the nb of modalities per question
nb_modalities = torch.load(f'../datasets/2-preprocessed_data/{config["dataset_name"]}_nb_modalities.pkl',weights_only=True)


[INFO 41:03] #### assist0910 ####
[INFO 41:03] #### config : {'seed': 0, 'dataset_name': 'assist0910', 'load_params': True, 'save_params': False, 'embs_path': '../embs/', 'params_path': '../ckpt/', 'early_stopping': True, 'esc': 'error', 'verbose_early_stopping': False, 'disable_tqdm': False, 'valid_metric': 'mi_acc', 'learning_rate': 0.0001, 'batch_size': 512, 'valid_batch_size': 10000, 'num_epochs': 100, 'eval_freq': 1, 'patience': 20, 'device': device(type='cuda'), 'lambda': 2.2656270501845414e-06, 'tensorboard': False, 'flush_freq': True, 'pred_metrics': ['mi_acc'], 'profile_metrics': ['meta_doa'], 'num_responses': 12, 'low_mem': False, 'n_query': 6, 'CDM': 'impact', 'i_fold': 0, 'num_inner_users_epochs': 3, 'num_inner_epochs': 10, 'inner_lr': 0.0001, 'inner_user_lr': 0.0016969685554352153, 'meta_lr': 0.5, 'meta_trainer': 'Adam', 'num_workers': 0, 'pin_memory': False, 'debug': True, 'learning_users_emb_lr': 0.0005, 'd_in': 4} ####


[INFO 38:14] #### config : {'seed': 0, 'dataset_name': 'math2', 'load_params': True, 'save_params': False, 'embs_path': '../embs/', 'params_path': '../ckpt/', 'early_stopping': True, 'esc': 'error', 'verbose_early_stopping': False, 'disable_tqdm': False, 'valid_metric': 'mi_acc', 'learning_rate': 0.0001, 'batch_size': 512, 'valid_batch_size': 10000, 'num_epochs': 100, 'eval_freq': 1, 'patience': 20, 'device': device(type='cuda'), 'lambda': 9.972254466547545e-06, 'tensorboard': False, 'flush_freq': True, 'pred_metrics': ['mi_acc'], 'profile_metrics': ['meta_doa'], 'num_responses': 12, 'low_mem': False, 'n_query': 6, 'CDM': 'impact', 'i_fold': 0, 'num_inner_users_epochs': 3, 'num_inner_epochs': 10, 'inner_lr': 0.0001, 'inner_user_lr': 0.016848380924625605, 'meta_lr': 0.5, 'meta_trainer': 'Adam', 'num_workers': 0, 'pin_memory': False, 'debug': False, 'learning_users_emb_lr': 0.001, 'd_in': 4} ####


In [8]:
meta_trainers = ['GAP']
for meta_trainer in meta_trainers : 
    config['meta_trainer'] = meta_trainer
    logging.info(f'#### meta_trainer : {config["meta_trainer"]} ####')
    for i_fold in range(1,2) : 
        config['i_fold'] = i_fold
            
        logging.info(f'#### i_fold : {i_fold} ####')
        ## Dataframe columns : (user_id, question_id, response, category_id)
        train_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_train_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        valid_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_valid_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})
        test_df = pd.read_csv(
            f'../datasets/2-preprocessed_data/{config["dataset_name"]}_test_{i_fold}.csv',
            encoding='utf-8', dtype={'student_id': int, 'item_id': int, "correct": float,
                                                                    "dimension_id": int})

        train_data = liriscat.dataset.CATDataset(train_df, concept_map, metadata, config,nb_modalities)
        valid_data = liriscat.dataset.EvalDataset(valid_df, concept_map, metadata, config,nb_modalities)
        test_data = liriscat.dataset.EvalDataset(test_df, concept_map, metadata, config,nb_modalities)

        for seed in range(1) :
            config['seed'] = seed
            logging.info(f'#### seed : {seed} ####')

            train_data.reset_rng()
            valid_data.reset_rng()
            test_data.reset_rng()

            S = liriscat.selectionStrategy.Random(metadata,**config)
            S.init_models(train_data, valid_data)
            S.train(train_data, valid_data)
            liriscat.utils.set_seed(0)
            S.reset_rng()
            d = (S.evaluate_test(test_data, train_data, valid_data))
            logging.info(d)
            logging.info(liriscat.utils.pareto_index(d))

    torch.cuda.empty_cache()

[INFO 41:03] #### meta_trainer : GAP ####
[INFO 41:03] #### i_fold : 1 ####
[INFO 41:20] #### seed : 0 ####
[DEBUG 41:20] ------- Abstract model __init__()
[DEBUG 41:20] ----- Meta trainer init : GAP
[INFO 41:20] Random_cont_model
[DEBUG 41:20] ------- Initialize CDM and Selection strategy
[INFO 41:21] compiling CDM model
[INFO 41:25] compiling selection model
[INFO 41:25] ------- START Training
[INFO 41:25] train on cuda


  0%|          | 0/100 [00:00<?, ?it/s]

[INFO 05:35] - meta_params: 0.5
[INFO 05:35] - cross_cond: 0.5
[INFO 05:35] - meta_lambda: 0.5
[INFO 05:35] - learning_users_emb: 0.0005
[INFO 05:41] rmse : 0.5449409484863281
[INFO 05:41] valid_rmse : 0.6968554258346558
[INFO 05:41] valid_loss : 0.9390125274658203


  1%|          | 1/100 [00:21<35:47, 21.69s/it]

[INFO 05:55] - meta_params: 0.5
[INFO 05:55] - cross_cond: 0.5
[INFO 05:55] - meta_lambda: 0.5
[INFO 05:55] - learning_users_emb: 0.0005
[INFO 05:59] rmse : 0.554278552532196
[INFO 05:59] valid_rmse : 0.7087960839271545
[INFO 05:59] valid_loss : 0.9394581317901611


  2%|▏         | 2/100 [00:40<32:21, 19.82s/it]

[INFO 06:13] - meta_params: 0.5
[INFO 06:13] - cross_cond: 0.5
[INFO 06:13] - meta_lambda: 0.5
[INFO 06:13] - learning_users_emb: 0.0005
[INFO 06:18] rmse : 0.5512136816978455
[INFO 06:18] valid_rmse : 0.7048768401145935
[INFO 06:18] valid_loss : 0.9408632516860962


  3%|▎         | 3/100 [00:58<30:51, 19.09s/it]

[INFO 06:32] - meta_params: 0.5
[INFO 06:32] - cross_cond: 0.5
[INFO 06:32] - meta_lambda: 0.5
[INFO 06:32] - learning_users_emb: 0.0005
[INFO 06:36] rmse : 0.5538288950920105
[INFO 06:36] valid_rmse : 0.7082210779190063
[INFO 06:36] valid_loss : 0.9397975206375122


  4%|▍         | 4/100 [01:17<30:16, 18.92s/it]

[INFO 06:50] - meta_params: 0.25
[INFO 06:50] - cross_cond: 0.25
[INFO 06:50] - meta_lambda: 0.25
[INFO 06:50] - learning_users_emb: 0.00025
[INFO 06:55] rmse : 0.5370201468467712
[INFO 06:55] valid_rmse : 0.6867265105247498
[INFO 06:55] valid_loss : 0.936843991279602


  5%|▌         | 5/100 [01:35<29:43, 18.77s/it]

[INFO 07:09] - meta_params: 0.25
[INFO 07:09] - cross_cond: 0.25
[INFO 07:09] - meta_lambda: 0.25
[INFO 07:09] - learning_users_emb: 0.00025
[INFO 07:13] rmse : 0.5508520007133484
[INFO 07:13] valid_rmse : 0.7044143080711365
[INFO 07:13] valid_loss : 0.9387009739875793


  6%|▌         | 6/100 [01:54<29:19, 18.71s/it]

[INFO 07:27] - meta_params: 0.25
[INFO 07:27] - cross_cond: 0.25
[INFO 07:27] - meta_lambda: 0.25
[INFO 07:27] - learning_users_emb: 0.00025
[INFO 07:32] rmse : 0.5430174469947815
[INFO 07:32] valid_rmse : 0.6943957209587097
[INFO 07:32] valid_loss : 0.9376018047332764


  7%|▋         | 7/100 [02:12<28:54, 18.65s/it]

[INFO 07:46] - meta_params: 0.25
[INFO 07:46] - cross_cond: 0.25
[INFO 07:46] - meta_lambda: 0.25
[INFO 07:46] - learning_users_emb: 0.00025
[INFO 07:50] rmse : 0.545397937297821
[INFO 07:50] valid_rmse : 0.6974397897720337
[INFO 07:50] valid_loss : 0.9389498233795166


  8%|▊         | 8/100 [02:31<28:35, 18.65s/it]

[INFO 08:04] - meta_params: 0.125
[INFO 08:04] - cross_cond: 0.125
[INFO 08:04] - meta_lambda: 0.125
[INFO 08:05] - learning_users_emb: 0.000125
[INFO 08:09] rmse : 0.5460370182991028
[INFO 08:09] valid_rmse : 0.6982570290565491
[INFO 08:09] valid_loss : 0.940470814704895


  9%|▉         | 9/100 [02:49<28:12, 18.59s/it]

[INFO 08:23] - meta_params: 0.125
[INFO 08:23] - cross_cond: 0.125
[INFO 08:23] - meta_lambda: 0.125
[INFO 08:23] - learning_users_emb: 0.000125
[INFO 08:28] rmse : 0.5362773537635803
[INFO 08:28] valid_rmse : 0.6857766509056091
[INFO 08:28] valid_loss : 0.9353311657905579


 10%|█         | 10/100 [03:08<27:58, 18.65s/it]

[INFO 08:42] - meta_params: 0.125
[INFO 08:42] - cross_cond: 0.125
[INFO 08:42] - meta_lambda: 0.125
[INFO 08:42] - learning_users_emb: 0.000125
[INFO 08:47] rmse : 0.5597351789474487
[INFO 08:47] valid_rmse : 0.71577388048172
[INFO 08:47] valid_loss : 0.9416601657867432


 11%|█         | 11/100 [03:27<27:45, 18.71s/it]

[INFO 09:01] - meta_params: 0.125
[INFO 09:01] - cross_cond: 0.125
[INFO 09:01] - meta_lambda: 0.125
[INFO 09:01] - learning_users_emb: 0.000125
[INFO 09:05] rmse : 0.5525679588317871
[INFO 09:05] valid_rmse : 0.7066086530685425
[INFO 09:05] valid_loss : 0.9380117058753967


 12%|█▏        | 12/100 [03:46<27:29, 18.75s/it]

[INFO 09:20] - meta_params: 0.125
[INFO 09:20] - cross_cond: 0.125
[INFO 09:20] - meta_lambda: 0.125
[INFO 09:20] - learning_users_emb: 0.000125
[INFO 09:24] rmse : 0.5465842485427856
[INFO 09:24] valid_rmse : 0.6989568471908569
[INFO 09:24] valid_loss : 0.9385941624641418


 13%|█▎        | 13/100 [04:04<27:07, 18.70s/it]

[INFO 09:38] - meta_params: 0.0625
[INFO 09:38] - cross_cond: 0.0625
[INFO 09:38] - meta_lambda: 0.0625
[INFO 09:38] - learning_users_emb: 6.25e-05
[INFO 09:43] rmse : 0.5503090620040894
[INFO 09:43] valid_rmse : 0.7037200331687927
[INFO 09:43] valid_loss : 0.9420244097709656


 14%|█▍        | 14/100 [04:23<26:51, 18.74s/it]

[INFO 09:57] - meta_params: 0.0625
[INFO 09:57] - cross_cond: 0.0625
[INFO 09:57] - meta_lambda: 0.0625
[INFO 09:57] - learning_users_emb: 6.25e-05
[INFO 10:02] rmse : 0.5545481443405151
[INFO 10:02] valid_rmse : 0.7091408371925354
[INFO 10:02] valid_loss : 0.9388293623924255


 15%|█▌        | 15/100 [04:42<26:35, 18.77s/it]

[INFO 10:16] - meta_params: 0.0625
[INFO 10:16] - cross_cond: 0.0625
[INFO 10:16] - meta_lambda: 0.0625
[INFO 10:16] - learning_users_emb: 6.25e-05
[INFO 10:21] rmse : 0.5448495149612427
[INFO 10:21] valid_rmse : 0.6967384815216064
[INFO 10:21] valid_loss : 0.9384444952011108


 16%|█▌        | 16/100 [05:01<26:22, 18.84s/it]

[INFO 10:35] - meta_params: 0.03125
[INFO 10:35] - cross_cond: 0.03125
[INFO 10:35] - meta_lambda: 0.03125
[INFO 10:35] - learning_users_emb: 3.125e-05
[INFO 10:39] rmse : 0.5437510013580322
[INFO 10:39] valid_rmse : 0.69533371925354
[INFO 10:39] valid_loss : 0.9371609687805176


 17%|█▋        | 17/100 [05:20<26:02, 18.82s/it]

[INFO 10:53] - meta_params: 0.03125
[INFO 10:53] - cross_cond: 0.03125
[INFO 10:53] - meta_lambda: 0.03125
[INFO 10:53] - learning_users_emb: 3.125e-05
[INFO 10:58] rmse : 0.5484952330589294
[INFO 10:58] valid_rmse : 0.7014005184173584
[INFO 10:58] valid_loss : 0.9402368068695068


 18%|█▊        | 18/100 [05:38<25:35, 18.73s/it]

[INFO 11:12] - meta_params: 0.03125
[INFO 11:12] - cross_cond: 0.03125
[INFO 11:12] - meta_lambda: 0.03125
[INFO 11:12] - learning_users_emb: 3.125e-05
[INFO 11:17] rmse : 0.5575053095817566
[INFO 11:17] valid_rmse : 0.7129223942756653
[INFO 11:17] valid_loss : 0.9405050277709961


 19%|█▉        | 19/100 [05:57<25:16, 18.73s/it]

[INFO 11:31] - meta_params: 0.015625
[INFO 11:31] - cross_cond: 0.015625
[INFO 11:31] - meta_lambda: 0.015625
[INFO 11:31] - learning_users_emb: 1.5625e-05
[INFO 11:35] rmse : 0.5499467849731445
[INFO 11:35] valid_rmse : 0.7032567262649536
[INFO 11:35] valid_loss : 0.935683012008667


 20%|██        | 20/100 [06:16<24:57, 18.71s/it]

[INFO 11:50] - meta_params: 0.015625
[INFO 11:50] - cross_cond: 0.015625
[INFO 11:50] - meta_lambda: 0.015625
[INFO 11:50] - learning_users_emb: 1.5625e-05
[INFO 11:54] rmse : 0.5442090034484863
[INFO 11:54] valid_rmse : 0.6959194540977478
[INFO 11:54] valid_loss : 0.9363977909088135


 21%|██        | 21/100 [06:34<24:38, 18.71s/it]

[INFO 12:08] - meta_params: 0.015625
[INFO 12:08] - cross_cond: 0.015625
[INFO 12:08] - meta_lambda: 0.015625
[INFO 12:08] - learning_users_emb: 1.5625e-05
[INFO 12:13] rmse : 0.5479499101638794
[INFO 12:13] valid_rmse : 0.7007032036781311
[INFO 12:13] valid_loss : 0.9382624626159668


 22%|██▏       | 22/100 [06:53<24:13, 18.63s/it]

[INFO 12:27] - meta_params: 0.0078125
[INFO 12:27] - cross_cond: 0.0078125
[INFO 12:27] - meta_lambda: 0.0078125
[INFO 12:27] - learning_users_emb: 7.8125e-06
[INFO 12:31] rmse : 0.5485860705375671
[INFO 12:31] valid_rmse : 0.7015166878700256
[INFO 12:31] valid_loss : 0.9363723397254944


 23%|██▎       | 23/100 [07:12<23:54, 18.62s/it]

[INFO 12:45] - meta_params: 0.0078125
[INFO 12:45] - cross_cond: 0.0078125
[INFO 12:45] - meta_lambda: 0.0078125
[INFO 12:45] - learning_users_emb: 7.8125e-06
[INFO 12:50] rmse : 0.5551767349243164
[INFO 12:50] valid_rmse : 0.7099446654319763
[INFO 12:50] valid_loss : 0.9400185942649841


 24%|██▍       | 24/100 [07:30<23:33, 18.59s/it]

[INFO 13:04] - meta_params: 0.0078125
[INFO 13:04] - cross_cond: 0.0078125
[INFO 13:04] - meta_lambda: 0.0078125
[INFO 13:04] - learning_users_emb: 7.8125e-06
[INFO 13:08] rmse : 0.5521169304847717
[INFO 13:08] valid_rmse : 0.706031858921051
[INFO 13:08] valid_loss : 0.9399572014808655


 25%|██▌       | 25/100 [07:49<23:15, 18.61s/it]

[INFO 13:23] - meta_params: 0.00390625
[INFO 13:23] - cross_cond: 0.00390625
[INFO 13:23] - meta_lambda: 0.00390625
[INFO 13:23] - learning_users_emb: 3.90625e-06
[INFO 13:27] rmse : 0.548313558101654
[INFO 13:27] valid_rmse : 0.7011682391166687
[INFO 13:27] valid_loss : 0.9368614554405212


 26%|██▌       | 26/100 [08:07<22:59, 18.64s/it]

[INFO 13:41] - meta_params: 0.00390625
[INFO 13:41] - cross_cond: 0.00390625
[INFO 13:41] - meta_lambda: 0.00390625
[INFO 13:41] - learning_users_emb: 3.90625e-06
[INFO 13:46] rmse : 0.5437510013580322
[INFO 13:46] valid_rmse : 0.69533371925354
[INFO 13:46] valid_loss : 0.9385684132575989


 27%|██▋       | 27/100 [08:26<22:41, 18.65s/it]

[INFO 14:00] - meta_params: 0.00390625
[INFO 14:00] - cross_cond: 0.00390625
[INFO 14:00] - meta_lambda: 0.00390625
[INFO 14:00] - learning_users_emb: 3.90625e-06
[INFO 14:04] rmse : 0.5365560054779053
[INFO 14:04] valid_rmse : 0.6861329674720764
[INFO 14:04] valid_loss : 0.9339613318443298


 28%|██▊       | 28/100 [08:45<22:23, 18.66s/it]

[INFO 14:18] - meta_params: 0.00390625
[INFO 14:18] - cross_cond: 0.00390625
[INFO 14:18] - meta_lambda: 0.00390625
[INFO 14:18] - learning_users_emb: 3.90625e-06
[INFO 14:23] rmse : 0.5497655272483826
[INFO 14:23] valid_rmse : 0.7030249834060669
[INFO 14:23] valid_loss : 0.9366678595542908


 29%|██▉       | 29/100 [09:03<22:02, 18.63s/it]

[INFO 14:37] - meta_params: 0.00390625
[INFO 14:37] - cross_cond: 0.00390625
[INFO 14:37] - meta_lambda: 0.00390625
[INFO 14:37] - learning_users_emb: 3.90625e-06
[INFO 14:42] rmse : 0.5473130345344543
[INFO 14:42] valid_rmse : 0.6998887658119202
[INFO 14:42] valid_loss : 0.9386559128761292


 29%|██▉       | 29/100 [09:22<22:57, 19.40s/it]

[INFO 14:42] -- END Training --
[INFO 14:42] train on cuda



100%|██████████| 6/6 [00:04<00:00,  1.44it/s]


[INFO 14:53] ({0: {'mi_acc': 0.6947683095932007}, 1: {'mi_acc': 0.6942700147628784}, 2: {'mi_acc': 0.6940707564353943}, 3: {'mi_acc': 0.6961634159088135}, 4: {'mi_acc': 0.6958644390106201}, 5: {'mi_acc': 0.6982560753822327}}, {0: {'meta_doa': 0.5289896884487395}, 1: {'meta_doa': 0.5510916448164791}, 2: {'meta_doa': 0.5684174768159787}, 3: {'meta_doa': 0.5906893303588627}, 4: {'meta_doa': 0.6148654919621281}, 5: {'meta_doa': 0.6317061722410084}})
[INFO 14:53] 0.09524922147326688


In [27]:
config['meta_trainer'] = 'GAP'
#config['meta_trainer'] = 'Adam'
for seed in range(1) :
    config['seed'] = seed
    logging.info(f'#### seed : {seed} ####')
    
    train_data.reset_rng()
    valid_data.reset_rng()
    test_data.reset_rng()


    S = liriscat.selectionStrategy.Random(metadata,**config)
    S.init_models(train_data, valid_data)
    S.train(train_data, valid_data)
    liriscat.utils.set_seed(0)
    S.reset_rng()
    d = (S.evaluate_test(test_data, train_data, valid_data))
    print(liriscat.utils.pareto_index(d))
    print(d)
    torch.cuda.empty_cache()

[INFO 04:26] #### seed : 0 ####
[INFO 04:26] Random_cont_model
[INFO 04:26] compiling CDM model
[INFO 04:26] compiling selection model
[INFO 04:26] ------- START Training
[INFO 04:26] train on cuda


  0%|          | 0/100 [00:00<?, ?it/s]

[DEBUG 08:24] ------- Epoch : 0
[INFO 08:24] ----- User batch : 0
[DEBUG 08:24] --- Query nb : 0
[DEBUG 08:24] - Update users 
[DEBUG 13:26] --- Query nb : 1
[DEBUG 13:26] - Update users 
