# Tutorial

This notebook provides a use case example of the ``EsnTorch`` library.
It described the implementation of an Echo State Network (ESN)
for text classification on the IMDB dataset.

The instantiation, training and evaluation of an ESN for text classification
is achieved via the following steps:
- Import the required modules
- Create the dataloaders
- Instantiate the ESN by specifying:
    - a reservoir
    - a loss function
    - a learning algorithm
- Train the ESN
- Training and testing results

## Librairies

In [1]:
# !pip3 install datasets==1.7.0

In [2]:
# Comment this if library is installed
import os
import sys
sys.path.insert(0, os.path.abspath(".."))

In [3]:
# import numpy as np
from sklearn.metrics import classification_report

import torch

from datasets import load_dataset, Dataset, concatenate_datasets

from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorWithPadding

import esntorch.core.reservoir as res
import esntorch.core.learning_algo as la
import esntorch.core.merging_strategy as ms
import esntorch.core.esn as esn

In [4]:
%load_ext autoreload
%autoreload 2

In [6]:
import transformers
transformers.__version__

'4.15.0'

In [7]:
import datasets
datasets.__version__

'1.17.0'

In [12]:
import ax
ax.__version__

'0.1.20'

## Device and Seed

In [5]:
SEED = 42

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Load and Tokenize Data

In [7]:
# Custom functions for loading and preparing data

def tokenize(sample):
    """Tokenize sample"""
    
    sample = tokenizer(sample['text'], truncation=True, padding=False, return_length=True)
    
    return sample
    
def load_and_enrich_dataset(dataset_name, split, cache_dir):
    """
    Load dataset from the datasets library of HuggingFace.
    Tokenize and add length.
    """
    
    # Load dataset
    dataset = load_dataset(dataset_name, split=split, cache_dir=CACHE_DIR)
    
    # Rename label column for tokenization purposes
    dataset = dataset.rename_column('label', 'labels')
    
    # Tokenize data
    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.rename_column('length', 'lengths')
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lengths'])
    
    return dataset

In [8]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# CACHE_DIR = 'cache_dir/' # put your path here
CACHE_DIR = '/raid/home/jeremiec/huggingface_datasets' # XXX REMOVE THIS FOR DEPLOYMENT

# train set
full_train_dataset = load_and_enrich_dataset('imdb', split='train', cache_dir=CACHE_DIR).sort("lengths")

# mini train and val sets: 70%/30% of 30%
train_val_datasets = full_train_dataset.train_test_split(train_size=0.3, shuffle=True)
train_val_datasets = train_val_datasets['train'].train_test_split(train_size=0.7, shuffle=True)
train_dataset = train_val_datasets['train'].sort("lengths")
val_dataset = train_val_datasets['test'].sort("lengths")

# test set
test_dataset = load_and_enrich_dataset('imdb', split='test', cache_dir=CACHE_DIR).sort("lengths")

dataset_d = {
    'full_train': full_train_dataset,
    'train': train_dataset,
    'val': val_dataset,
    'test': test_dataset
    }

# dataloaders
dataloader_d = {}
for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, collate_fn=DataCollatorWithPadding(tokenizer))

