# ESNs for Text Classification

This notebook provides a use case example of the ``EsnTorch`` library. It describes the implementation of an **Echo State Network (ESN)** for text classification on the **TREC-6** dataset (question classification).

The instantiation, training and evaluation of an ESN for text classification is achieved via the following steps:
1. Import libraries and modules
2. Load and prepare data
3. Instantiate the model:
    - specify parameters
    - specify learning algorithm
    - warm up
4. Train
5. Evaluate

## Librairies

In [None]:
# !pip install transformers==4.8.2
# !pip install datasets==1.7.0
# !pip install tqdm
# !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
# To enable progress bars in jupyter:

# pip install ipywidgets
# jupyter nbextension enable --py widgetsnbextension
# conda install -c conda-forge nodejs
# jupyter labextension install @jupyter-widgets/jupyterlab-manager

In [None]:
# Comment this if library is installed!
import os
import sys
sys.path.insert(0, os.path.abspath(".."))

In [4]:
from tqdm.notebook import tqdm_notebook
from sklearn.metrics import classification_report

import torch

from datasets import load_dataset, Dataset, concatenate_datasets

from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorWithPadding

import esntorch.core.reservoir as res
import esntorch.core.learning_algo as la
import esntorch.core.merging_strategy as ms
import esntorch.core.esn as esn

In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
# Set device (cpu or gpu if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Load and prepare data

### Load and tokenize data

In [7]:
# Custom functions for loading and preparing data
def tokenize(sample):
    """Tokenize sample"""
    
    sample = tokenizer(sample['text'], truncation=True, padding=False, return_length=True)
    
    return sample
    
def load_and_prepare_dataset(dataset_name, split, cache_dir):
    """
    Load dataset from the datasets library of HuggingFace.
    Tokenize and add length.
    """
    
    # Load dataset
    dataset = load_dataset(dataset_name, split=split, cache_dir=CACHE_DIR)
    
    # Rename label column for tokenization purposes (use 'label-fine' for fine-grained labels)
    dataset = dataset.rename_column('label-coarse', 'labels')
    
    # Tokenize data
    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.rename_column('length', 'lengths')
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lengths'])
    
    return dataset

In [8]:
# Load BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Load and prepare data
CACHE_DIR = 'cache_dir/' # put your path here

full_dataset = load_and_prepare_dataset('trec', split=None, cache_dir=CACHE_DIR)
train_dataset = full_dataset['train'].sort("lengths")
test_dataset = full_dataset['test'].sort("lengths")

# Create dict of all datasets
dataset_d = {
    'train': train_dataset,
    'test': test_dataset
    }

Using custom data configuration default
Reusing dataset trec (cache_dir/trec/default/1.1.0/751da1ab101b8d297a3d6e9c79ee9b0173ff94c4497b75677b59b61d5467a9b9)


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [9]:
dataset_d

{'train': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 5452
 }),
 'test': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 500
 })}

### Create dataloaders

In [10]:
# Create dict of dataloaders

dataloader_d = {}

for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, 
                                                  collate_fn=DataCollatorWithPadding(tokenizer))

In [11]:
dataloader_d

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f0f3f10da30>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f0f300267c0>}

## Instanciate the model

### ESN

Most parameters are self-explanatory for a reader familiar with ESNs. Please refer to the documentation for further details. 

The ``mode`` parameter represents the type of reservoir to be considered:
- ``recurrent_layer``: implements a **classical recurrent reservoir**, specified among others by its ``dim``, ``sparsity``, ``spectral_radius``, ``leaking_rate`` and ``activation_function``.
- ``linear_layer``: implements a simple **linear layer** specified by its ``dim`` and ``activation_function``.
- ``no_layer``: implements **the absence of reservoir**, meaning that the embedded inputs are directly fed to the the learning algorithms.

The comparison between the ``recurrent_layer`` and the ``no_layer``modes enables to assess the impact of the reservoir on the results. The comparison between the ``recurrent_layer`` and the ``linear_layer``modes allows to assess the importance of the recurrence of the reservoir on the results.

