# Text classification using ElMo, BERT, GPT2

## Import packages

In [0]:
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns
import keras.layers as layers
from keras.models import Model
from keras import backend as K
np.random.seed(10)

## Downloading the model

In [0]:
!pip install git+https://github.com/zalandoresearch/flair.git

Collecting git+https://github.com/zalandoresearch/flair.git
  Cloning https://github.com/zalandoresearch/flair.git to /tmp/pip-req-build-c5fux2nl
  Running command git clone -q https://github.com/zalandoresearch/flair.git /tmp/pip-req-build-c5fux2nl
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: flair
  Building wheel for flair (PEP 517) ... [?25l[?25hdone
  Created wheel for flair: filename=flair-0.4.3-cp36-none-any.whl size=113131 sha256=288fe9da66cea325a2dfcc2fa33a347ca1eead0ba717f360e3346dc7f12de2e1
  Stored in directory: /tmp/pip-ephem-wheel-cache-gc8eebxq/wheels/6a/78/0f/399330241d3bc69458cc4fe320dcdfbf818f9887803f0294e7
Successfully built flair


In [0]:
import flair

In [0]:
from flair.embeddings import BertEmbeddings,ELMoEmbeddings,OpenAIGPTEmbeddings

# init embedding
embed = BertEmbeddings()

100%|██████████| 231508/231508 [00:00<00:00, 1217314.09B/s]
100%|██████████| 313/313 [00:00<00:00, 78064.88B/s]
100%|██████████| 440473133/440473133 [00:12<00:00, 34378837.40B/s]


In [0]:
from flair.models import TextClassifier
from flair.data import Sentence
classifier = TextClassifier.load('en-sentiment')
sentence = Sentence('Flair is pretty neat!')
classifier.predict(sentence)
# print sentence with predicted labels
print('Sentence above is: ', sentence.labels)

2019-09-25 02:43:34,037 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/models-v0.4/classy-imdb-en-rnn-cuda%3A0/imdb-v0.4.pt not found in cache, downloading to /tmp/tmp1emr07fz


100%|██████████| 1501979561/1501979561 [01:28<00:00, 17015115.62B/s]

2019-09-25 02:45:02,986 copying /tmp/tmp1emr07fz to cache at /root/.flair/models/imdb-v0.4.pt





2019-09-25 02:45:10,784 removing temp file /tmp/tmp1emr07fz
2019-09-25 02:45:10,971 loading file /root/.flair/models/imdb-v0.4.pt
Sentence above is:  [POSITIVE (0.6636107563972473)]


In [0]:
!ls

sample_data  test_data.txt  train_5500.txt


## Download data

In [0]:
!wget https://raw.githubusercontent.com/Tony607/Keras-Text-Transfer-Learning/master/train_5500.txt
!wget https://raw.githubusercontent.com/Tony607/Keras-Text-Transfer-Learning/master/test_data.txt

--2019-09-25 02:45:29--  https://raw.githubusercontent.com/Tony607/Keras-Text-Transfer-Learning/master/train_5500.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 335860 (328K) [text/plain]
Saving to: ‘train_5500.txt’


2019-09-25 02:45:29 (11.3 MB/s) - ‘train_5500.txt’ saved [335860/335860]

--2019-09-25 02:45:31--  https://raw.githubusercontent.com/Tony607/Keras-Text-Transfer-Learning/master/test_data.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 23354 (23K) [text/plain]
Saving to: ‘test_data.txt’


2019-09-25 02:45:31 (3.17 MB/s

## Decscription of Data

The dataset we use is the TREC Question Classification dataset, There are entirely 5452 training and 500 test samples, that is 5452 + 500 questions each categorized into one of the six labels.

- ABBR - 'abbreviation': expression abbreviated, etc.
- DESC - 'description and abstract concepts': manner of an action, description of sth. etc.
- ENTY - 'entities': animals, colors, events, food, etc.
- HUM - 'human beings': a group or organization of persons, an individual, etc.
- LOC - 'locations': cities, countries, etc.
- NUM - 'numeric values': postcodes, dates, speed,temperature, etc

In [0]:
#Extract lines from .txt and convert to dataframe

def get_dataframe(filename):
    lines = open(filename, 'r').read().splitlines()
    data = []
    for i in range(0, len(lines)):
        label = lines[i].split(' ')[0]
        label = label.split(":")[0]
        text = ' '.join(lines[i].split(' ')[1:])
        text = re.sub('[^A-Za-z0-9 ,\?\'\"-._\+\!/\`@=;:]+', '', text)
        data.append([label, text])

    df = pd.DataFrame(data, columns=['label', 'text'])
    df.label = df.label.astype('category')
    return df



In [0]:
#Assign train data
df_train = get_dataframe('train_5500.txt')
print(df_train.head())
df_train.to_csv('train.csv')
df_test = get_dataframe('test_data.txt')
print(df_test.head())
df_test.to_csv('test.csv')

  label                                               text
0  DESC  How did serfdom develop in and then leave Russ...
1  ENTY   What films featured the character Popeye Doyle ?
2  DESC  How can I find a list of celebrities ' real na...
3  ENTY  What fowl grabs the spotlight after the Chines...
4  ABBR                    What is the full form of .com ?
  label                                      text
0   NUM      How far is it from Denver to Aspen ?
1   LOC  What county is Modesto , California in ?
2   HUM                         Who was Galileo ?
3  DESC                         What is an atom ?
4   NUM          When did Hawaii become a state ?


In [0]:
#Number of categories in dataset
category_counts = len(df_train.label.cat.categories)
category_counts

6

In [0]:
import pandas as pd
data = pd.read_csv("./train.csv", encoding='latin-1').sample(frac=1).drop_duplicates()
data = data[['label', 'text']].rename(columns={"v1":"label", "v2":"text"})
 
data['label'] = '__label__' + data['label'].astype(str)

data.iloc[0:int(len(data)*0.8)].to_csv('train.csv', sep='\t', index = False, header = False, columns=['label', 'text'])
data.iloc[int(len(data)*0.8):int(len(data)*0.9)].to_csv('test.csv', sep='\t', index = False, header = False, columns=['label', 'text'])
data.iloc[int(len(data)*0.9):].to_csv('dev.csv', sep='\t', index = False, header = False, columns=['label', 'text'])

In [0]:
from flair.data_fetcher import NLPTaskDataFetcher
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentLSTMEmbeddings, BertEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path


corpus = NLPTaskDataFetcher.load_classification_corpus(Path('./'), test_file='test.csv', dev_file='dev.csv', train_file='train.csv')
word_embeddings = [WordEmbeddings('glove'), BertEmbeddings('bert-base-uncased'), FlairEmbeddings('news-backward-fast')]
document_embeddings = DocumentLSTMEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256)


classifier = TextClassifier(document_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False)

trainer = ModelTrainer(classifier, corpus)
trainer.train('./', max_epochs=10)

2019-09-25 02:45:39,245 Reading data from .
2019-09-25 02:45:39,247 Train: train.csv
2019-09-25 02:45:39,248 Dev: dev.csv
2019-09-25 02:45:39,250 Test: test.csv


  
  train_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
  test_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
  dev_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc


2019-09-25 02:45:40,794 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/glove.gensim.vectors.npy not found in cache, downloading to /tmp/tmp7bd8xn68


100%|██████████| 160000128/160000128 [00:10<00:00, 14875900.75B/s]

2019-09-25 02:45:52,265 copying /tmp/tmp7bd8xn68 to cache at /root/.flair/embeddings/glove.gensim.vectors.npy





2019-09-25 02:45:52,677 removing temp file /tmp/tmp7bd8xn68
2019-09-25 02:45:53,312 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/glove.gensim not found in cache, downloading to /tmp/tmpnqy5yd5v


100%|██████████| 21494764/21494764 [00:02<00:00, 8721699.00B/s]

2019-09-25 02:45:56,445 copying /tmp/tmpnqy5yd5v to cache at /root/.flair/embeddings/glove.gensim





2019-09-25 02:45:56,476 removing temp file /tmp/tmpnqy5yd5v


  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


2019-09-25 02:46:02,893 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/lm-news-english-backward-1024-v0.2rc.pt not found in cache, downloading to /tmp/tmp150ghs0u


100%|██████████| 19689779/19689779 [00:02<00:00, 7946637.99B/s]

2019-09-25 02:46:06,035 copying /tmp/tmp150ghs0u to cache at /root/.flair/embeddings/lm-news-english-backward-1024-v0.2rc.pt
2019-09-25 02:46:06,060 removing temp file /tmp/tmp150ghs0u





2019-09-25 02:46:06,873 Computing label dictionary. Progress:


  # Remove the CWD from sys.path while we load stuff.
100%|██████████| 4361/4361 [00:00<00:00, 234748.39it/s]

2019-09-25 02:46:06,898 [b'HUM', b'LOC', b'NUM', b'ABBR', b'ENTY', b'DESC']
2019-09-25 02:46:06,914 ----------------------------------------------------------------------------------------------------
2019-09-25 02:46:06,919 Model: "TextClassifier(
  (document_embeddings): DocumentLSTMEmbeddings(
    (embeddings): StackedEmbeddings(
      (list_embedding_0): WordEmbeddings('glove')
      (list_embedding_1): BertEmbeddings(
        (model): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(30522, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (token_type_embeddings): Embedding(2, 768)
            (LayerNorm): LayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0): BertLayer(
                (attention): BertAttention(
                  (self): BertSelfAttention(
 




2019-09-25 02:46:10,812 epoch 1 - iter 0/137 - loss 1.80171144 - samples/sec: 107.74
2019-09-25 02:47:01,253 epoch 1 - iter 13/137 - loss 2.26821761 - samples/sec: 8.27
2019-09-25 02:47:53,379 epoch 1 - iter 26/137 - loss 2.10939177 - samples/sec: 8.00
2019-09-25 02:48:45,516 epoch 1 - iter 39/137 - loss 1.99575513 - samples/sec: 7.99
2019-09-25 02:49:36,032 epoch 1 - iter 52/137 - loss 1.94803076 - samples/sec: 8.25
2019-09-25 02:50:28,462 epoch 1 - iter 65/137 - loss 1.83922567 - samples/sec: 7.95
2019-09-25 02:51:16,016 epoch 1 - iter 78/137 - loss 1.80030416 - samples/sec: 8.77
2019-09-25 02:52:08,141 epoch 1 - iter 91/137 - loss 1.72820347 - samples/sec: 8.00
2019-09-25 02:53:00,244 epoch 1 - iter 104/137 - loss 1.69269149 - samples/sec: 8.00
2019-09-25 02:53:50,726 epoch 1 - iter 117/137 - loss 1.64486988 - samples/sec: 8.26
2019-09-25 02:54:41,693 epoch 1 - iter 130/137 - loss 1.59708728 - samples/sec: 8.18
2019-09-25 02:55:02,360 ------------------------------------------------

  result = unpickler.load()


2019-09-25 03:04:04,235 0.9138	0.9138	0.9138
2019-09-25 03:04:04,236 
MICRO_AVG: acc 0.8412 - f1-score 0.9138
MACRO_AVG: acc 0.7806 - f1-score 0.8656833333333332
ABBR       tp: 4 - fp: 1 - fn: 4 - tn: 536 - precision: 0.8000 - recall: 0.5000 - accuracy: 0.4444 - f1-score: 0.6154
DESC       tp: 98 - fp: 12 - fn: 16 - tn: 419 - precision: 0.8909 - recall: 0.8596 - accuracy: 0.7778 - f1-score: 0.8750
ENTY       tp: 86 - fp: 21 - fn: 9 - tn: 429 - precision: 0.8037 - recall: 0.9053 - accuracy: 0.7414 - f1-score: 0.8515
HUM        tp: 121 - fp: 3 - fn: 7 - tn: 414 - precision: 0.9758 - recall: 0.9453 - accuracy: 0.9237 - f1-score: 0.9603
LOC        tp: 86 - fp: 3 - fn: 11 - tn: 445 - precision: 0.9663 - recall: 0.8866 - accuracy: 0.8600 - f1-score: 0.9247
NUM        tp: 103 - fp: 7 - fn: 0 - tn: 435 - precision: 0.9364 - recall: 1.0000 - accuracy: 0.9364 - f1-score: 0.9672
2019-09-25 03:04:04,241 -----------------------------------------------------------------------------------------------

{'test_score': 0.9138,
 'dev_score_history': [0.6319,
  0.7125,
  0.7143,
  0.6447,
  0.7967,
  0.8462,
  0.8626,
  0.9066,
  0.8974,
  0.9139],
 'train_loss_history': [1.583923603061342,
  0.9675393302510255,
  0.7819692168357598,
  0.6384952681778121,
  0.5309577492901879,
  0.43726150593618407,
  0.37412173333611803,
  0.2846662643073249,
  0.24728865875271114,
  0.2069828187244652],
 'dev_loss_history': [tensor(0.8732),
  tensor(0.7114),
  tensor(0.7528),
  tensor(1.0911),
  tensor(0.4972),
  tensor(0.4475),
  tensor(0.4473),
  tensor(0.2831),
  tensor(0.2592),
  tensor(0.2423)]}

In [0]:
from flair.models import TextClassifier
from flair.data import Sentence
classifier = TextClassifier.load('./best-model.pt')
sentence = Sentence("What is the full form of .col?")
classifier.predict(sentence)
print(sentence.labels)

2019-09-25 03:05:15,188 loading file ./best-model.pt


  result = unpickler.load()


[ABBR (0.5854447484016418)]


In [0]:
!pip install allennlp

Collecting allennlp
[?25l  Downloading https://files.pythonhosted.org/packages/3f/bc/e30325523363215c503171822f09436adcfbc74f426ad62496276f1ac4c0/allennlp-0.8.5-py3-none-any.whl (7.4MB)
[K     |████████████████████████████████| 7.5MB 4.2MB/s 
Collecting jsonnet>=0.10.0; sys_platform != "win32" (from allennlp)
[?25l  Downloading https://files.pythonhosted.org/packages/fe/a6/e69e38f1f259fcf8532d8bd2c4bc88764f42d7b35a41423a7f4b035cc5ce/jsonnet-0.14.0.tar.gz (253kB)
[K     |████████████████████████████████| 256kB 44.0MB/s 
[?25hCollecting flaky (from allennlp)
  Downloading https://files.pythonhosted.org/packages/fe/12/0f169abf1aa07c7edef4855cca53703d2e6b7ecbded7829588ac7e7e3424/flaky-3.6.1-py2.py3-none-any.whl
Collecting overrides (from allennlp)
  Downloading https://files.pythonhosted.org/packages/de/55/3100c6d14c1ed177492fcf8f07c4a7d2d6c996c0a7fc6a9a0a41308e7eec/overrides-1.9.tar.gz
Collecting word2number>=1.1 (from allennlp)
  Downloading https://files.pythonhosted.org/packages/4

In [0]:
from flair.data_fetcher import NLPTaskDataFetcher
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentLSTMEmbeddings, BertEmbeddings, ELMoEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path


corpus = NLPTaskDataFetcher.load_classification_corpus(Path('./'), test_file='test.csv', dev_file='dev.csv', train_file='train.csv')
word_embeddings = [WordEmbeddings('glove'), ELMoEmbeddings(), FlairEmbeddings('news-backward-fast')]
document_embeddings = DocumentLSTMEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256)


classifier = TextClassifier(document_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False)
trainer = ModelTrainer(classifier, corpus)
trainer.train('./', max_epochs=10)

2019-09-25 03:06:36,779 Reading data from .
2019-09-25 03:06:36,780 Train: train.csv
2019-09-25 03:06:36,781 Dev: dev.csv
2019-09-25 03:06:36,784 Test: test.csv


  
  train_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
  test_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
  dev_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL
100%|██████████| 336/336 [00:00<00:00, 474666.94B/s]
100%|██████████| 374434792/374434792 [00:07<00:00, 51693678.58B/s]


2019-09-25 03:07:06,373 Computing label dictionary. Progress:


  # Remove the CWD from sys.path while we load stuff.
100%|██████████| 4361/4361 [00:00<00:00, 237054.47it/s]

2019-09-25 03:07:06,398 [b'HUM', b'LOC', b'NUM', b'ABBR', b'ENTY', b'DESC']
2019-09-25 03:07:06,466 ----------------------------------------------------------------------------------------------------
2019-09-25 03:07:06,467 Model: "TextClassifier(
  (document_embeddings): DocumentLSTMEmbeddings(
    (embeddings): StackedEmbeddings(
      (list_embedding_0): WordEmbeddings('glove')
      (list_embedding_1): ELMoEmbeddings(model=1-elmo-original)
      (list_embedding_2): FlairEmbeddings(
        (lm): LanguageModel(
          (drop): Dropout(p=0.25)
          (encoder): Embedding(275, 100)
          (rnn): LSTM(100, 1024)
          (decoder): Linear(in_features=1024, out_features=275, bias=True)
        )
      )
    )
    (word_reprojection_map): Linear(in_features=4196, out_features=256, bias=True)
    (rnn): GRU(256, 512)
    (dropout): Dropout(p=0.5)
  )
  (decoder): Linear(in_features=512, out_features=6, bias=True)
  (loss_function): CrossEntropyLoss()
)"
2019-09-25 03:07:06,468 -




2019-09-25 03:07:12,068 epoch 1 - iter 0/137 - loss 1.91484761 - samples/sec: 74.84
2019-09-25 03:08:12,691 epoch 1 - iter 13/137 - loss 2.08656985 - samples/sec: 6.86
2019-09-25 03:09:13,455 epoch 1 - iter 26/137 - loss 1.98000769 - samples/sec: 6.85
2019-09-25 03:10:15,857 epoch 1 - iter 39/137 - loss 1.91287484 - samples/sec: 6.67
2019-09-25 03:11:11,251 epoch 1 - iter 52/137 - loss 1.79564144 - samples/sec: 7.51
2019-09-25 03:12:10,509 epoch 1 - iter 65/137 - loss 1.74097147 - samples/sec: 7.02
2019-09-25 03:13:07,229 epoch 1 - iter 78/137 - loss 1.68728253 - samples/sec: 7.34
2019-09-25 03:14:04,707 epoch 1 - iter 91/137 - loss 1.63104457 - samples/sec: 7.24
2019-09-25 03:14:58,317 epoch 1 - iter 104/137 - loss 1.56980777 - samples/sec: 7.76
2019-09-25 03:15:55,117 epoch 1 - iter 117/137 - loss 1.52973626 - samples/sec: 7.33
2019-09-25 03:16:52,279 epoch 1 - iter 130/137 - loss 1.49661636 - samples/sec: 7.28
2019-09-25 03:17:14,696 -------------------------------------------------

  result = unpickler.load()


2019-09-25 03:26:16,224 0.9266	0.9266	0.9266
2019-09-25 03:26:16,226 
MICRO_AVG: acc 0.8632 - f1-score 0.9266
MACRO_AVG: acc 0.8701 - f1-score 0.9283833333333332
ABBR       tp: 8 - fp: 1 - fn: 0 - tn: 536 - precision: 0.8889 - recall: 1.0000 - accuracy: 0.8889 - f1-score: 0.9412
DESC       tp: 92 - fp: 2 - fn: 22 - tn: 429 - precision: 0.9787 - recall: 0.8070 - accuracy: 0.7931 - f1-score: 0.8846
ENTY       tp: 89 - fp: 27 - fn: 6 - tn: 423 - precision: 0.7672 - recall: 0.9368 - accuracy: 0.7295 - f1-score: 0.8436
HUM        tp: 120 - fp: 3 - fn: 8 - tn: 414 - precision: 0.9756 - recall: 0.9375 - accuracy: 0.9160 - f1-score: 0.9562
LOC        tp: 94 - fp: 5 - fn: 3 - tn: 443 - precision: 0.9495 - recall: 0.9691 - accuracy: 0.9216 - f1-score: 0.9592
NUM        tp: 102 - fp: 2 - fn: 1 - tn: 440 - precision: 0.9808 - recall: 0.9903 - accuracy: 0.9714 - f1-score: 0.9855
2019-09-25 03:26:16,227 -------------------------------------------------------------------------------------------------

{'test_score': 0.9266,
 'dev_score_history': [0.5549,
  0.6978,
  0.6172,
  0.8388,
  0.8974,
  0.8974,
  0.9139,
  0.8974,
  0.8938,
  0.9249],
 'train_loss_history': [1.4833933989496997,
  0.8925306229260717,
  0.5600962136348668,
  0.34675389033381954,
  0.2526464889465022,
  0.19069881340230468,
  0.15098025739519266,
  0.10854372590861834,
  0.09353612496578781,
  0.07901349325365231],
 'dev_loss_history': [tensor(1.2081),
  tensor(0.8407),
  tensor(1.3467),
  tensor(0.5290),
  tensor(0.3598),
  tensor(0.3356),
  tensor(0.2985),
  tensor(0.3592),
  tensor(0.4098),
  tensor(0.3187)]}

In [0]:

from flair.models import TextClassifier
from flair.data import Sentence
classifier = TextClassifier.load('./best-model.pt')
sentence = Sentence("How many flowers are there in this park?")
classifier.predict(sentence)
print(sentence.labels)

2019-09-25 03:31:27,086 loading file ./best-model.pt


  result = unpickler.load()


[NUM (0.9999963045120239)]


In [0]:
from flair.data_fetcher import NLPTaskDataFetcher
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentLSTMEmbeddings, BertEmbeddings, ELMoEmbeddings, OpenAIGPTEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path


corpus = NLPTaskDataFetcher.load_classification_corpus(Path('./'), test_file='test.csv', dev_file='dev.csv', train_file='train.csv')
word_embeddings = [OpenAIGPTEmbeddings()]
document_embeddings = DocumentRNNEmbeddings(word_embeddings, hidden_size=1024, reproject_words=True, reproject_words_dimension=32)


classifier = TextClassifier(document_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False)
trainer = ModelTrainer(classifier, corpus)
trainer.train('./', max_epochs=10)

2019-09-24 11:53:36,816 Reading data from .
2019-09-24 11:53:36,818 Train: train.csv
2019-09-24 11:53:36,820 Dev: dev.csv
2019-09-24 11:53:36,821 Test: test.csv


  
  train_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
  test_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
  dev_file, tokenizer=tokenizer, max_tokens_per_doc=max_tokens_per_doc
100%|██████████| 815973/815973 [00:00<00:00, 5613426.57B/s]
100%|██████████| 458495/458495 [00:00<00:00, 3796334.51B/s]
ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.
100%|██████████| 273/273 [00:00<00:00, 59206.05B/s]
100%|██████████| 478750579/478750579 [00:09<00:00, 51783555.13B/s]


2019-09-24 11:53:55,035 Computing label dictionary. Progress:


100%|██████████| 4361/4361 [00:00<00:00, 268142.78it/s]

2019-09-24 11:53:55,057 [b'HUM', b'LOC', b'NUM', b'ABBR', b'ENTY', b'DESC']
2019-09-24 11:53:55,062 ----------------------------------------------------------------------------------------------------
2019-09-24 11:53:55,064 Model: "TextClassifier(
  (document_embeddings): DocumentRNNEmbeddings(
    (embeddings): StackedEmbeddings(
      (list_embedding_0): OpenAIGPTEmbeddings(
        model=0-openai-gpt
        (model): OpenAIGPTModel(
          (tokens_embed): Embedding(40478, 768)
          (positions_embed): Embedding(512, 768)
          (drop): Dropout(p=0.1)
          (h): ModuleList(
            (0): Block(
              (attn): Attention(
                (c_attn): Conv1D()
                (c_proj): Conv1D()
                (attn_dropout): Dropout(p=0.1)
                (resid_dropout): Dropout(p=0.1)
              )
              (ln_1): LayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
              (mlp): MLP(
                (c_fc): Conv1D()
                (c




2019-09-24 11:53:59,673 epoch 1 - iter 0/137 - loss 1.80280530 - samples/sec: 90.95
2019-09-24 11:54:54,549 epoch 1 - iter 13/137 - loss 1.72893476 - samples/sec: 7.59
2019-09-24 11:55:56,126 epoch 1 - iter 26/137 - loss 1.68462484 - samples/sec: 6.76
2019-09-24 11:56:56,275 epoch 1 - iter 39/137 - loss 1.67050674 - samples/sec: 6.92
2019-09-24 11:57:51,298 epoch 1 - iter 52/137 - loss 1.64905939 - samples/sec: 7.57
2019-09-24 11:58:53,562 epoch 1 - iter 65/137 - loss 1.63552886 - samples/sec: 6.69


IndexError: ignored