# TREC-6: TEXT Classification + BERT + Ax

## Librairies

In [None]:
# !pip install transformers==4.8.2
# !pip install datasets==1.7.0
# !pip install ax-platform==0.1.20
# !pip install ipywidgets
# !jupyter nbextension enable --py widgetsnbextension

In [16]:
import os
import sys

In [17]:
import io
import re
import pickle
from timeit import default_timer as timer

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import AutoTokenizer
from transformers import BertModel
from transformers.data.data_collator import DataCollatorWithPadding

from ax import optimize
from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting

import esntorch.core.reservoir as res
import esntorch.core.learning_algo as la
import esntorch.core.pooling_strategy as ps
import esntorch.core.esn as esn

In [18]:
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [20]:
SEED = 42

## Global variables

In [15]:
RESULTS_PATH = '~/Results/Ax_results/ESN' # path of your result folder
CACHE_DIR = '~/Data/huggignface/'         # path of your  folder

PARAMS_FILE = 'trec-6_params.pkl'
RESULTS_FILE = 'trec-6_results.pkl'

## Dataset

In [12]:
# rename correct column as 'labels': depends on the dataset you load

def load_and_enrich_dataset(dataset_name, split, cache_dir):
    
    dataset = load_dataset(dataset_name, split=split, cache_dir=CACHE_DIR)
    
    dataset = dataset.rename_column('label-coarse', 'labels') # 'label-fine'
    dataset = dataset.map(lambda e: tokenizer(e['text'], truncation=True, padding=False), batched=True)
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    def add_lengths(sample):
        sample["lengths"] = sum(sample["input_ids"] != 0)
        return sample
    
    dataset = dataset.map(add_lengths, batched=False)
    
    return dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

full_train_dataset = load_and_enrich_dataset('trec', split='train', cache_dir=CACHE_DIR).sort("lengths")
train_val_datasets = full_train_dataset.train_test_split(train_size=0.8, shuffle=True)
train_dataset = train_val_datasets['train'].sort("lengths")
val_dataset = train_val_datasets['test'].sort("lengths")

test_dataset = load_and_enrich_dataset('trec', split='test', cache_dir=CACHE_DIR).sort("lengths")

dataset_d = {
    'full_train': full_train_dataset,
    'train': train_dataset,
    'val': val_dataset,
    'test': test_dataset
    }

dataloader_d = {}
for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, collate_fn=DataCollatorWithPadding(tokenizer))

In [14]:
dataset_d

{'full_train': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 5452
 }),
 'train': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 4361
 }),
 'val': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 1091
 }),
 'test': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 500
 })}

## Optimization

In [11]:
def fitness(leaking_rate, 
            spectral_radius, 
            input_scaling, 
            bias_scaling, 
            alpha, 
            reservoir_dim, 
            dataset_d, 
            dataloader_d, 
            seed_l=[1991, 420, 666, 1979, 7], # 5 seeds
            return_test_acc=False):
    
    acc_l = []
    time_l = []
    
    for seed in seed_l:
    
        # parameters
        esn_params = {
                    'embedding': 'bert-base-uncased', # TEXT.vocab.vectors,
                    'distribution' : 'uniform',               # uniform, gaussian
                    'input_dim' : 768,                        # dim of encoding!
                    'reservoir_dim' : reservoir_dim,
                    'bias_scaling' : bias_scaling,
                    'sparsity' : 0.99,
                    'spectral_radius' : spectral_radius,
                    'leaking_rate': leaking_rate,
                    'activation_function' : 'tanh',
                    'input_scaling' : input_scaling,
                    'mean' : 0.0,
                    'std' : 1.0,
                    'learning_algo' : None,
                    'criterion' : None,
                    'optimizer' : None,
                    'pooling_strategy' : 'mean',
                    'lexicon' : None,
                    'bidirectional' : False, # False
                    'device' : device,
                    'seed' : seed
                     }

        # model
        ESN = esn.EchoStateNetwork(**esn_params)

        ESN.learning_algo = la.RidgeRegression(alpha = alpha)# , mode='normalize')

        ESN = ESN.to(device)

        # warm up (new)
        nb_sentences = 3
        for i in range(nb_sentences): 

            sentence = dataset_d["train"].select([i])
            dataloader_tmp = torch.utils.data.DataLoader(sentence, 
                                                         batch_size=1, 
                                                         collate_fn=DataCollatorWithPadding(tokenizer))  

            for sentence in dataloader_tmp:
                ESN.warm_up(sentence)
        
        # predict
        if return_test_acc:
            t0 = timer()
            LOSS = ESN.fit(dataloader_d["full_train"]) # full_train -> train (like Hugo)
            t1 = timer()
            time_l.append(t1 - t0)
            acc = ESN.predict(dataloader_d["test"], verbose=False)[1].item()
        else:
            LOSS = ESN.fit(dataloader_d["train"])
            acc = ESN.predict(dataloader_d["val"], verbose=False)[1].item()

        acc_l.append(acc)
        
        # clean objects
        del ESN.learning_algo
        del ESN.criterion
        del ESN.pooling_strategy
        del ESN
        torch.cuda.empty_cache()
    
    if return_test_acc:
        return np.mean(acc_l), np.std(acc_l), np.mean(time_l), np.std(time_l)
    else:
        return np.mean(acc_l)