In [12]:
# ESN parameters
esn_params = {
            'embedding_weights': 'bert-base-uncased', # name of Hugging Face model
            'dim': 1000,
            'sparsity': 0.9,
            'spectral_radius': 0.9,
            'leaking_rate': 0.5,
            'activation_function': 'tanh', # 'tanh', 'relu'
            'bias_scaling': 0.1,
            'input_scaling': 0.1,
            'mean': 0.0,
            'std': 1.0,     
            'learning_algo': None,         # initialzed below
            'criterion': None,             # initialzed below (only for learning algos trained with SGD)
            'optimizer': None,             # initialzed below (only for learning algos trained with SGD)
            'merging_strategy': None,    # 'mean', 'last', None
            'bidirectional': True,        # True, False
            'mode' : 'recurrent_layer',    # 'no_layer', 'linear_layer', 'recurrent_layer'
            'device': device,  
            'seed': 42
            }

# Instantiate the ESN
ESN = esn.EchoStateNetwork(**esn_params)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model downloaded: bert-base-uncased
dim and input_dim 1000 768


### Learning algorithm
Choose your learning algo by un-commenting its associated cell.

**Direct methods:**
To be used with the following learning algos:
- ``RidgeRegression`` (our implementation)
- ``RidgeRegression_skl`` (from scikit-learn)
- ``LinearSVC`` (from scikit-learn)
- ``LogisticRegression_skl`` (from scikit-learn)

In [13]:
# ESN.learning_algo = la.RidgeRegression(alpha=10.0)

In [14]:
# ESN.learning_algo = la.RidgeRegression_skl(alpha=10.0)

In [15]:
# ESN.learning_algo = la.LinearSVC(C=1.0)

In [16]:
# ESN.learning_algo = la.LogisticRegression_skl()

**Gradient descent methods:**
To be used with the following learning algos:
- ``LogisticRegression`` (our implementation)
- ``DeepNN`` (our implementation)

In [17]:
if esn_params['mode'] == 'no_layer':
    input_dim = ESN.layer.input_dim
else:
    input_dim = ESN.layer.dim

if esn_params['bidirectional']:
    input_dim *= 2 

In [18]:
ESN.learning_algo = la.LogisticRegression(input_dim=input_dim, output_dim=6)

In [19]:
# ESN.learning_algo = la.DeepNN([input_dim, 512, 256, 6])

In [20]:
# Needs criterion and otpimizer

ESN.criterion = torch.nn.CrossEntropyLoss()
ESN.optimizer = torch.optim.Adam(ESN.learning_algo.parameters(), lr=0.01)

### Warm up

In [21]:
# Put model on device
ESN = ESN.to(device)

In [22]:
# Warm up the ESN on multiple sentences
if isinstance(ESN.layer, res.LayerRecurrent):
    ESN.warm_up(dataset_d['train'].select(range(10)))

## Training

For **direct methods**, the parameters ``epochs`` and ``iter_steps`` are ignored. 

In [23]:
ESN.fit(dataloader_d["train"], epochs=3, iter_steps=50)

Performing gradient descent...


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

FINAL STATES torch.Size([1266, 2000])
FINAL STATES torch.Size([1536, 2000])
FINAL STATES torch.Size([1775, 2000])
FINAL STATES torch.Size([1826, 2000])
FINAL STATES torch.Size([2048, 2000])
FINAL STATES torch.Size([2080, 2000])
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2544, 2000])
FINAL STATES torch.Size([2560, 2000])
FINAL STATES torch.Size([2752, 2000])
FINAL STATES torch.Size([2816, 2000])
FINAL STATES torch.Size([3010, 2000])
FINAL STATES torch.Size([3072, 2000])
FINAL STATES torch.Size([3290, 2000])
FINAL STATES torch.Size([3410, 2000])
FINAL STATES torch.Size([3630, 2000])
FINAL STATES torch.Size([3900, 2000])
FINAL STATES torch.Size([4197, 2000])
FINAL STATES torch.Size([4621, 2000])
FINAL STATES torch.Size([5436, 2000])
FINAL STATES torch.Size([2161, 2000])


HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

