# ESNs for Text Classification

This notebook provides a use case example of the ``EsnTorch`` library. It describes the implementation of an **Echo State Network (ESN)** for text classification on the **TREC-6** dataset (question classification).

The instantiation, training and evaluation of an ESN for text classification is achieved via the following steps:
1. Import libraries and modules
2. Load and prepare data
3. Instantiate the model:
    - specify parameters
    - specify learning algorithm
    - warm up
4. Train
5. Evaluate

## Librairies

In [11]:
# !pip install transformers==4.8.2
# !pip install datasets==1.7.0
# !pip install tqdm
# !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html

In [12]:
# To enable progress bars in jupyter:

# pip install ipywidgets
# jupyter nbextension enable --py widgetsnbextension
# conda install -c conda-forge nodejs
# jupyter labextension install @jupyter-widgets/jupyterlab-manager

In [13]:
# Comment this if library is installed!
import os
import sys
sys.path.insert(0, os.path.abspath(".."))

In [14]:
from tqdm.notebook import tqdm_notebook
from sklearn.metrics import classification_report

import torch

from datasets import load_dataset, Dataset, concatenate_datasets

from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorWithPadding

import esntorch.core.reservoir as res
import esntorch.core.learning_algo as la
import esntorch.core.pooling_strategy as ps
import esntorch.core.esn as esn

In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
# Set device (cpu or gpu if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

## Load and prepare data

### Load and tokenize data

In [17]:
# Custom functions for loading and preparing data
def tokenize(sample):
    """Tokenize sample"""
    
    sample = tokenizer(sample['text'], truncation=True, padding=False, return_length=True)
    
    return sample
    
def load_and_prepare_dataset(dataset_name, split, cache_dir):
    """
    Load dataset from the datasets library of HuggingFace.
    Tokenize and add length.
    """
    
    # Load dataset
    dataset = load_dataset(dataset_name, split=split, cache_dir=CACHE_DIR)
    
    # Rename label column for tokenization purposes (use 'label-fine' for fine-grained labels)
    dataset = dataset.rename_column('label-coarse', 'labels')
    
    # Tokenize data
    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.rename_column('length', 'lengths')
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lengths'])
    
    return dataset

In [18]:
# Load BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Load and prepare data
CACHE_DIR = 'cache_dir/' # put your path here

full_dataset = load_and_prepare_dataset('trec', split=None, cache_dir=CACHE_DIR)
train_dataset = full_dataset['train'].sort("lengths")
test_dataset = full_dataset['test'].sort("lengths")

# Create dict of all datasets
dataset_d = {
    'train': train_dataset,
    'test': test_dataset
    }

Using custom data configuration default
Reusing dataset trec (cache_dir/trec/default/1.1.0/751da1ab101b8d297a3d6e9c79ee9b0173ff94c4497b75677b59b61d5467a9b9)


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [19]:
dataset_d

{'train': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 5452
 }),
 'test': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 500
 })}

### Create dataloaders

In [20]:
# Create dict of dataloaders

dataloader_d = {}

for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, 
                                                  collate_fn=DataCollatorWithPadding(tokenizer))

In [21]:
dataloader_d

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f851be9a710>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f851be9a590>}

## Instanciate the model

### ESN

Most parameters are self-explanatory for a reader familiar with ESNs. Please refer to the documentation for further details. 

The ``mode`` parameter represents the type of reservoir to be considered:
- ``recurrent_layer``: implements a **classical recurrent reservoir**, specified among others by its ``dim``, ``sparsity``, ``spectral_radius``, ``leaking_rate`` and ``activation_function``.
- ``linear_layer``: implements a simple **linear layer** specified by its ``dim`` and ``activation_function``.
- ``no_layer``: implements **the absence of reservoir**, meaning that the embedded inputs are directly fed to the the learning algorithms.

The comparison between the ``recurrent_layer`` and the ``no_layer``modes enables to assess the impact of the reservoir on the results. The comparison between the ``recurrent_layer`` and the ``linear_layer``modes allows to assess the importance of the recurrence of the reservoir on the results.

In [22]:
# ESN parameters
esn_params = {
            'embedding': 'bert-base-uncased', # name of Hugging Face model
            'dim': 1000,
            'sparsity': 0.9,
            'spectral_radius': 0.9,
            'leaking_rate': 0.5,
            'activation_function': 'tanh', # 'tanh', 'relu'
            'bias_scaling': 0.1,
            'input_scaling': 0.1,
            'mean': 0.0,
            'std': 1.0,     
            'learning_algo': None,         # initialzed below
            'criterion': None,             # initialzed below (only for learning algos trained with SGD)
            'optimizer': None,             # initialzed below (only for learning algos trained with SGD)
            'pooling_strategy': 'mean',    # 'mean', 'last', None
            'bidirectional': False,        # True, False
            'mode' : 'recurrent_layer',    # 'no_layer', 'linear_layer', 'recurrent_layer'
            'device': device,  
            'seed': 42
            }

