# Tutorial 1: ESN with Ridge Regression

This notebook provides a use case example of the ``EsnTorch`` library. It describes the implementation of an **Echo State Network (ESN)** for text classification on the **TREC-6** dataset (question classification).

The instantiation, training and evaluation of an ESN for text classification
is achieved via the following steps:
- Import the required libraries and modules
- Load and prepare data
- Instantiate the model by:
    - specifying its constituting parameters
    - specifying its learning algorithm
    - warming up the model
- Train the ESN
- Evaluate the ESN on the train and test sets

## Librairies

In [1]:
# !pip install transformers==4.8.2
# !pip install datasets==1.7.0

In [2]:
# Comment this if library is installed!
import os
import sys
sys.path.insert(0, os.path.abspath(".."))

In [3]:
# import numpy as np
from sklearn.metrics import classification_report

import torch

from datasets import load_dataset, Dataset, concatenate_datasets

from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorWithPadding

import esntorch.core.reservoir as res
import esntorch.core.learning_algo as la
import esntorch.core.merging_strategy as ms
import esntorch.core.esn as esn

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
# Set device (cpu or gpu if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

## Load and prepare data

### Load and tokenize data

In [6]:
# Custom functions for loading and preparing data
def tokenize(sample):
    """Tokenize sample"""
    
    sample = tokenizer(sample['text'], truncation=True, padding=False, return_length=True)
    
    return sample
    
def load_and_prepare_dataset(dataset_name, split, cache_dir):
    """
    Load dataset from the datasets library of HuggingFace.
    Tokenize and add length.
    """
    
    # Load dataset
    dataset = load_dataset(dataset_name, split=split, cache_dir=CACHE_DIR)
    
    # Rename label column for tokenization purposes (use 'label-fine' for fine-grained labels)
    dataset = dataset.rename_column('label-coarse', 'labels')
    
    # Tokenize data
    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.rename_column('length', 'lengths')
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels', 'lengths'])
    
    return dataset

In [7]:
# Load BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Load and prepare data
CACHE_DIR = 'cache_dir/' # put your path here

full_dataset = load_and_prepare_dataset('trec', split=None, cache_dir=CACHE_DIR)
train_dataset = full_dataset['train'].sort("lengths")
test_dataset = full_dataset['test'].sort("lengths")

# Create dict of all datasets
dataset_d = {
    'train': train_dataset,
    'test': test_dataset
    }

Using custom data configuration default
Reusing dataset trec (cache_dir/trec/default/1.1.0/751da1ab101b8d297a3d6e9c79ee9b0173ff94c4497b75677b59b61d5467a9b9)


HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [8]:
dataset_d

{'train': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 5452
 }),
 'test': Dataset({
     features: ['attention_mask', 'input_ids', 'label-fine', 'labels', 'lengths', 'text', 'token_type_ids'],
     num_rows: 500
 })}

### Create dataloaders

In [9]:
# Create dict of dataloaders

dataloader_d = {}

for k, v in dataset_d.items():
    dataloader_d[k] = torch.utils.data.DataLoader(v, batch_size=256, 
                                                  collate_fn=DataCollatorWithPadding(tokenizer))

In [10]:
dataloader_d

{'train': <torch.utils.data.dataloader.DataLoader at 0x15d9e8100>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x15d9e8b80>}

## Instanciate the model

### ESN

Most parameters are self-explanatory for a reader familiar with ESNs. Please refer to the documentation for further details. 

The ``mode`` parameter represents the type of reservoir to be considered:
- ``recurrent_layer``: implements a **classical recurrent reservoir**, specified among others by its ``dim``, ``sparsity``, ``spectral_radius``, ``leaking_rate`` and ``activation_function``.
- ``linear_layer``: implements a simple **linear layer** specified by its ``dim`` and ``activation_function``.
- ``no_layer``: implements **the absence of reservoir**, meaning that the embedded inputs are directly fed to the the learning algorithms.

The comparison between the ``recurrent_layer`` and the ``no_layer``modes enable to assess the impact of the reservoir on the results. The comparison between the ``recurrent_layer`` and the ``linear_layer``modes enable to assess the impact of the recurrence of the reservoir on the results.

In [11]:
# ESN parameters
esn_params = {
            'embedding_weights': 'bert-base-uncased',
            'input_dim': 768,  # dim of BERT encoding!
            'dim': 300,
            'bias_scaling': 0.1,
            'sparsity': 0.9,
            'spectral_radius': 0.9,
            'leaking_rate': 0.5,
            'activation_function': 'tanh',
            'input_scaling': 0.1,
            'mean': 0.0,
            'std': 1.0,     
            'learning_algo': None,       # initialzed below
            'criterion': None,           # initialzed below (only for learning algos trained with SGD)
            'optimizer': None,           # initialzed below (only for learning algos trained with SGD)
            'merging_strategy': 'mean',
            'bidirectional': False,
            'mode' : 'recurrent_layer',  # 'no_layer', 'linear_layer', 'recurrent_layer'
            'device': device,
            'seed': 42345
            }

# Instantiate the ESN
ESN = esn.EchoStateNetwork(**esn_params)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model downloaded: bert-base-uncased


### Learning algorithm

In [12]:
# Ridge regression
ESN.learning_algo = la.RidgeRegression(alpha=7.843536845714804)

In [13]:
ESN = ESN.to(device)

### Warm up

In [14]:
# Warm up the ESN on multiple sentences
nb_sentences = 10

for i in range(nb_sentences): 
    sentence = dataset_d["train"].select([i])
    dataloader_tmp = torch.utils.data.DataLoader(sentence, 
                                                 batch_size=1, 
                                                 collate_fn=DataCollatorWithPadding(tokenizer))  

    for sentence in dataloader_tmp:
        ESN.warm_up(sentence)

## Training

In [15]:
%%time
# training the ESN
ESN.fit(dataloader_d["train"])

CPU times: user 2min 33s, sys: 9.57 s, total: 2min 42s
Wall time: 2min 16s


In [16]:
# %%time
# # training the ESN (Logistic Regression, gradient descent)
# ESN.fit(dataloader_d["train"], epochs=10, iter_steps=10)

## Evaluation

In [17]:
# Train predictions and accuracy
train_pred, train_acc = ESN.predict(dataloader_d["train"], verbose=False)

AttributeError: 'float' object has no attribute 'item'

In [None]:
# Train accuracy
train_acc = train_acc.item() if device.type == 'cuda'

In [None]:
# Test predictions
test_pred, test_acc = ESN.predict(dataloader_d["test"], verbose=False)

In [None]:
# Test accuracy
test_acc = test_acc.item() if device.type == 'cuda'

In [None]:
# Test classification report
print(classification_report(test_pred.tolist(), 
                            dataset_d['test']['labels'].tolist(), 
                            digits=4))