# Tutorial

This notebook provides a use case example of the ``EsnTorch`` library.
It described the implementation of the co-called Custom Baseline (CBS)
for text classification on the IMDB dataset.

The instantiation, training and evaluation of the CBS for text classification
is achieved via the following steps:
- Import the required modules
- Create the dataloaders
- Instantiate the CBS by specifying:
    - a reservoir (not recurrent)
    - a loss function
    - a learning algorithm
- Train the CBS
- Training and testing results

## Librairies

In [None]:
# !pip install transformers==4.8.2
# !pip install datasets==1.7.0

In [None]:
# Comment this if library is installed
import os
import sys
sys.path.insert(0, os.path.abspath(".."))

In [None]:
# import numpy as np
from sklearn.metrics import classification_report

import torch

from datasets import load_dataset, Dataset, concatenate_datasets

from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorWithPadding

import esntorch.core.reservoir as res
import esntorch.core.learning_algo as la
import esntorch.core.merging_strategy as ms
import esntorch.core.esn as esn
import esntorch.core.baseline as bs

## Device and Seed

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Load and Tokenize Data

In [None]:
# Custom functions for loading and preparing data

def tokenize(sample):
    """Tokenize sample"""
    
    sample = tokenizer(sample['text'], truncation=True, padding=False, return_length=True)
    
    return sample
    
def load_and_prepare_dataset(dataset_name, split, cache_dir):
    """
    Load dataset from the datasets library of HuggingFace.
    Tokenize and add length.
    """
    
    # Load dataset
    dataset = load_dataset(dataset_name, split=split, cache_dir=CACHE_DIR)
    
    # Rename label column (for tokenization prupses)
    dataset = dataset.rename_column('label', 'labels')
    
    # Tokenize data
    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.rename_column('length', 'lengths')
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lengths'])
    
    return dataset

In [None]:
# Load BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Load and prepare data
CACHE_DIR = 'cache_dir/' # put your path here

full_dataset = load_and_prepare_dataset('imdb', split=None, cache_dir=CACHE_DIR)
train_dataset = full_dataset['train'].sort("lengths")
test_dataset = full_dataset['test'].sort("lengths")

# Create dict of all datasets
dataset_d = {
    'train': train_dataset,
    'test': test_dataset
    }

In [None]:
dataset_d

In [None]:
# Create dict of dataloaders

dataloader_d = {}

for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, collate_fn=DataCollatorWithPadding(tokenizer))

In [None]:
dataloader_d

## Model

In [None]:
# CBS parameters
cbs_params = {
            'embedding_weights': 'bert-base-uncased', # TEXT.vocab.vectors,
            'input_dim' : 768,                        # dim of BERT encoding!
            'reservoir_dim' : 1000,
            'bias_scaling' : 1.0, #1.0,
            'input_scaling' : 1.0,
            'activation_function' : 'relu',           # 'tanh', relu'
            #'learning_algo' : None, # initialzed below
            #'criterion' : None,     # initialzed below
            #'optimizer' : None,     # initialzed below
            'merging_strategy' : 'mean',
            'bidirectional' : False, # True
            'device' : device,
            'seed' : 42
             }

# Instantiate the CBS
CBS = bs.CustomBaseline(**cbs_params)

# Define the learning algo of the CBS
CBS.learning_algo = la.RidgeRegression(alpha=10)

# Put the CBS on the device (CPU or GPU)
CBS = CBS.to(device)

## Training

In [None]:
%%time
# training the CBS
CBS.fit(dataloader_d["train"])

## Results

In [None]:
# Train predictions and accuracy
train_pred, train_acc = CBS.predict(dataloader_d["train"], verbose=False)
train_acc.item()

In [None]:
# Test predictions and accuracy
test_pred, test_acc = CBS.predict(dataloader_d["test"], verbose=False)
test_acc.item()

In [None]:
# Test classification report
print(classification_report(test_pred.tolist(), 
                            dataset_d['test']['labels'].tolist(), 
                            digits=4))