# Chapter 8 - Text Classification in Flair

This Jupyter notebook provides a resource to help you follow the code examples from the book more easily. The notebook covers all practical code snippets and exercises found in: Chapter 8 - Text Classification in Flair.

## Text classification in Flair

### Using Flair pre-trained text classification models

In [None]:
from flair.models import TextClassifier
from flair.data import Sentence

classifier = TextClassifier.load("sentiment")
sentence = Sentence("Flair is pretty cool!")
classifier.predict(sentence)

print(sentence.labels)

In [None]:
sentence = Sentence("I don't get enough sleep.")
classifier.predict(sentence)

print(sentence.labels)

### Document embeddings in Flair

#### DocumentPoolEmbeddings

In [None]:
from flair.data import Sentence
from flair.embeddings import (DocumentPoolEmbeddings,
                              WordEmbeddings)

glove = WordEmbeddings('glove')
document_embeddings = DocumentPoolEmbeddings([glove])
sentence = Sentence('two words')
glove.embed(sentence)


print((sentence[0].embedding + sentence[1].embedding) / 2)

In [None]:
document_embeddings = DocumentPoolEmbeddings([glove])
document_embeddings.embed(sentence)

print(sentence.embedding)

#### TransformerDocumentEmbeddings

In [None]:
from flair.embeddings import TransformerDocumentEmbeddings

embedding = TransformerDocumentEmbeddings('roberta-base')
sentence = Sentence('two words')

embedding.embed(sentence)

print(sentence.embedding)

## Training a text classifier in Flair

### Loading a tagged corpus

In [None]:
from flair.data import Corpus
from flair.datasets import IMDB
import _locale

# fix for rare Flair encoding bug - only needed on Windows
_locale._getdefaultlocale = (lambda *_: ['en_US', 'utf8'])

corpus = IMDB()
corpus.downsample(0.05)
lbl_type = 'sentiment'
label_dict = corpus.make_label_dictionary(label_type=lbl_type)

#### Loading and preparing the document embeddings

In [None]:
from flair.embeddings import (DocumentPoolEmbeddings,
                              WordEmbeddings)

glove = WordEmbeddings('glove')
document_embeddings = DocumentPoolEmbeddings([glove])

#### Initializing the TextClassifier class

In [None]:
from flair.models import TextClassifier

classifier = TextClassifier(document_embeddings,
                            label_dictionary=label_dict,
                            label_type=lbl_type)

#### Training the model

In [None]:
from flair.trainers import ModelTrainer

trainer = ModelTrainer(classifier, corpus)

trainer.train('classifier',
              learning_rate=0.1,
              mini_batch_size=32,
              max_epochs=40)

#### Loading and using custom text classifiers

In [None]:
from flair.models import TextClassifier

classifier = TextClassifier.load('classifier/best-model.pt')

In [None]:
from flair.data import Sentence

sentence = Sentence("great")
classifier.predict(sentence)

print(sentence.labels)

In [None]:
sentence = Sentence("bad")
classifier.predict(sentence)

print(sentence.labels)

In [None]:
sentence = Sentence("Not quite my cup of tea")
classifier.predict(sentence)

print(sentence.labels)

## Text classifiers that require little to no training data

### Zero-shot classification with TARS

#### Load the pre-trained TARS model for the right language

In [None]:
from flair.models import TARSClassifier

tars = TARSClassifier.load('tars-base')

#### Defining the classes

In [None]:
classes = ["positive", "negative"]

#### Generating zero-shot predictions with TARS 

In [None]:
from flair.data import Sentence 

sentence = Sentence("great")
tars.predict_zero_shot(sentence, classes)

print(sentence)

In [None]:
sentence = Sentence("bad")
tars.predict_zero_shot(sentence, classes)

print(sentence)

In [None]:
sentence = Sentence("Not quite my cup of tea")
tars.predict_zero_shot(sentence, classes)

print(sentence)

### Few-shot classification with TARS (requires GPU)

#### Loading a tagged corpus

In [None]:
from flair.data import Corpus
from flair.datasets import IMDB
import _locale

# fix for rare Flair encoding bug - only needed on Windows
_locale._getdefaultlocale = (lambda *_: ['en_US', 'utf8'])

corpus = IMDB()
corpus.downsample(0.01)
lbl_type = 'sentiment'
label_dict = corpus.make_label_dictionary(label_type=lbl_type)

#### Loading a TARS model and setting the task

In [None]:
from flair.models import TARSClassifier
from flair.data import Sentence

tars = TARSClassifier().load('tars-base')

tars.add_and_switch_to_new_task(
    task_name="sentiment_analysis",
    label_dictionary=label_dict,
    label_type=lbl_type)

#### Training few-shot TARS models

In [None]:
from flair.trainers import ModelTrainer

trainer = ModelTrainer(tars, corpus)
trainer.train(base_path='few-shot-tars',
              learning_rate=0.02,
              mini_batch_size=16,
              mini_batch_chunk_size=4,
              max_epochs=6)

#### Using custom few-shot TARS models

In [None]:
tars = TARSClassifier().load('few-shot-tars/best-model.pt')

In [None]:
from flair.data import Sentence 

sentence = Sentence("great")
tars.predict_zero_shot(sentence, classes)

print(sentence)

In [None]:
sentence = Sentence("bad")
tars.predict(sentence)

print(sentence)

In [None]:
sentence = Sentence("Not quite my cup of tea")
tars.predict(sentence)

print(sentence)