In [115]:
import tempfile
from typing import Dict, Iterable, List, Tuple

import allennlp
import torch
from allennlp.data import DataLoader, DatasetReader, Instance, Vocabulary
from allennlp.data.fields import LabelField, TextField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, WhitespaceTokenizer
from allennlp.models import Model
from allennlp.modules import TextFieldEmbedder, Seq2VecEncoder
from allennlp.modules.seq2vec_encoders import BagOfEmbeddingsEncoder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.nn import util
from allennlp.training.trainer import GradientDescentTrainer, Trainer
from allennlp.training.optimizers import AdamOptimizer
from allennlp.training.metrics import CategoricalAccuracy

from os.path import join as pathjoin
import pandas as pd
from allennlp.predictors import TextClassifierPredictor
from allennlp.training.metrics import CategoricalAccuracy
import numpy as np

Let's download all the data.

In [147]:
DATA_DIR = '/home/mlepekhin/data'
MODELS_DIR = '/home/mlepekhin/models'
MODEL_ID = 'allennlp_simple'

In [99]:
class ClassificationDatasetReader(DatasetReader):
    def __init__(self,
                 lazy: bool = False,
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 max_tokens: int = None):
        super().__init__(lazy)
        self.tokenizer = tokenizer or WhitespaceTokenizer()
        self.token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
        self.max_tokens = max_tokens
        
    def text_to_instance(self, strings: List[str], label: str = None) -> Instance:
        tokens = self.tokenizer.tokenize(strings)
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"text": sentence_field}
        if label is not None:
            fields["label"] = LabelField(label)
        return Instance(fields)

    def _read(self, file_path: str) -> Iterable[Instance]:
        dataset_df = pd.read_csv(file_path)
        for text, label in zip(dataset_df['text'], dataset_df['target']):
            tokens = self.tokenizer.tokenize(text)
            text_field = TextField(tokens, self.token_indexers)
            yield Instance({'text': text_field, 'label': LabelField(label)})

In [104]:
class SimpleClassifier(Model):
    def __init__(self,
                 vocab: Vocabulary,
                 embedder: TextFieldEmbedder,
                 encoder: Seq2VecEncoder):
        super().__init__(vocab)
        self.embedder = embedder
        self.encoder = encoder
        num_labels = vocab.get_vocab_size("labels")
        self.classifier = torch.nn.Linear(encoder.get_output_dim(), num_labels)
        self.accuracy = CategoricalAccuracy()

    def forward(self,
                text: Dict[str, torch.Tensor],
                label: torch.Tensor=None) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, num_tokens, embedding_dim)
        embedded_text = self.embedder(text)
        # Shape: (batch_size, num_tokens)
        mask = util.get_text_field_mask(text)
        # Shape: (batch_size, encoding_dim)
        encoded_text = self.encoder(embedded_text, mask)
        # Shape: (batch_size, num_labels)
        logits = self.classifier(encoded_text)
        # Shape: (batch_size, num_labels)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        if label is not None:
            # Shape: (1,)
            loss = torch.nn.functional.cross_entropy(logits, label)
            self.accuracy(logits, label)
            return {'loss': loss, 'probs': probs}
        else:
            return {'probs': probs}
    
    def get_metrics(self, reset: bool = True) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}


def build_dataset_reader() -> DatasetReader:
    return ClassificationDatasetReader()


def read_data(reader: DatasetReader) -> Tuple[Iterable[Instance], Iterable[Instance]]:
    print("Reading data")
    training_data = reader.read(pathjoin(DATA_DIR, "multi_train"))
    validation_data = reader.read(pathjoin(DATA_DIR, "multi_test"))
    return training_data, validation_data


def build_vocab(instances: Iterable[Instance]) -> Vocabulary:
    print("Building the vocabulary")
    return Vocabulary.from_instances(instances)


def build_model(vocab: Vocabulary) -> Model:
    print("Building the model")
    vocab_size = vocab.get_vocab_size("tokens")
    embedder = BasicTextFieldEmbedder(
        {"tokens": Embedding(embedding_dim=10, num_embeddings=vocab_size)})
    encoder = BagOfEmbeddingsEncoder(embedding_dim=10)
    return SimpleClassifier(vocab, embedder, encoder)


def run_training_loop():
    dataset_reader = build_dataset_reader()

    # These are a subclass of pytorch Datasets, with some allennlp-specific
    # functionality added.
    train_data, dev_data = read_data(dataset_reader)

    vocab = build_vocab(train_data + dev_data)
    model = build_model(vocab)

    # This is the allennlp-specific functionality in the Dataset object;
    # we need to be able convert strings in the data to integers, and this
    # is how we do it.
    train_data.index_with(vocab)
    dev_data.index_with(vocab)

    # These are again a subclass of pytorch DataLoaders, with an
    # allennlp-specific collate function, that runs our indexing and
    # batching code.
    train_loader, dev_loader = build_data_loaders(train_data, dev_data)

    # You obviously won't want to create a temporary file for your training
    # results, but for execution in binder for this course, we need to do this.
    with tempfile.TemporaryDirectory() as serialization_dir:
        trainer = build_trainer(
            model,
            serialization_dir,
            train_loader,
            dev_loader
        )
        print("Starting training")
        trainer.train()
        print("Finished training")
    return trainer