# Instantiate the ESN
ESN = esn.EchoStateNetwork(**esn_params)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model downloaded: bert-base-uncased


### Learning algorithm
Choose your learning algo by un-commenting its associated cell.

**Direct methods:**
To be used with the following learning algos:
- ``RidgeRegression`` (our implementation)
- ``RidgeRegression_skl`` (from scikit-learn)
- ``LinearSVC`` (from scikit-learn)
- ``LogisticRegression_skl`` (from scikit-learn)

In [23]:
ESN.learning_algo = la.RidgeRegression(alpha=10.0)

In [24]:
# ESN.learning_algo = la.RidgeRegression_skl(alpha=10.0)

In [25]:
# ESN.learning_algo = la.LinearSVC(C=1.0)

In [26]:
# ESN.learning_algo = la.LogisticRegression_skl()

**Gradient descent methods:**
To be used with the following learning algos:
- ``LogisticRegression`` (our implementation)
- ``DeepNN`` (our implementation)

In [27]:
# if esn_params['mode'] == 'no_layer':
#     input_dim = ESN.layer.input_dim
# else:
#     input_dim = ESN.layer.dim

# if esn_params['bidirectional']:
#     input_dim *= 2 

In [28]:
# ESN.learning_algo = la.LogisticRegression(input_dim=input_dim, output_dim=6)

In [29]:
# ESN.learning_algo = la.DeepNN([input_dim, 512, 256, 6])

In [30]:
# # Needs criterion and otpimizer

# ESN.criterion = torch.nn.CrossEntropyLoss()
# ESN.optimizer = torch.optim.Adam(ESN.learning_algo.parameters(), lr=0.01)

### Warm up

In [31]:
# Put model on device
ESN = ESN.to(device)

In [32]:
# Warm up the ESN on multiple sentences
if isinstance(ESN.layer, res.LayerRecurrent):
    ESN.warm_up(dataset_d['train'].select(range(10)))

## Training

For **direct methods**, the parameters ``epochs`` and ``iter_steps`` are ignored. 

In [39]:
ESN.fit(dataloader_d["train"], epochs=3, iter_steps=50)

Computing closed-form solution...


HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

merged_states tensor([[-0.5491, -0.1495, -0.6181,  ...,  0.3854,  0.0540, -0.3135],
        [-0.2277, -0.3749, -0.4504,  ...,  0.4097, -0.4364, -0.5022],
        [ 0.2181, -0.1885, -0.5022,  ...,  0.5114, -0.1661, -0.5272],
        ...,
        [-0.5534, -0.0398, -0.3525,  ...,  0.5843, -0.4014, -0.2391],
        [ 0.0096, -0.2676, -0.7860,  ..., -0.0653, -0.0119, -0.2290],
        [-0.1647, -0.5620, -0.3170,  ...,  0.5999,  0.0895, -0.4625]]) torch.Size([256, 1000])
merged_states tensor([[-1.3371e-01, -2.5267e-01, -2.8620e-01,  ...,  5.0754e-01,
          1.2015e-01, -1.4571e-01],
        [-1.6493e-01, -4.9228e-01, -1.8564e-01,  ...,  1.3779e-01,
         -1.7096e-01, -4.8600e-01],
        [-3.9425e-02, -3.9217e-01, -9.1265e-02,  ...,  3.0076e-01,
          4.0995e-01, -2.4503e-01],
        ...,
        [ 1.2299e-01, -3.0018e-01, -2.8754e-01,  ...,  3.6618e-01,
         -3.0152e-01, -3.7622e-01],
        [-1.5239e-01, -5.1234e-04, -3.5158e-01,  ...,  4.9021e-01,
         -5.1167e-02, 

KeyboardInterrupt: 

## Evaluation

In [24]:
# Train predictions and accuracy
train_pred, train_acc = ESN.predict(dataloader_d["train"], verbose=False)
train_acc

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))




93.01173881144534

In [25]:
# Test predictions
test_pred, test_acc = ESN.predict(dataloader_d["test"], verbose=False)
test_acc

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




93.8

In [26]:
# Test classification report
print(classification_report(test_pred.tolist(), 
                            dataset_d['test']['labels'].tolist(), 
                            digits=4))

              precision    recall  f1-score   support

           0     0.9855    0.9189    0.9510       148
           1     0.8191    0.9506    0.8800        81
           2     0.7778    1.0000    0.8750         7
           3     0.9692    0.9844    0.9767        64
           4     0.9735    0.9402    0.9565       117
           5     0.9383    0.9157    0.9268        83

    accuracy                         0.9380       500
   macro avg     0.9106    0.9516    0.9277       500
weighted avg     0.9429    0.9380    0.9390       500