Reusing dataset imdb (/raid/home/jeremiec/huggingface_datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
Loading cached processed dataset at /raid/home/jeremiec/huggingface_datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-ed0791a7ab7b5e93.arrow
Loading cached sorted indices for dataset at /raid/home/jeremiec/huggingface_datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-564008796428c775.arrow
Loading cached split indices for dataset at /raid/home/jeremiec/huggingface_datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-70d09293f42b8520.arrow and /raid/home/jeremiec/huggingface_datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-b4af90adcac04d03.arrow
Reusing dataset imdb (/raid/home/jeremiec/huggingface_datasets/imdb/plain_text/1.0.0/2fdd8b9

In [9]:
dataset_d

{'full_train': Dataset({
     features: ['attention_mask', 'input_ids', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 25000
 }),
 'train': Dataset({
     features: ['attention_mask', 'input_ids', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 5250
 }),
 'val': Dataset({
     features: ['attention_mask', 'input_ids', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 2250
 }),
 'test': Dataset({
     features: ['attention_mask', 'input_ids', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 25000
 })}

In [10]:
# Create dict of dataloaders

dataloader_d = {}

for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, collate_fn=DataCollatorWithPadding(tokenizer))

In [11]:
dataloader_d

{'full_train': <torch.utils.data.dataloader.DataLoader at 0x7efe8e0b9dc0>,
 'train': <torch.utils.data.dataloader.DataLoader at 0x7efe8c116c10>,
 'val': <torch.utils.data.dataloader.DataLoader at 0x7efe8e183820>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7efe8e183100>}

## Model

In [12]:
# ESN parameters
esn_params = {
            'embedding_weights': 'bert-base-uncased', # TEXT.vocab.vectors,
            'distribution' : 'gaussian',              # uniform, gaussian
            'input_dim' : 768,                        # dim of BERT encoding!
            'reservoir_dim' : 500,
            'input_scaling' : 0.1,
            'bias_scaling' : 0.0, # can be None
            'sparsity' : 0.5,
            'spectral_radius' : 1.5,
            'leaking_rate': 0.5,
            'activation_function' : 'tanh',           # 'tanh', relu'
            'mean' : 0.0,
            'std' : 1.0,
            #'learning_algo' : None, # initialzed below
            #'criterion' : None,     # initialzed below
            #'optimizer' : None,     # initialzed below
            'merging_strategy' : 'mean',
            'bidirectional' : False, # True
            'mode' : 'esn', 
            'device' : device,
            'seed' : 888
             }

# Instantiate the ESN
ESN = esn.EchoStateNetwork(**esn_params)

# Define the learning algo of the ESN
ESN.learning_algo = la.RidgeRegression(alpha=0.1)

# Put the ESN on the device (CPU or GPU)
ESN = ESN.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model downloaded: bert-base-uncased


In [13]:
# Warm up the ESN on 3 sentences
nb_sentences = 10

for i in range(nb_sentences): 
    sentence = dataset_d["train"].select([i])
    dataloader_tmp = torch.utils.data.DataLoader(sentence, 
                                                 batch_size=1, 
                                                 collate_fn=DataCollatorWithPadding(tokenizer))  

    for sentence in dataloader_tmp:
        ESN.warm_up(sentence)

## Training

In [14]:
%%time
# training the ESN
ESN.fit(dataloader_d["train"])

CPU times: user 22.4 s, sys: 8.13 s, total: 30.5 s
Wall time: 30.7 s


In [15]:
# inspect X_, y_, etc.
X_ = ESN.learning_algo.X_
y_ = ESN.learning_algo.y_

In [16]:
X_

tensor([[-0.5354,  0.1846,  0.2846,  ..., -0.8567,  0.0376,  1.0000],
        [-0.5571, -0.7364,  0.0747,  ..., -0.9730,  0.2430,  1.0000],
        [-0.2634, -0.2606, -0.2487,  ..., -0.6620,  0.0821,  1.0000],
        ...,
        [ 0.1475, -0.2627,  0.0511,  ..., -0.6474, -0.3583,  1.0000],
        [-0.1446, -0.1679,  0.0573,  ..., -0.7840,  0.2140,  1.0000],
        [ 0.1112, -0.0971, -0.0931,  ..., -0.5763,  0.1195,  1.0000]],
       device='cuda:0')

In [17]:
LI = torch.eye(X_.size()[1], device=device) * ESN.learning_algo.alpha

In [18]:
Xt = torch.transpose(X_, 0, 1)

In [19]:
beta = torch.mm(Xt, X_) + LI

In [20]:
beta

tensor([[ 2.3624e+02,  1.6151e+02, -4.2428e+01,  ...,  5.2709e+02,
          1.9931e+00, -7.2348e+02],
        [ 1.6151e+02,  3.5610e+02, -5.1106e+01,  ...,  7.2167e+02,
         -5.5462e+00, -1.0334e+03],
        [-4.2428e+01, -5.1106e+01,  1.9947e+02,  ..., -2.1363e+01,
          4.3244e+01,  3.4506e+01],
        ...,
        [ 5.2709e+02,  7.2167e+02, -2.1363e+01,  ...,  2.6701e+03,
          2.7944e+01, -3.7022e+03],
        [ 1.9931e+00, -5.5462e+00,  4.3244e+01,  ...,  2.7944e+01,
          1.5077e+02, -5.9600e+01],
        [-7.2348e+02, -1.0334e+03,  3.4506e+01,  ..., -3.7022e+03,
         -5.9600e+01,  5.2501e+03]], device='cuda:0')

In [21]:
beta = torch.pinverse(beta)

## Results

In [22]:
# Train predictions and accuracy
train_pred, train_acc = ESN.predict(dataloader_d["train"], verbose=False)
train_acc.item()

89.71428680419922

In [23]:
# Test predictions and accuracy
test_pred, test_acc = ESN.predict(dataloader_d["val"], verbose=False)
test_acc.item()

87.68888854980469

In [24]:
# Test classification report
print(classification_report(test_pred.tolist(), 
                            dataset_d['val']['labels'].tolist(), 
                            digits=4))

              precision    recall  f1-score   support

           0     0.8712    0.8843    0.8777      1124
           1     0.8828    0.8694    0.8761      1126

    accuracy                         0.8769      2250
   macro avg     0.8770    0.8769    0.8769      2250
weighted avg     0.8770    0.8769    0.8769      2250