# The other `build_*` methods are things we've seen before, so they are
# in the setup section above.
def build_data_loaders(
    train_data: torch.utils.data.Dataset,
    dev_data: torch.utils.data.Dataset,
) -> Tuple[allennlp.data.DataLoader, allennlp.data.DataLoader]:
    # Note that DataLoader is imported from allennlp above, *not* torch.
    # We need to get the allennlp-specific collate function, which is
    # what actually does indexing and batching.
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    dev_loader = DataLoader(dev_data, batch_size=32, shuffle=False)
    return train_loader, dev_loader


def build_trainer(
    model: Model,
    serialization_dir: str,
    train_loader: DataLoader,
    dev_loader: DataLoader,
    num_epochs: int = 1
) -> Trainer:
    parameters = [
        [n, p]
        for n, p in model.named_parameters() if p.requires_grad
    ]
    optimizer = AdamOptimizer(parameters)
    trainer = GradientDescentTrainer(
        model=model,
        serialization_dir=serialization_dir,
        data_loader=train_loader,
        validation_data_loader=dev_loader,
        num_epochs=num_epochs,
        optimizer=optimizer,
    )
    return trainer

In [105]:
#trainer = run_training_loop()
'''
dataset_reader = build_dataset_reader()

train_data, dev_data = read_data(dataset_reader)

vocab = build_vocab(train_data + dev_data)

train_data.index_with(vocab)
dev_data.index_with(vocab)
'''

'\ndataset_reader = build_dataset_reader()\n\ntrain_data, dev_data = read_data(dataset_reader)\n\nvocab = build_vocab(train_data + dev_data)\n\ntrain_data.index_with(vocab)\ndev_data.index_with(vocab)\n'

In [106]:
model = build_model(vocab)

Building the model


In [107]:
train_loader, dev_loader = build_data_loaders(train_data, dev_data)

# You obviously won't want to create a temporary file for your training
# results, but for execution in binder for this course, we need to do this.
with tempfile.TemporaryDirectory() as serialization_dir:
    trainer = build_trainer(
        model,
        serialization_dir,
        train_loader,
        dev_loader,
        5
    )
    print("Starting training")
    trainer.train()
    print("Finished training")

You provided a validation dataset but patience was set to None, meaning that early stopping is disabled


Starting training


HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))


Finished training


In [148]:
import torch

torch.save(model.state_dict, pathjoin(MODELS_DIR, MODEL_ID, 'model'))

In [149]:
vocab.save_to_files(pathjoin(MODELS_DIR, MODEL_ID, 'vocab'))

In [155]:
id_to_label = []
with open(pathjoin(MODELS_DIR, MODEL_ID, 'vocab', 'labels.txt')) as vocab_in:
    for line in vocab_in:
        id_to_label.append(line.strip())
print(id_to_label)

['A1', 'A8', 'A12', 'A14', 'A16', 'A7', 'A11', 'A17', 'A4', 'A9', 'A22']


In [158]:
def predict_classes(sentence_list):
    predictor = TextClassifierPredictor(model, dataset_reader=build_dataset_reader())
    result = [id_to_label[np.argmax(predictor.predict(sentence)["probs"])]\
              for sentence in sentence_list]
    return result

In [159]:
predict_classes(['Здесь должно быть ваше сообщение',
                 'Коты - лучшие домашние животные. К такому выводу пришли эксперты из издания New York Times',
                 'It is no more than what it is.',
                 'Жила я как-то с парнем. Я только вот на днях уволилась с работы, так как мне тяжело было работать сутки через сутки, должна была выходить на другую работу. И именно в этот период, мне сильно поплохело, начались жуткие головные боли, слабость, обмороки. Парень настоял, что нужно срочно вызывать врача. Приехала скорая, фельдшер мужик лет 50 весь седой. Позадавал вопросы мне, где болит, как болит, и зачем болит? Кто такая вообще по жизни, и чем занимаюсь? Смерил давление, температуру, написал что-то в своих бумагах, и дав лишь рекомендацию: "больше отдыхайте, пейте воду, гуляйте на свежем воздухе" пошёл на выход.'])

['A1', 'A8', 'A1', 'A11']

This is the time to interpret our simple classifier.

In [166]:
from allennlp.interpret.saliency_interpreters import SmoothGradient

In [168]:
smooth_grad_interpr = SmoothGradient(predictor)

In [177]:
smooth_grad_interpr.saliency_interpret_from_json({'sentence': 'Жила я как-то с парнем. Я только вот на днях уволилась с работы, так как мне тяжело было работать сутки через сутки, должна была выходить на другую работу. И именно в этот период, мне сильно поплохело, начались жуткие головные боли, слабость, обмороки. Парень настоял, что нужно срочно вызывать врача. Приехала скорая, фельдшер мужик лет 50 весь седой. Позадавал вопросы мне, где болит, как болит, и зачем болит? Кто такая вообще по жизни, и чем занимаюсь? Смерил давление, температуру, написал что-то в своих бумагах, и дав лишь рекомендацию: "больше отдыхайте, пейте воду, гуляйте на свежем воздухе" пошёл на выход.'})

{'instance_1': {'grad_input_1': [0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614602,
   0.010204082900614