# Path setup & import packages

In [1]:
import pandas as pd
%load_ext autoreload
%autoreload 2
import sys
import os
root_path = '../../../' # path to project root
sys.path.append('{}/code'.format(root_path))
sys.path.append('{}/code/core'.format(root_path))
sys.path.append('{}/code/datasets/'.format(root_path))
sys.path.insert(0,'{}/code/ptranking'.format(root_path))

from core.ranking_utils import *
from core.mallows import *
from core.ws_ranking import *
from core.ws_real_workflow import * 
from datasets.imdb_tmdb_dataset import * 
from datasets.basic_clmn_dataset import * 
from core.labelling.feature_lf import *
from ptranking_wrapper import PtrankingWrapper
import datasets_factory 
import numpy as np 
import yaml
import matplotlib.pyplot as plt
import pickle

# Read config & basic setup

In [2]:
config_file_path = '{}/configs/imdb-tmdb_ranking_experiment_play.yaml'.format(root_path)

with open(config_file_path,'r') as conf_file:
    conf = yaml.full_load(conf_file)
    conf['project_root'] = root_path 

data_conf = conf['data_conf']
weak_sup_conf = conf['weak_sup_conf'] # For partial ranking experiments, we should give
l2r_training_conf = conf['l2r_training_conf']
data_conf['project_root'] = root_path

In [3]:
weak_sup_conf['synthetic'] = False

# Train and evaluation - mainly with PtrankingWrapper

In [4]:
for seed in range(5):
    dataset= datasets_factory.create_dataset(data_conf)
    dataset.create_samples()
    
    if l2r_training_conf['use_weak_labels']:
        Y_tilde, thetas = get_weak_labels(dataset, weak_sup_conf, root_path=root_path)
        r_utils = RankingUtils(data_conf['dimension'])
        kt = r_utils.mean_kt_distance(Y_tilde,dataset.Y)
        print('kt distance: ', kt)
        dataset.set_Y_tilde(Y_tilde)
    else:
        kt = None
    
    ptwrapper = PtrankingWrapper(data_conf=data_conf, weak_sup_conf=weak_sup_conf,
                                 l2r_training_conf=l2r_training_conf, result_path=conf['results_path'],
                                 wl_kt_distance = kt)
    X_train, X_test, Y_train, Y_test = dataset.get_train_test_torch(use_weak_labels=l2r_training_conf['use_weak_labels'])
    ptwrapper.set_data(X_train=X_train, X_test=X_test,
                      Y_train=Y_train, Y_test=Y_test)
    model = ptwrapper.get_model()
    result = ptwrapper.train_model(model, verbose=1)
    
    with open(os.path.join(f'results/seed_{seed}.pickle'), 'wb') as f:
        pickle.dump(result, f)
        
    print(seed, max(result['test_tau']))

Generate samples...
Weak labels generated and saved in ../../../data/imdb-tmdb/processed/210513_dim-10_ntrain-500_ntest-1000_model-ListMLE_weaklabel-False/LFs/weak_labels.pkl
Use our weak supervision...train_method: triplet_opt,inference_rule: weighted_kemeny
kt distance:  0.3519333333333333
use_weak_labels:True, we will use weak labels
Training data shape, X_train.shape torch.Size([5000, 5, 255]) Y_train.shape torch.Size([5000, 5])
set_and_load_data in LTREvaluator
(5000, 5, 255) (5000, 5) (5000,)
data_dict {'data_id': 'imdb_tmdb', 'dir_data': 'data/imdb-tmdb/processed/210513_dim-10_ntrain-500_ntest-1000_model-ListMLE_weaklabel-False', 'min_docs': 10, 'min_rele': 1, 'scale_data': False, 'scaler_id': None, 'scaler_level': None, 'train_presort': True, 'validation_presort': True, 'test_presort': True, 'train_batch_size': 64, 'validation_batch_size': 1, 'test_batch_size': 1, 'unknown_as_zero': False, 'binary_rele': False, 'num_features': 255, 'has_comment': False, 'label_type': <LABEL_TYP

epoch 0, loss [1585.7478], train tau 0.050112903118133545, test_tau 0.3540000021457672,train_ndcg@1 tensor([0.9636]), test_ndcg@1 tensor([0.6457])
epoch 1, loss [1229.1354], train tau 0.045732468366622925, test_tau 0.3531000018119812,train_ndcg@1 tensor([0.9664]), test_ndcg@1 tensor([0.6430])
epoch 2, loss [1150.5269], train tau 0.04193222522735596, test_tau 0.3525000214576721,train_ndcg@1 tensor([0.9704]), test_ndcg@1 tensor([0.6435])
epoch 3, loss [1097.8229], train tau 0.039971619844436646, test_tau 0.3529999852180481,train_ndcg@1 tensor([0.9715]), test_ndcg@1 tensor([0.6472])


KeyboardInterrupt: 