In [12]:
# %%time

# fitness(leaking_rate=0.2, spectral_radius=1.1, input_scaling=0.8, bias_scaling=1.0, alpha=10, reservoir_dim=500, dataset_d=dataset_d, dataloader_d=dataloader_d)

In [13]:
def wrapped_fitness(d, return_test_acc=False):
    
    return fitness(leaking_rate=d['leaking_rate'],
                   spectral_radius=d['spectral_radius'],
                   input_scaling=d['input_scaling'],
                   bias_scaling=d['bias_scaling'],
                   alpha=d['alpha'],
                   reservoir_dim=d['reservoir_dim'], # will be in the loop
                   dataset_d=dataset_d,
                   dataloader_d=dataloader_d,
                   return_test_acc=return_test_acc)

In [14]:
# *** WARNING *** DO NO EXECUTE NEXT CELLS IF BIDIRECTIONAL MODE (OPTIM ALREADY DONE)

In [None]:
best_params_d = {}

for res_dim in [500, 1000, 3000, 5000]:

    best_parameters, best_values, experiment, model = optimize(
            parameters=[
              {
                "name": "leaking_rate",
                "value_type": "float",
                "type": "range",
                "bounds": [0.0, 0.999],
              },
              {
                "name": "spectral_radius",
                "value_type": "float",
                "type": "range",
                "bounds": [0.2, 1.7],
              },
              {
                "name": "input_scaling",
                "value_type": "float",
                "type": "range",
                "bounds": [0.1, 3.0],
              },
              {
                "name": "bias_scaling",
                "value_type": "float",
                "type": "range",
                "bounds": [0.1, 3.0],
              },
              {
                "name": "alpha",
                "value_type": "float",
                "type": "range",
                "log_scale": True,
                "bounds": [1e-3, 1e3],
              },
              {
                "name": "reservoir_dim",
                "value_type": "int",
                "type": "fixed",
                "value": res_dim,
              }
            ],
            # Booth function
            evaluation_function = wrapped_fitness,
            minimize = False,
            objective_name = 'val_accuracy',
            total_trials = 40
        )
    
    # results
    best_params_d[res_dim] = {}
    best_params_d[res_dim]['best_parameters'] = best_parameters
    best_params_d[res_dim]['best_values'] = best_values
    best_params_d[res_dim]['experiment'] = experiment
    # best_params_d[res_dim]['model'] = model

In [None]:
best_params_d

## Results

In [None]:
# best parameters

with open(os.path.join(RESULTS_PATH, PARAMS_FILE), 'wb') as fh:
    pickle.dump(best_params_d, fh)

In [None]:
# # load results
# with open(os.path.join(RESULTS_PATH, PARAMS_FILE), 'rb') as fh:
#     best_params_d = pickle.load(fh)

In [None]:
best_params_d

In [None]:
# results

results_d = {}

for res_dim in [500, 1000, 3000, 5000]:
    
    best_parameters = best_params_d[res_dim]['best_parameters']
    acc, std, time, time_std = wrapped_fitness(best_parameters, return_test_acc=True)
    results_d[res_dim] = acc, std, time, time_std
    print("Experiment finished.")

In [None]:
results_d

In [None]:
with open(os.path.join(RESULTS_PATH, RESULTS_FILE), 'wb') as fh:
    pickle.dump(results_d, fh)

In [32]:
# load results
with open(os.path.join(RESULTS_PATH, 'trec-6_results_2.pkl'), 'rb') as fh:
    results_d = pickle.load(fh)

In [33]:
results_d

{500: (90.44000244140625,
  0.44542175195073835,
  9.471520644892006,
  0.10720533827426228),
 1000: (92.36000366210938,
  0.7525970822215208,
  9.629882322065532,
  0.1428496800140904),
 3000: (94.6800048828125,
  0.515362598957664,
  10.678268241602927,
  0.14556552515205334),
 5000: (95.28000335693359,
  0.44899883203884566,
  16.242420702800153,
  0.3827186644692675)}