# 🏷️ How to label your data and fine-tune a 🤗 sentiment classifier

This tutorial will show you how to fine-tune a sentiment classifier for your own domain starting with no labeled data.

Most online tutorials about fine-tuning models assume you already have a training dataset. You'll find many tutorials showing you how to fine-tune a pre-trained model with widely-used datasets, such as IMDB for sentiment analysis. 

However, very often **what you want is to fine-tune a model for your use case**. It's well-known that NLP model performance degrades with "out-of-domain" data. For example, a sentiment classifier pre-trained on movie reviews (e.g., IMDB) will not perform very well with customer requests.

In this tutorial, we will build a sentiment classifier for user requests in the banking domain. To do this, we will: 

- Start with the most popular sentiment classifier on the Hugging Face Hub (2.3 million downloads as of July 2021) which has been fine-tuned on the SST2 sentiment dataset. 

- Label a training dataset with banking user requests starting with the pre-trained sentiment classifier predictions.

- Fine-tune the pre-trained classifier with your training dataset.

- Label more data by correcting the predictions of the fine-tuned model.


![Labeling workflow](img/labeling_tutorial/workflow.svg "Labeling workflow")

## Preliminaries

### Dataset: BANKING 77


In [7]:
from datasets import load_dataset

banking_ds = load_dataset("banking77") ; banking_ds

Using custom data configuration default
Reusing dataset banking77 (/Users/dani/.cache/huggingface/datasets/banking77/default/1.1.0/aec0289529599d4572d76ab00c8944cb84f88410ad0c9e7da26189d31f62a55b)


DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 10003
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 3080
    })
})

In [22]:
to_label1, to_label2 = banking_ds['train'].train_test_split(test_size=0.5).values()

Loading cached split indices for dataset at /Users/dani/.cache/huggingface/datasets/banking77/default/1.1.0/aec0289529599d4572d76ab00c8944cb84f88410ad0c9e7da26189d31f62a55b/cache-19f5b897c66c7f91.arrow and /Users/dani/.cache/huggingface/datasets/banking77/default/1.1.0/aec0289529599d4572d76ab00c8944cb84f88410ad0c9e7da26189d31f62a55b/cache-059505c4ae6201d3.arrow


### Model: Sentiment distilbert fine-tuned on sst-2

In [23]:
from transformers import pipeline

sentiment_classifier = pipeline(
    model="distilbert-base-uncased-finetuned-sst-2-english",
    task="sentiment-analysis", 
    return_all_scores=True,
)

## Run pre-trained model on raw dataset

In [24]:
import rubrix as rb

In [25]:
def predict(examples):
    return {"predictions": sentiment_classifier(examples['text'], truncation=True)}

In [26]:
to_label1 = to_label1.map(predict, batched=True, batch_size=4)  

HBox(children=(FloatProgress(value=0.0, max=1251.0), HTML(value='')))




In [29]:
records = []
for example in to_label1.shuffle():
    record = rb.TextClassificationRecord(
        inputs=example["text"],
        metadata={'category': example['label']},
        prediction=[(pred['label'], pred['score']) for pred in example['predictions']],
        prediction_agent="distilbert-base-uncased-finetuned-sst-2-english"
    )
    records.append(record)

Loading cached shuffled indices for dataset at /Users/dani/.cache/huggingface/datasets/banking77/default/1.1.0/aec0289529599d4572d76ab00c8944cb84f88410ad0c9e7da26189d31f62a55b/cache-e084c7ecb48c6090.arrow


In [30]:
rb.log(name='labeling_with_pretrained', records=records)

BulkResponse(dataset='labeling_with_pretrained', processed=5001, failed=0)

## Explore and label data with pretrained model


### Labeling random examples