FINAL STATES torch.Size([1266, 2000])
FINAL STATES torch.Size([1536, 2000])
FINAL STATES torch.Size([1775, 2000])
FINAL STATES torch.Size([1826, 2000])
FINAL STATES torch.Size([2048, 2000])
FINAL STATES torch.Size([2080, 2000])
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2544, 2000])
FINAL STATES torch.Size([2560, 2000])
FINAL STATES torch.Size([2752, 2000])
FINAL STATES torch.Size([2816, 2000])
FINAL STATES torch.Size([3010, 2000])
FINAL STATES torch.Size([3072, 2000])
FINAL STATES torch.Size([3290, 2000])
FINAL STATES torch.Size([3410, 2000])
FINAL STATES torch.Size([3630, 2000])
FINAL STATES torch.Size([3900, 2000])
FINAL STATES torch.Size([4197, 2000])
FINAL STATES torch.Size([4621, 2000])
FINAL STATES torch.Size([5436, 2000])
FINAL STATES torch.Size([2161, 2000])


HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

FINAL STATES torch.Size([1266, 2000])
FINAL STATES torch.Size([1536, 2000])
FINAL STATES torch.Size([1775, 2000])
FINAL STATES torch.Size([1826, 2000])
FINAL STATES torch.Size([2048, 2000])
FINAL STATES torch.Size([2080, 2000])
Iteration: 50 Loss: 1.0678523778915405
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2544, 2000])
FINAL STATES torch.Size([2560, 2000])
FINAL STATES torch.Size([2752, 2000])
FINAL STATES torch.Size([2816, 2000])
FINAL STATES torch.Size([3010, 2000])
FINAL STATES torch.Size([3072, 2000])
FINAL STATES torch.Size([3290, 2000])
FINAL STATES torch.Size([3410, 2000])
FINAL STATES torch.Size([3630, 2000])
FINAL STATES torch.Size([3900, 2000])
FINAL STATES torch.Size([4197, 2000])
FINAL STATES torch.Size([4621, 2000])
FINAL STATES torch.Size([5436, 2000])
FINAL STATES torch.Size([2161, 2000])


Training complete.


[1.0678523778915405]

## Evaluation

In [24]:
# Train predictions and accuracy
train_pred, train_acc = ESN.predict(dataloader_d["train"], verbose=False)

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

FINAL STATES torch.Size([1266, 2000])
FINAL STATES torch.Size([1536, 2000])
FINAL STATES torch.Size([1775, 2000])
FINAL STATES torch.Size([1826, 2000])
FINAL STATES torch.Size([2048, 2000])
FINAL STATES torch.Size([2080, 2000])
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2304, 2000])
FINAL STATES torch.Size([2544, 2000])
FINAL STATES torch.Size([2560, 2000])
FINAL STATES torch.Size([2752, 2000])
FINAL STATES torch.Size([2816, 2000])
FINAL STATES torch.Size([3010, 2000])
FINAL STATES torch.Size([3072, 2000])
FINAL STATES torch.Size([3290, 2000])
FINAL STATES torch.Size([3410, 2000])
FINAL STATES torch.Size([3630, 2000])
FINAL STATES torch.Size([3900, 2000])
FINAL STATES torch.Size([4197, 2000])
FINAL STATES torch.Size([4621, 2000])
FINAL STATES torch.Size([5436, 2000])
FINAL STATES torch.Size([2161, 2000])



In [25]:
# Test predictions
test_pred, test_acc = ESN.predict(dataloader_d["test"], verbose=False)

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

FINAL STATES torch.Size([1582, 2000])
FINAL STATES torch.Size([2580, 2000])



In [26]:
# Test classification report
print(classification_report(test_pred.tolist(), 
                            dataset_d['test']['labels'].tolist(), 
                            digits=4))

              precision    recall  f1-score   support

           0     1.0000    0.6026    0.7520       229
           1     0.4149    0.6964    0.5200        56
           2     0.2222    1.0000    0.3636         2
           3     0.7231    0.9792    0.8319        48
           4     0.8673    0.9423    0.9032       104
           5     0.7407    0.9836    0.8451        61

    accuracy                         0.7680       500
   macro avg     0.6614    0.8674    0.7026       500
weighted avg     0.8455    0.7680    0.7750       500



In [27]:
ESN.layer.input_dim, ESN.layer.dim, input_dim

(768, 1000, 2000)