# Tutorial

This notebook provides a use case example of the ``EsnTorch`` library.
It described the implementation of an **Echo State Network (ESN)**
for text classification on the **TREC-6** dataset.

The instantiation, training and evaluation of an ESN for text classification
is achieved via the following steps:
- Import the required modules
- Create the dataloaders
- Instantiate the ESN by specifying:
    - a reservoir
    - a loss function
    - a learning algorithm
- Train the ESN
- Training and testing results

## Librairies

In [9]:
# !pip install transformers==4.8.2
# !pip install datasets==1.7.0

In [10]:
# Comment this if library is installed
import os
import sys
sys.path.insert(0, os.path.abspath(".."))
# sys.path.insert(0, os.path.abspath("../.."))

In [12]:
# import numpy as np
from sklearn.metrics import classification_report

import torch

from datasets import load_dataset, Dataset, concatenate_datasets

from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorWithPadding

import esntorch.core.reservoir as res
import esntorch.core.learning_algo as la
import esntorch.core.pooling_strategy as ps
import esntorch.core.esn as esn

In [13]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Device and Seed

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Load and Tokenize Data

In [15]:
# Custom functions for loading and preparing data

def tokenize(sample):
    """Tokenize sample"""
    
    sample = tokenizer(sample['text'], truncation=True, padding=False, return_length=True)
    
    return sample
    
def load_and_prepare_dataset(dataset_name, split, cache_dir):
    """
    Load dataset from the datasets library of HuggingFace.
    Tokenize and add length.
    """
    
    # Load dataset
    dataset = load_dataset(dataset_name, split=split, cache_dir=CACHE_DIR)
    
    # Rename label column for tokenization purposes (use 'label-fine' for fine-grained labels)
    dataset = dataset.rename_column('label-coarse', 'labels')
    
    # Tokenize data
    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.rename_column('length', 'lengths')
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lengths'])
    
    return dataset

In [16]:
# Load BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Load and prepare data
CACHE_DIR = '.' # put your path here

full_dataset = load_and_prepare_dataset('trec', split=None, cache_dir=CACHE_DIR)
train_dataset = full_dataset['train'].sort("lengths")
test_dataset = full_dataset['test'].sort("lengths")

# Create dict of all datasets
dataset_d = {
    'train': train_dataset,
    'test': test_dataset
    }

Using custom data configuration default


Downloading and preparing dataset trec/default (download: 350.79 KiB, generated: 403.39 KiB, post-processed: Unknown size, total: 754.18 KiB) to ./trec/default/1.1.0/751da1ab101b8d297a3d6e9c79ee9b0173ff94c4497b75677b59b61d5467a9b9...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=335858.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=23354.0, style=ProgressStyle(descriptio…




HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset trec downloaded and prepared to ./trec/default/1.1.0/751da1ab101b8d297a3d6e9c79ee9b0173ff94c4497b75677b59b61d5467a9b9. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [17]:
dataset_d

{'train': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 5452
 }),
 'test': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 500
 })}

In [43]:
# Create dict of dataloaders

dataloader_d = {}

for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, collate_fn=DataCollatorWithPadding(tokenizer))

In [44]:
dataloader_d

{'train': <torch.utils.data.dataloader.DataLoader at 0x7fbab20842e0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7fbc59932910>}

## Model

In [45]:
# ESN parameters
esn_params = {
            'embedding': 'bert-base-uncased', # TEXT.vocab.vectors,
            'distribution' : 'uniform',               # uniform, gaussian
            'input_dim' : 768,                        # dim of BERT encoding!
            'reservoir_dim' : 1000,
            'bias_scaling' : 0., #1.0742377381236705, # 1.0,
            'sparsity' : 0.,
            'spectral_radius' : 0.7094538192983408, # 0.9,
            'leaking_rate': 0.17647315261153904, # 0.5,
            'activation_function' : 'tanh',
            'input_scaling' : 0.1, # 1.0,
            'mean' : 0.0,
            'std' : 1.0,
            #'learning_algo' : None,     # initialzed below
            #'criterion' : None,         # initialzed below
            #'optimizer' : None,         # initialzed below
            'pooling_strategy' : 'mean',
            'bidirectional' : False,     # True
            'device' : device,
            'mode' : 'esn',              # 'no_layer, 'linear_layer'
            'seed' : 42345
             }

# Instantiate the ESN
ESN = esn.EchoStateNetwork(**esn_params)

# Define the learning algo of the ESN
# Ridge Regression
ESN.learning_algo = la.RidgeRegression(alpha=7.843536845714804)

# Logistic Regression (uncomment below)
# if esn_params['mode'] == 'no_layer':
#     input_dim = esn_params['input_dim']
# else:
#     input_dim = esn_params['reservoir_dim']
    
# ESN.learning_algo = la.LogisticRegression(input_dim=input_dim, output_dim=6)
# ESN.criterion = torch.nn.CrossEntropyLoss()                                  # loss
# ESN.optimizer = torch.optim.Adam(ESN.learning_algo.parameters(), lr=0.01)    # optimizer

# Put the ESN on the device (CPU or GPU)
ESN = ESN.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model downloaded: bert-base-uncased


In [46]:
# Warm up the ESN on multiple sentences
nb_sentences = 10

for i in range(nb_sentences): 
    sentence = dataset_d["train"].select([i])
    dataloader_tmp = torch.utils.data.DataLoader(sentence, 
                                                 batch_size=1, 
                                                 collate_fn=DataCollatorWithPadding(tokenizer))  

    for sentence in dataloader_tmp:
        ESN.warm_up(sentence)

## Training

In [47]:
%%time
# training the ESN
ESN.fit(dataloader_d["train"])

CPU times: user 9.51 s, sys: 1.1 s, total: 10.6 s
Wall time: 9.91 s


In [15]:
# %%time
# # training the ESN (Logistic Regression, gradient descent)
# ESN.fit(dataloader_d["train"], epochs=10, iter_steps=10)

## Results

In [33]:
# Train predictions and accuracy
train_pred, train_acc = ESN.predict(dataloader_d["train"], verbose=False)
train_acc.item()

93.03008270263672

In [34]:
# Test predictions and accuracy
test_pred, test_acc = ESN.predict(dataloader_d["test"], verbose=False)
test_acc.item()

93.80000305175781

In [36]:
# Test classification report
print(classification_report(test_pred.tolist(), 
                            dataset_d['test']['labels'].tolist(), 
                            digits=4))

              precision    recall  f1-score   support

           0     0.9710    0.9054    0.9371       148
           1     0.8298    0.9286    0.8764        84
           2     0.7778    1.0000    0.8750         7
           3     0.9538    1.0000    0.9764        62
           4     0.9735    0.9565    0.9649       115
           5     0.9630    0.9286    0.9455        84

    accuracy                         0.9380       500
   macro avg     0.9115    0.9532    0.9292       500
weighted avg     0.9417    0.9380    0.9387       500