![labeling](https://github.com/dvsrepo/imgs/raw/main/labeling_tutorial/1.gif "labeling")



### Labeling POSITIVE examples

![labeling](https://github.com/dvsrepo/imgs/raw/main/labeling_tutorial/2.gif "labeling")

## Fine-tune pre-trained model

In [31]:
rb_df = rb.load(name='labeling_with_pretrained')

In [37]:
rb_df = rb_df[rb_df.status == "Validated"] ; len(rb_df)

229

In [38]:
rb_df.head()

Unnamed: 0,inputs,prediction,annotation,prediction_agent,annotation_agent,multi_label,explanation,id,metadata,status,event_timestamp,text
4771,{'text': 'I saw there is a cash withdrawal fro...,"[(NEGATIVE, 0.9997006654739381), (POSITIVE, 0....",[NEGATIVE],distilbert-base-uncased-finetuned-sst-2-english,.local-Rubrix,False,,0001e324-3247-4716-addc-d9d9c83fd8f9,{'category': 20},Validated,,I saw there is a cash withdrawal from my accou...
4772,{'text': 'Why is it showing that my account ha...,"[(NEGATIVE, 0.9991878271102901), (POSITIVE, 0....",[NEGATIVE],distilbert-base-uncased-finetuned-sst-2-english,.local-Rubrix,False,,0017e5c9-c135-44b9-8efb-a17ffecdbe68,{'category': 34},Validated,,Why is it showing that my account has been cha...
4773,{'text': 'I thought I lost my card but I found...,"[(POSITIVE, 0.9842885732650751), (NEGATIVE, 0....",[POSITIVE],distilbert-base-uncased-finetuned-sst-2-english,.local-Rubrix,False,,0048ccce-8c9f-453d-81b1-a966695e579c,{'category': 13},Validated,,"I thought I lost my card but I found it today,..."
4774,{'text': 'I wanted to top up my account and it...,"[(NEGATIVE, 0.999732434749603), (POSITIVE, 0.0...",[NEGATIVE],distilbert-base-uncased-finetuned-sst-2-english,.local-Rubrix,False,,0046aadc-2344-40d2-a930-81f00687bf44,{'category': 59},Validated,,I wanted to top up my account and it doesn't l...
4775,"{'text': 'I need to deposit my virtual card, h...","[(NEGATIVE, 0.9992493987083431), (POSITIVE, 0....",[POSITIVE],distilbert-base-uncased-finetuned-sst-2-english,.local-Rubrix,False,,00071745-741d-4555-82b3-54d25db44c38,{'category': 37},Validated,,"I need to deposit my virtual card, how do i do..."


In [40]:
from datasets import Dataset

rb_df['text'] = rb_df.inputs.transform(lambda r: r['text'])

rb_df['labels'] = rb_df.annotation.transform(lambda r: r[0])

label2id = {"NEGATIVE": 0, "POSITIVE": 1}


# create 🤗 dataset from pandas with labels as numeric ids
train_ds = Dataset.from_pandas(rb_df[['text', 'labels']])
train_ds = train_ds.map(lambda example: {'labels': label2id[example['labels']]})

HBox(children=(FloatProgress(value=0.0, max=229.0), HTML(value='')))




In [41]:
train_ds = train_ds.train_test_split(test_size=0.2) ; train_ds

DatasetDict({
    train: Dataset({
        features: ['__index_level_0__', 'labels', 'text'],
        num_rows: 183
    })
    test: Dataset({
        features: ['__index_level_0__', 'labels', 'text'],
        num_rows: 46
    })
})

In [46]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
  
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

In [51]:
import numpy as np
from transformers import Trainer
from datasets import load_metric
from transformers import TrainingArguments

training_args = TrainingArguments(
    "trainer_experiment", 
    evaluation_strategy="epoch",
    logging_steps=30
)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

train_dataset = train_ds['train'].map(tokenize_function, batched=True).shuffle(seed=42)
eval_dataset = train_ds['test'].map(tokenize_function, batched=True).shuffle(seed=42)

trainer = Trainer(
    args=training_args,
    model=model, 
    train_dataset=train_dataset, 
    eval_dataset=eval_dataset, 
    compute_metrics=compute_metrics,
)

trainer.train()

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Epoch,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1,No log,0.520485,0.76087,27.0565,1.7
2,No log,0.632395,0.73913,32.5403,1.414
3,No log,0.655101,0.76087,29.5319,1.558


TrainOutput(global_step=69, training_loss=0.39967987502830615, metrics={'train_runtime': 1322.9033, 'train_samples_per_second': 0.052, 'total_flos': 112921499105280, 'epoch': 3.0})

In [53]:
model.save_pretrained('distilbert-base-uncased-banking77-sentiment')

## Testing the fine-tuned model and preparing a new dataset for labeling

In [54]:
ft_sentiment_classifier = pipeline(
    model=model, 
    tokenizer=tokenizer, 
    task="sentiment-analysis", 
    return_all_scores=True
)

In [58]:
ft_sentiment_classifier(
    'I need to deposit my virtual card, how do i do that.'
), sentiment_classifier(
    'I need to deposit my virtual card, how do i do that.'
)

([[{'label': 'NEGATIVE', 'score': 0.0009169924887828529},
   {'label': 'POSITIVE', 'score': 0.9990829825401306}]],
 [[{'label': 'NEGATIVE', 'score': 0.9992493987083435},
   {'label': 'POSITIVE', 'score': 0.0007506058318540454}]])

In [59]:
rb_df = rb.load(name='labeling_with_pretrained')
rb_df = rb_df[rb_df.status == "Default"] ; len(rb_df)

4771

In [63]:
rb_df['text'] = rb_df.inputs.transform(lambda r: r['text'])

In [64]:
ds = Dataset.from_pandas(rb_df[['text']])

In [65]:
def predict(examples):
    return {"predictions": ft_sentiment_classifier(examples['text'])}

In [66]:
ds = ds.map(predict, batched=True, batch_size=8) 

HBox(children=(FloatProgress(value=0.0, max=597.0), HTML(value='')))




In [67]:
records = []
for example in ds.shuffle():
    record = rb.TextClassificationRecord(
        inputs=example["text"],
        prediction=[(pred['label'], pred['score']) for pred in example['predictions']],
        prediction_agent="distilbert-base-uncased-banking77-sentiment"
    )
    records.append(record)

In [68]:
rb.log(name='labeling_with_finetuned', records=records)

BulkResponse(dataset='labeling_with_finetuned', processed=4771, failed=0)

[34m[1mwandb[0m: Network error resolved after 1:33:02.232309, resuming normal operation.
[34m[1mwandb[0m: Network error resolved after 0:00:39.029868, resuming normal operation.



# Explore and label data with the fine-tuned model

![labeling](https://github.com/dvsrepo/imgs/raw/main/labeling_tutorial/3.gif "labeling")


## Fine-tuning with more data

In [72]:
def prepare_train_df(dataset_name):
    rb_df = rb.load(name=dataset_name)
    rb_df = rb_df[rb_df.status == "Validated"] ; len(rb_df)
    rb_df['text'] = rb_df.inputs.transform(lambda r: r['text'])
    rb_df['labels'] = rb_df.annotation.transform(lambda r: r[0])
    return rb_df

In [98]:
df = prepare_train_df('labeling_with_finetuned') ; len(df)

83

In [78]:
train_dataset

Dataset({
    features: ['__index_level_0__', 'attention_mask', 'input_ids', 'labels', 'text'],
    num_rows: 183
})

In [85]:
train_dataset = train_dataset.remove_columns('__index_level_0__')

In [96]:
for i,r in df.iterrows():
    tokenization = tokenizer(r["text"], padding="max_length", truncation=True)
    train_dataset = train_dataset.add_item({
        "attention_mask": tokenization["attention_mask"],
        "input_ids": tokenization["input_ids"],
        "labels": label2id[r['labels']],
        "text": r['text'],
    })

In [97]:
train_dataset

Dataset({
    features: ['attention_mask', 'input_ids', 'labels', 'text'],
    num_rows: 266
})

In [99]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

In [100]:
train_ds = train_dataset.shuffle(seed=42)

trainer = Trainer(
    args=training_args,
    model=model, 
    train_dataset=train_dataset, 
    eval_dataset=eval_dataset, 
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1,0.9796,0.362587,0.847826,26.9248,1.708
2,0.3118,0.361794,0.869565,28.7295,1.601
3,0.1507,0.44119,0.847826,27.1541,1.694




TrainOutput(global_step=102, training_loss=0.4260212887151569, metrics={'train_runtime': 1794.0301, 'train_samples_per_second': 0.057, 'total_flos': 164137260994560, 'epoch': 3.0})

[34m[1mwandb[0m: Network error resolved after 0:41:46.731045, resuming normal operation.
