In [1]:
from typing import Iterator, List, Dict

In [2]:
import torch
import torch.optim as optim
import numpy as np

In [3]:
from allennlp.common.params import Params

In [4]:
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [43]:
from allennlp.data.dataset_readers import DatasetReader

In [6]:
from allennlp.common.file_utils import cached_path

In [7]:
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

In [8]:
from allennlp.data.vocabulary import Vocabulary

In [9]:
from allennlp.models import Model

In [10]:
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits

In [11]:
from allennlp.training.metrics import CategoricalAccuracy

In [12]:
from allennlp.data.iterators import BucketIterator

In [13]:
from allennlp.training.trainer import Trainer

In [14]:
from allennlp.predictors import SentenceTaggerPredictor

torch.manual_seed(1)

<torch._C.Generator at 0x7fbd6c104cf0>

In [34]:
from allennlp.data.token_indexers import TokenCharactersIndexer
import unicodedata, glob, os, string, random

In [83]:
# Methods from pytorch tutorial

all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

def readLines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicodeToAscii(line) for line in lines]

def findFiles(path): return glob.glob(path)

class Dataset_Reader(DatasetReader):
    
    # Implementing data reader from character RNN tutorial
    def __init__(self, token_indexers: Dict[str, SingleIdTokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        # Character indexers based on AllenNLP tutorial
        self.token_character_indexers = {"token_characters" : TokenCharactersIndexer()}
        
    def text_to_instance(self, tokens: List[str], tags: List[str] = None) -> Instance:
        name_tokens = [Token(name) for name in tokens]
        name_token_field = TextField(name_tokens, self.token_indexers)
        # Same logic for character tokens and fields as well
        fields = {"tokens": name_token_field}
        name_token_character_field = TextField(name_tokens, self.token_character_indexers)
        fields["token_characters"] = name_token_character_field
        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=name_token_field)
            fields["labels"] = label_field

        return Instance(fields)
    
    def _read(self, file_path: str) -> Iterator[Instance]:
        concatenated_pair = []
        for filename in findFiles(file_path):
            category = os.path.splitext(os.path.basename(filename))[0]
            all_categories.append(category)
            lines = readLines(filename)
            category_lines[category] = lines
            concatenated_pair.extend([(word, category) for word in lines])
        # Tried using names in bigger (language based) splits, did not work out, will be reported in the report
        # Make smaller splits someway
        # Trying making splits of 30 names. for that shuffle the names, otherwise same problem as before
        np.random.shuffle(concatenated_pair)
        #print(concatenated_pair)
        # Code for chunks from "https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks"
        for i in range(0, len(concatenated_pair), 30):
            shuffeled_split = concatenated_pair[i:i+30]
            yield self.text_to_instance([pair[0] for pair in shuffeled_split], [pair[1] for pair in shuffeled_split])

In [84]:
dataset_reader = Dataset_Reader()



In [85]:
category_lines = {}
all_categories = []
names_data = dataset_reader.read('data/names/*.txt')


0it [00:00, ?it/s][A
669it [00:00, 6581.72it/s][A

In [86]:
class NamesClassifier(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()
    def forward(self,
                tokens: Dict[str, torch.Tensor],
                token_characters: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        mask = get_text_field_mask(tokens)
        # Had to change the input for embeddings, give it as an input of dictionary with tokens and token characters
        # use ** to unpack them
        embeddings = self.word_embeddings({**tokens, **token_characters})
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [87]:
# Since we don't have different files for validation, use sklearn to split the data
from sklearn.model_selection import train_test_split
train_set, val_set = train_test_split(names_data, test_size=0.2)
vocab = Vocabulary.from_instances(names_data)


  0%|          | 0/669 [00:00<?, ?it/s][A
 61%|██████    | 405/669 [00:00<00:00, 4048.07it/s][A
100%|██████████| 669/669 [00:00<00:00, 3611.33it/s][A

In [88]:
vocab

Vocabulary with namespaces:  tokens, Size: 17423 || token_characters, Size: 57 || labels, Size: 18 || Non Padded Namespaces: {'*tags', '*labels'}

In [89]:
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
from allennlp.modules.token_embedders.token_characters_encoder import TokenCharactersEncoder

# Define the model
WORD_EMBEDDING_DIM = 3
CHAR_EMBEDDING_DIM = 3
HIDDEN_DIM = 6
EMBEDDING_DIM = WORD_EMBEDDING_DIM + CHAR_EMBEDDING_DIM

# Embeddings for characters and then for names
character_encoder = PytorchSeq2VecWrapper(torch.nn.RNN(CHAR_EMBEDDING_DIM, CHAR_EMBEDDING_DIM, batch_first=True))
token_character_embedding = Embedding(num_embeddings=vocab.get_vocab_size('token_characters'),
                            embedding_dim=WORD_EMBEDDING_DIM)
character_embeddings = TokenCharactersEncoder(token_character_embedding, character_encoder)

# Now names
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=WORD_EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding, "token_characters" : character_embeddings})
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
model = NamesClassifier(word_embeddings, lstm, vocab)

In [90]:
if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1

In [92]:
# Train the model - 30 epochs seem to give a pretty good baseline accuracy - 0.7 val accuracy
optimizer = optim.SGD(model.parameters(), lr=0.1)
iterator = BucketIterator(batch_size=2, sorting_keys=[("tokens", "num_tokens"), ("token_characters", "num_token_characters")])
iterator.index_with(vocab)
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_set,
                  validation_dataset=val_set,
                  patience=10,
                  num_epochs=30, 
                  cuda_device=cuda_device)
trainer.train()


  0%|          | 0/268 [00:00<?, ?it/s][A
accuracy: 0.5222, loss: 1.4904 ||:   1%|          | 3/268 [00:00<00:09, 27.78it/s][A
accuracy: 0.5619, loss: 1.4109 ||:   3%|▎         | 7/268 [00:00<00:09, 28.75it/s][A
accuracy: 0.5650, loss: 1.4230 ||:   4%|▎         | 10/268 [00:00<00:09, 27.75it/s][A
accuracy: 0.5625, loss: 1.4189 ||:   4%|▍         | 12/268 [00:00<00:11, 21.37it/s][A
accuracy: 0.5607, loss: 1.4317 ||:   5%|▌         | 14/268 [00:00<00:12, 19.61it/s][A
accuracy: 0.5559, loss: 1.4497 ||:   6%|▋         | 17/268 [00:00<00:11, 21.68it/s][A
accuracy: 0.5617, loss: 1.4420 ||:   7%|▋         | 20/268 [00:00<00:11, 22.40it/s][A
accuracy: 0.5601, loss: 1.4440 ||:   9%|▊         | 23/268 [00:00<00:10, 22.86it/s][A
accuracy: 0.5603, loss: 1.4412 ||:  10%|▉         | 26/268 [00:01<00:11, 21.33it/s][A
accuracy: 0.5615, loss: 1.4375 ||:  11%|█         | 29/268 [00:01<00:11, 21.36it/s][A
accuracy: 0.5603, loss: 1.4478 ||:  12%|█▏        | 32/268 [00:01<00:11, 20.70it/s][A
a

accuracy: 0.5515, loss: 1.4453 ||:  16%|█▋        | 11/67 [00:00<00:01, 44.40it/s][A
accuracy: 0.5433, loss: 1.4669 ||:  22%|██▏       | 15/67 [00:00<00:01, 41.57it/s][A
accuracy: 0.5447, loss: 1.4495 ||:  33%|███▎      | 22/67 [00:00<00:00, 45.36it/s][A
accuracy: 0.5417, loss: 1.4621 ||:  39%|███▉      | 26/67 [00:00<00:01, 40.84it/s][A
accuracy: 0.5394, loss: 1.4651 ||:  45%|████▍     | 30/67 [00:00<00:00, 38.90it/s][A
accuracy: 0.5412, loss: 1.4727 ||:  51%|█████     | 34/67 [00:00<00:00, 39.19it/s][A
accuracy: 0.5448, loss: 1.4747 ||:  72%|███████▏  | 48/67 [00:00<00:00, 49.55it/s][A
accuracy: 0.5481, loss: 1.4734 ||:  93%|█████████▎| 62/67 [00:01<00:00, 61.21it/s][A
accuracy: 0.5478, loss: 1.4724 ||: 100%|██████████| 67/67 [00:01<00:00, 64.29it/s][A
  0%|          | 0/268 [00:00<?, ?it/s][A
accuracy: 0.5333, loss: 1.4901 ||:   1%|▏         | 4/268 [00:00<00:07, 33.89it/s][A
accuracy: 0.5479, loss: 1.4573 ||:   3%|▎         | 8/268 [00:00<00:08, 30.98it/s][A
accuracy: 0

  0%|          | 0/67 [00:00<?, ?it/s][A
accuracy: 0.5222, loss: 1.4559 ||:   4%|▍         | 3/67 [00:00<00:02, 28.73it/s][A
accuracy: 0.5646, loss: 1.3812 ||:  12%|█▏        | 8/67 [00:00<00:01, 32.87it/s][A
accuracy: 0.5482, loss: 1.4735 ||:  28%|██▊       | 19/67 [00:00<00:01, 40.76it/s][A
accuracy: 0.5590, loss: 1.4449 ||:  36%|███▌      | 24/67 [00:00<00:01, 40.31it/s][A
accuracy: 0.5609, loss: 1.4377 ||:  43%|████▎     | 29/67 [00:00<00:00, 41.80it/s][A
accuracy: 0.5544, loss: 1.4422 ||:  51%|█████     | 34/67 [00:00<00:00, 41.88it/s][A
accuracy: 0.5481, loss: 1.4736 ||:  64%|██████▍   | 43/67 [00:00<00:00, 48.99it/s][A
accuracy: 0.5456, loss: 1.4798 ||:  73%|███████▎  | 49/67 [00:00<00:00, 48.87it/s][A
accuracy: 0.5494, loss: 1.4724 ||:  82%|████████▏ | 55/67 [00:01<00:00, 49.82it/s][A
accuracy: 0.5530, loss: 1.4625 ||:  91%|█████████ | 61/67 [00:01<00:00, 48.76it/s][A
accuracy: 0.5567, loss: 1.4577 ||: 100%|██████████| 67/67 [00:01<00:00, 53.47it/s][A
  0%|         

accuracy: 0.5695, loss: 1.4090 ||:  95%|█████████▍| 254/268 [00:11<00:00, 20.27it/s][A
accuracy: 0.5703, loss: 1.4073 ||:  96%|█████████▌| 257/268 [00:11<00:00, 20.40it/s][A
accuracy: 0.5703, loss: 1.4068 ||:  97%|█████████▋| 260/268 [00:11<00:00, 20.41it/s][A
accuracy: 0.5697, loss: 1.4077 ||:  98%|█████████▊| 263/268 [00:11<00:00, 21.01it/s][A
accuracy: 0.5698, loss: 1.4082 ||:  99%|█████████▉| 266/268 [00:11<00:00, 21.37it/s][A
accuracy: 0.5700, loss: 1.4074 ||: 100%|██████████| 268/268 [00:12<00:00, 22.17it/s][A
  0%|          | 0/67 [00:00<?, ?it/s][A
accuracy: 0.5625, loss: 1.4617 ||:  12%|█▏        | 8/67 [00:00<00:00, 73.08it/s][A
accuracy: 0.5774, loss: 1.4299 ||:  21%|██        | 14/67 [00:00<00:00, 66.83it/s][A
accuracy: 0.5850, loss: 1.4112 ||:  30%|██▉       | 20/67 [00:00<00:00, 62.22it/s][A
accuracy: 0.5846, loss: 1.4177 ||:  39%|███▉      | 26/67 [00:00<00:00, 58.94it/s][A
accuracy: 0.5839, loss: 1.4160 ||:  46%|████▋     | 31/67 [00:00<00:00, 55.55it/s][A
a

accuracy: 0.5790, loss: 1.3877 ||:  83%|████████▎ | 223/268 [00:09<00:01, 22.54it/s][A
accuracy: 0.5796, loss: 1.3860 ||:  84%|████████▍ | 226/268 [00:09<00:01, 23.78it/s][A
accuracy: 0.5797, loss: 1.3850 ||:  85%|████████▌ | 229/268 [00:10<00:01, 23.57it/s][A
accuracy: 0.5786, loss: 1.3885 ||:  87%|████████▋ | 232/268 [00:10<00:01, 23.60it/s][A
accuracy: 0.5794, loss: 1.3873 ||:  88%|████████▊ | 235/268 [00:10<00:01, 21.47it/s][A
accuracy: 0.5790, loss: 1.3890 ||:  89%|████████▉ | 238/268 [00:10<00:01, 19.61it/s][A
accuracy: 0.5788, loss: 1.3890 ||:  90%|████████▉ | 241/268 [00:10<00:01, 20.61it/s][A
accuracy: 0.5791, loss: 1.3876 ||:  91%|█████████ | 244/268 [00:10<00:01, 22.03it/s][A
accuracy: 0.5801, loss: 1.3836 ||:  93%|█████████▎| 248/268 [00:10<00:00, 25.05it/s][A
accuracy: 0.5799, loss: 1.3845 ||:  94%|█████████▎| 251/268 [00:11<00:00, 24.13it/s][A
accuracy: 0.5803, loss: 1.3836 ||:  95%|█████████▍| 254/268 [00:11<00:00, 24.16it/s][A
accuracy: 0.5805, loss: 1.3833 |

accuracy: 0.5877, loss: 1.3685 ||:  66%|██████▌   | 177/268 [00:08<00:04, 21.52it/s][A
accuracy: 0.5883, loss: 1.3659 ||:  67%|██████▋   | 180/268 [00:08<00:03, 22.65it/s][A
accuracy: 0.5886, loss: 1.3637 ||:  68%|██████▊   | 183/268 [00:08<00:04, 20.74it/s][A
accuracy: 0.5887, loss: 1.3616 ||:  69%|██████▉   | 186/268 [00:09<00:04, 19.95it/s][A
accuracy: 0.5886, loss: 1.3609 ||:  71%|███████   | 189/268 [00:09<00:04, 17.76it/s][A
accuracy: 0.5886, loss: 1.3608 ||:  71%|███████▏  | 191/268 [00:09<00:04, 18.11it/s][A
accuracy: 0.5873, loss: 1.3652 ||:  72%|███████▏  | 194/268 [00:09<00:03, 20.20it/s][A
accuracy: 0.5876, loss: 1.3636 ||:  74%|███████▎  | 197/268 [00:09<00:03, 21.00it/s][A
accuracy: 0.5882, loss: 1.3618 ||:  75%|███████▍  | 200/268 [00:09<00:03, 20.34it/s][A
accuracy: 0.5870, loss: 1.3659 ||:  76%|███████▌  | 203/268 [00:09<00:03, 19.42it/s][A
accuracy: 0.5869, loss: 1.3673 ||:  76%|███████▋  | 205/268 [00:10<00:03, 19.54it/s][A
accuracy: 0.5882, loss: 1.3639 |

accuracy: 0.5913, loss: 1.3473 ||:  54%|█████▎    | 144/268 [00:06<00:05, 21.80it/s][A
accuracy: 0.5903, loss: 1.3490 ||:  55%|█████▍    | 147/268 [00:06<00:05, 23.14it/s][A
accuracy: 0.5909, loss: 1.3464 ||:  56%|█████▌    | 150/268 [00:06<00:05, 22.35it/s][A
accuracy: 0.5916, loss: 1.3458 ||:  57%|█████▋    | 153/268 [00:07<00:05, 22.51it/s][A
accuracy: 0.5912, loss: 1.3487 ||:  58%|█████▊    | 156/268 [00:07<00:04, 22.66it/s][A
accuracy: 0.5907, loss: 1.3497 ||:  59%|█████▉    | 159/268 [00:07<00:04, 22.61it/s][A
accuracy: 0.5917, loss: 1.3473 ||:  60%|██████    | 162/268 [00:07<00:04, 23.60it/s][A
accuracy: 0.5919, loss: 1.3488 ||:  62%|██████▏   | 165/268 [00:07<00:04, 23.08it/s][A
accuracy: 0.5908, loss: 1.3512 ||:  63%|██████▎   | 168/268 [00:07<00:04, 22.86it/s][A
accuracy: 0.5904, loss: 1.3532 ||:  64%|██████▍   | 171/268 [00:07<00:04, 22.35it/s][A
accuracy: 0.5900, loss: 1.3559 ||:  65%|██████▍   | 174/268 [00:08<00:04, 22.33it/s][A
accuracy: 0.5904, loss: 1.3534 |

accuracy: 0.6002, loss: 1.3427 ||:  35%|███▌      | 94/268 [00:05<00:10, 17.35it/s][A
accuracy: 0.5988, loss: 1.3456 ||:  36%|███▌      | 96/268 [00:05<00:09, 17.41it/s][A
accuracy: 0.6003, loss: 1.3413 ||:  37%|███▋      | 99/268 [00:05<00:09, 17.97it/s][A
accuracy: 0.6007, loss: 1.3374 ||:  38%|███▊      | 102/268 [00:05<00:08, 19.84it/s][A
accuracy: 0.6010, loss: 1.3382 ||:  39%|███▉      | 105/268 [00:05<00:08, 18.27it/s][A
accuracy: 0.6005, loss: 1.3389 ||:  40%|███▉      | 107/268 [00:06<00:11, 13.55it/s][A
accuracy: 0.5997, loss: 1.3368 ||:  41%|████      | 109/268 [00:06<00:11, 13.65it/s][A
accuracy: 0.5995, loss: 1.3363 ||:  41%|████▏     | 111/268 [00:06<00:11, 13.33it/s][A
accuracy: 0.6000, loss: 1.3348 ||:  42%|████▏     | 113/268 [00:06<00:12, 12.83it/s][A
accuracy: 0.5999, loss: 1.3349 ||:  43%|████▎     | 115/268 [00:06<00:12, 12.00it/s][A
accuracy: 0.5999, loss: 1.3338 ||:  44%|████▎     | 117/268 [00:06<00:11, 12.87it/s][A
accuracy: 0.5999, loss: 1.3332 ||: 

accuracy: 0.6127, loss: 1.2814 ||:   6%|▋         | 17/268 [00:00<00:15, 16.73it/s][A
accuracy: 0.6042, loss: 1.3090 ||:   7%|▋         | 20/268 [00:01<00:13, 18.58it/s][A
accuracy: 0.6080, loss: 1.3074 ||:   9%|▊         | 23/268 [00:01<00:12, 19.14it/s][A
accuracy: 0.6080, loss: 1.3083 ||:   9%|▉         | 25/268 [00:01<00:12, 19.16it/s][A
accuracy: 0.6119, loss: 1.3071 ||:  10%|█         | 28/268 [00:01<00:11, 20.59it/s][A
accuracy: 0.6126, loss: 1.3109 ||:  12%|█▏        | 33/268 [00:01<00:09, 24.46it/s][A
accuracy: 0.6106, loss: 1.3163 ||:  13%|█▎        | 36/268 [00:01<00:09, 24.03it/s][A
accuracy: 0.6064, loss: 1.3239 ||:  15%|█▍        | 39/268 [00:01<00:09, 24.50it/s][A
accuracy: 0.6048, loss: 1.3241 ||:  16%|█▌        | 42/268 [00:01<00:09, 24.29it/s][A
accuracy: 0.6044, loss: 1.3261 ||:  17%|█▋        | 45/268 [00:02<00:09, 23.97it/s][A
accuracy: 0.6042, loss: 1.3311 ||:  18%|█▊        | 48/268 [00:02<00:09, 23.68it/s][A
accuracy: 0.6016, loss: 1.3363 ||:  19%|█▉ 

accuracy: 0.5778, loss: 1.3950 ||:   1%|          | 3/268 [00:00<00:11, 22.28it/s][A
accuracy: 0.5861, loss: 1.3441 ||:   2%|▏         | 6/268 [00:00<00:11, 22.91it/s][A
accuracy: 0.5963, loss: 1.3025 ||:   3%|▎         | 9/268 [00:00<00:10, 24.31it/s][A
accuracy: 0.5958, loss: 1.2938 ||:   4%|▍         | 12/268 [00:00<00:10, 24.50it/s][A
accuracy: 0.5978, loss: 1.3074 ||:   6%|▌         | 15/268 [00:00<00:10, 23.84it/s][A
accuracy: 0.5889, loss: 1.3418 ||:   7%|▋         | 18/268 [00:00<00:10, 24.35it/s][A
accuracy: 0.5903, loss: 1.3523 ||:   8%|▊         | 21/268 [00:00<00:10, 24.32it/s][A
accuracy: 0.5993, loss: 1.3438 ||:   9%|▉         | 24/268 [00:00<00:10, 23.62it/s][A
accuracy: 0.5938, loss: 1.3666 ||:  10%|█         | 27/268 [00:01<00:10, 23.97it/s][A
accuracy: 0.5921, loss: 1.3619 ||:  11%|█         | 30/268 [00:01<00:09, 24.43it/s][A
accuracy: 0.5929, loss: 1.3537 ||:  12%|█▏        | 33/268 [00:01<00:09, 24.78it/s][A
accuracy: 0.5972, loss: 1.3448 ||:  13%|█▎    

accuracy: 0.5625, loss: 1.4344 ||:  30%|██▉       | 20/67 [00:00<00:01, 44.80it/s][A
accuracy: 0.5713, loss: 1.4115 ||:  37%|███▋      | 25/67 [00:00<00:00, 44.41it/s][A
accuracy: 0.5772, loss: 1.3890 ||:  45%|████▍     | 30/67 [00:00<00:00, 44.88it/s][A
accuracy: 0.5767, loss: 1.4027 ||:  52%|█████▏    | 35/67 [00:00<00:00, 43.71it/s][A
accuracy: 0.5854, loss: 1.3875 ||:  60%|█████▉    | 40/67 [00:00<00:00, 43.72it/s][A
accuracy: 0.5922, loss: 1.3721 ||:  67%|██████▋   | 45/67 [00:00<00:00, 45.14it/s][A
accuracy: 0.5990, loss: 1.3578 ||:  76%|███████▌  | 51/67 [00:01<00:00, 47.30it/s][A
accuracy: 0.5962, loss: 1.3655 ||:  85%|████████▌ | 57/67 [00:01<00:00, 48.50it/s][A
accuracy: 0.5968, loss: 1.3601 ||:  93%|█████████▎| 62/67 [00:01<00:00, 48.02it/s][A
accuracy: 0.6017, loss: 1.3435 ||: 100%|██████████| 67/67 [00:01<00:00, 44.92it/s][A
  0%|          | 0/268 [00:00<?, ?it/s][A
accuracy: 0.6000, loss: 1.2002 ||:   1%|          | 2/268 [00:00<00:19, 13.88it/s][A
accuracy: 0

accuracy: 0.6144, loss: 1.2909 ||:  93%|█████████▎| 249/268 [00:10<00:00, 22.66it/s][A
accuracy: 0.6137, loss: 1.2933 ||:  94%|█████████▍| 252/268 [00:10<00:00, 22.50it/s][A
accuracy: 0.6138, loss: 1.2929 ||:  95%|█████████▌| 255/268 [00:11<00:00, 23.19it/s][A
accuracy: 0.6134, loss: 1.2936 ||:  96%|█████████▋| 258/268 [00:11<00:00, 23.57it/s][A
accuracy: 0.6132, loss: 1.2940 ||:  97%|█████████▋| 261/268 [00:11<00:00, 23.35it/s][A
accuracy: 0.6128, loss: 1.2949 ||:  99%|█████████▊| 264/268 [00:11<00:00, 20.26it/s][A
accuracy: 0.6134, loss: 1.2938 ||: 100%|█████████▉| 267/268 [00:11<00:00, 20.91it/s][A
accuracy: 0.6133, loss: 1.2935 ||: 100%|██████████| 268/268 [00:11<00:00, 22.94it/s][A
  0%|          | 0/67 [00:00<?, ?it/s][A
accuracy: 0.6300, loss: 1.2848 ||:   7%|▋         | 5/67 [00:00<00:01, 48.41it/s][A
accuracy: 0.6217, loss: 1.3000 ||:  15%|█▍        | 10/67 [00:00<00:01, 48.44it/s][A
accuracy: 0.6078, loss: 1.3299 ||:  22%|██▏       | 15/67 [00:00<00:01, 48.19it/s]

accuracy: 0.6182, loss: 1.2863 ||:  87%|████████▋ | 233/268 [00:10<00:01, 21.66it/s][A
accuracy: 0.6181, loss: 1.2863 ||:  88%|████████▊ | 236/268 [00:10<00:01, 17.88it/s][A
accuracy: 0.6178, loss: 1.2869 ||:  89%|████████▉ | 238/268 [00:10<00:01, 15.77it/s][A
accuracy: 0.6179, loss: 1.2874 ||:  90%|████████▉ | 240/268 [00:10<00:02, 12.87it/s][A
accuracy: 0.6179, loss: 1.2878 ||:  90%|█████████ | 242/268 [00:10<00:01, 13.94it/s][A
accuracy: 0.6185, loss: 1.2855 ||:  91%|█████████ | 244/268 [00:10<00:01, 12.42it/s][A
accuracy: 0.6183, loss: 1.2865 ||:  92%|█████████▏| 246/268 [00:11<00:01, 12.70it/s][A
accuracy: 0.6187, loss: 1.2860 ||:  93%|█████████▎| 249/268 [00:11<00:01, 15.07it/s][A
accuracy: 0.6190, loss: 1.2862 ||:  94%|█████████▍| 252/268 [00:11<00:00, 16.70it/s][A
accuracy: 0.6185, loss: 1.2865 ||:  95%|█████████▍| 254/268 [00:11<00:00, 16.51it/s][A
accuracy: 0.6188, loss: 1.2852 ||:  96%|█████████▌| 256/268 [00:11<00:00, 17.34it/s][A
accuracy: 0.6189, loss: 1.2861 |

accuracy: 0.6248, loss: 1.2809 ||:  69%|██████▊   | 184/268 [00:08<00:03, 21.49it/s][A
accuracy: 0.6247, loss: 1.2815 ||:  70%|██████▉   | 187/268 [00:08<00:03, 21.30it/s][A
accuracy: 0.6245, loss: 1.2832 ||:  71%|███████   | 190/268 [00:08<00:03, 20.94it/s][A
accuracy: 0.6242, loss: 1.2846 ||:  72%|███████▏  | 193/268 [00:08<00:03, 21.01it/s][A
accuracy: 0.6244, loss: 1.2835 ||:  73%|███████▎  | 196/268 [00:09<00:03, 21.43it/s][A
accuracy: 0.6258, loss: 1.2808 ||:  74%|███████▍  | 199/268 [00:09<00:03, 21.51it/s][A
accuracy: 0.6262, loss: 1.2785 ||:  75%|███████▌  | 202/268 [00:09<00:02, 22.92it/s][A
accuracy: 0.6260, loss: 1.2789 ||:  76%|███████▋  | 205/268 [00:09<00:02, 21.12it/s][A
accuracy: 0.6270, loss: 1.2769 ||:  78%|███████▊  | 208/268 [00:09<00:03, 19.61it/s][A
accuracy: 0.6274, loss: 1.2760 ||:  79%|███████▊  | 211/268 [00:09<00:02, 21.25it/s][A
accuracy: 0.6277, loss: 1.2767 ||:  80%|███████▉  | 214/268 [00:09<00:02, 21.80it/s][A
accuracy: 0.6278, loss: 1.2756 |

accuracy: 0.6239, loss: 1.2768 ||:  50%|████▉     | 133/268 [00:06<00:06, 20.20it/s][A
accuracy: 0.6251, loss: 1.2726 ||:  51%|█████     | 136/268 [00:06<00:06, 20.27it/s][A
accuracy: 0.6241, loss: 1.2764 ||:  52%|█████▏    | 139/268 [00:06<00:06, 20.03it/s][A
accuracy: 0.6233, loss: 1.2788 ||:  53%|█████▎    | 142/268 [00:07<00:06, 20.58it/s][A
accuracy: 0.6238, loss: 1.2795 ||:  54%|█████▍    | 145/268 [00:07<00:05, 21.42it/s][A
accuracy: 0.6244, loss: 1.2780 ||:  55%|█████▌    | 148/268 [00:07<00:05, 22.10it/s][A
accuracy: 0.6252, loss: 1.2740 ||:  56%|█████▋    | 151/268 [00:07<00:05, 22.21it/s][A
accuracy: 0.6246, loss: 1.2775 ||:  57%|█████▋    | 154/268 [00:07<00:05, 22.12it/s][A
accuracy: 0.6246, loss: 1.2773 ||:  59%|█████▊    | 157/268 [00:07<00:04, 22.58it/s][A
accuracy: 0.6249, loss: 1.2747 ||:  60%|█████▉    | 160/268 [00:07<00:05, 20.49it/s][A
accuracy: 0.6250, loss: 1.2767 ||:  61%|██████    | 163/268 [00:07<00:04, 21.63it/s][A
accuracy: 0.6247, loss: 1.2768 |

accuracy: 0.6334, loss: 1.2548 ||:  38%|███▊      | 103/268 [00:04<00:07, 21.00it/s][A
accuracy: 0.6328, loss: 1.2579 ||:  40%|███▉      | 106/268 [00:04<00:07, 21.18it/s][A
accuracy: 0.6327, loss: 1.2577 ||:  41%|████      | 109/268 [00:04<00:07, 21.76it/s][A
accuracy: 0.6342, loss: 1.2521 ||:  42%|████▏     | 112/268 [00:05<00:07, 21.56it/s][A
accuracy: 0.6330, loss: 1.2534 ||:  43%|████▎     | 115/268 [00:05<00:07, 21.66it/s][A
accuracy: 0.6337, loss: 1.2527 ||:  44%|████▍     | 118/268 [00:05<00:07, 21.35it/s][A
accuracy: 0.6337, loss: 1.2526 ||:  45%|████▌     | 121/268 [00:05<00:06, 21.98it/s][A
accuracy: 0.6346, loss: 1.2492 ||:  46%|████▋     | 124/268 [00:05<00:06, 21.46it/s][A
accuracy: 0.6347, loss: 1.2477 ||:  47%|████▋     | 127/268 [00:05<00:06, 21.70it/s][A
accuracy: 0.6364, loss: 1.2464 ||:  49%|████▊     | 130/268 [00:05<00:06, 22.01it/s][A
accuracy: 0.6359, loss: 1.2476 ||:  50%|████▉     | 133/268 [00:06<00:06, 21.66it/s][A
accuracy: 0.6359, loss: 1.2469 |

accuracy: 0.6373, loss: 1.2412 ||:  20%|██        | 54/268 [00:02<00:09, 22.05it/s][A
accuracy: 0.6386, loss: 1.2370 ||:  21%|██▏       | 57/268 [00:02<00:09, 21.42it/s][A
accuracy: 0.6406, loss: 1.2318 ||:  22%|██▏       | 60/268 [00:02<00:09, 21.03it/s][A
accuracy: 0.6387, loss: 1.2372 ||:  24%|██▎       | 63/268 [00:02<00:09, 21.18it/s][A
accuracy: 0.6372, loss: 1.2408 ||:  25%|██▍       | 66/268 [00:03<00:09, 21.68it/s][A
accuracy: 0.6341, loss: 1.2503 ||:  26%|██▌       | 69/268 [00:03<00:09, 21.21it/s][A
accuracy: 0.6345, loss: 1.2478 ||:  27%|██▋       | 72/268 [00:03<00:09, 21.05it/s][A
accuracy: 0.6347, loss: 1.2467 ||:  28%|██▊       | 75/268 [00:03<00:09, 20.27it/s][A
accuracy: 0.6355, loss: 1.2448 ||:  29%|██▉       | 78/268 [00:03<00:10, 18.54it/s][A
accuracy: 0.6377, loss: 1.2408 ||:  30%|██▉       | 80/268 [00:03<00:12, 15.63it/s][A
accuracy: 0.6382, loss: 1.2416 ||:  31%|███       | 82/268 [00:04<00:16, 11.25it/s][A
accuracy: 0.6375, loss: 1.2388 ||:  31%|███

accuracy: 0.6405, loss: 1.2422 ||:  96%|█████████▌| 256/268 [00:17<00:00, 13.95it/s][A
accuracy: 0.6410, loss: 1.2419 ||:  96%|█████████▋| 258/268 [00:17<00:00, 15.11it/s][A
accuracy: 0.6410, loss: 1.2420 ||:  97%|█████████▋| 260/268 [00:18<00:00, 15.71it/s][A
accuracy: 0.6408, loss: 1.2421 ||:  98%|█████████▊| 263/268 [00:18<00:00, 16.61it/s][A
accuracy: 0.6407, loss: 1.2422 ||:  99%|█████████▉| 265/268 [00:18<00:00, 15.85it/s][A
accuracy: 0.6406, loss: 1.2426 ||: 100%|█████████▉| 267/268 [00:18<00:00, 15.53it/s][A
accuracy: 0.6404, loss: 1.2429 ||: 100%|██████████| 268/268 [00:18<00:00, 14.43it/s][A
  0%|          | 0/67 [00:00<?, ?it/s][A
accuracy: 0.6458, loss: 1.2506 ||:   6%|▌         | 4/67 [00:00<00:02, 31.38it/s][A
accuracy: 0.6667, loss: 1.1951 ||:  12%|█▏        | 8/67 [00:00<00:01, 33.02it/s][A
accuracy: 0.6653, loss: 1.1899 ||:  18%|█▊        | 12/67 [00:00<00:01, 34.38it/s][A
accuracy: 0.6544, loss: 1.2395 ||:  28%|██▊       | 19/67 [00:00<00:01, 39.42it/s][A


accuracy: 0.6495, loss: 1.2268 ||:  74%|███████▍  | 199/268 [00:09<00:03, 19.55it/s][A
accuracy: 0.6497, loss: 1.2258 ||:  75%|███████▌  | 202/268 [00:09<00:03, 20.02it/s][A
accuracy: 0.6493, loss: 1.2274 ||:  76%|███████▋  | 205/268 [00:09<00:03, 20.87it/s][A
accuracy: 0.6488, loss: 1.2285 ||:  78%|███████▊  | 208/268 [00:10<00:03, 18.62it/s][A
accuracy: 0.6500, loss: 1.2257 ||:  78%|███████▊  | 210/268 [00:10<00:03, 16.55it/s][A
accuracy: 0.6502, loss: 1.2267 ||:  79%|███████▉  | 212/268 [00:10<00:03, 16.70it/s][A
accuracy: 0.6498, loss: 1.2286 ||:  80%|███████▉  | 214/268 [00:10<00:03, 17.34it/s][A
accuracy: 0.6498, loss: 1.2282 ||:  81%|████████  | 217/268 [00:10<00:02, 18.90it/s][A
accuracy: 0.6499, loss: 1.2275 ||:  82%|████████▏ | 219/268 [00:10<00:02, 19.16it/s][A
accuracy: 0.6502, loss: 1.2276 ||:  83%|████████▎ | 222/268 [00:10<00:02, 19.94it/s][A
accuracy: 0.6500, loss: 1.2273 ||:  84%|████████▍ | 225/268 [00:10<00:02, 19.51it/s][A
accuracy: 0.6500, loss: 1.2276 |

accuracy: 0.6588, loss: 1.1973 ||:  56%|█████▌    | 150/268 [00:06<00:05, 22.55it/s][A
accuracy: 0.6586, loss: 1.1967 ||:  57%|█████▋    | 153/268 [00:06<00:04, 23.30it/s][A
accuracy: 0.6576, loss: 1.1992 ||:  58%|█████▊    | 156/268 [00:06<00:04, 23.36it/s][A
accuracy: 0.6578, loss: 1.1978 ||:  59%|█████▉    | 159/268 [00:07<00:04, 23.01it/s][A
accuracy: 0.6575, loss: 1.1988 ||:  60%|██████    | 162/268 [00:07<00:04, 22.62it/s][A
accuracy: 0.6567, loss: 1.2002 ||:  62%|██████▏   | 165/268 [00:07<00:04, 22.79it/s][A
accuracy: 0.6567, loss: 1.1993 ||:  63%|██████▎   | 168/268 [00:07<00:04, 22.81it/s][A
accuracy: 0.6568, loss: 1.1990 ||:  64%|██████▍   | 171/268 [00:07<00:04, 22.96it/s][A
accuracy: 0.6566, loss: 1.2007 ||:  65%|██████▍   | 174/268 [00:07<00:04, 22.60it/s][A
accuracy: 0.6568, loss: 1.1996 ||:  66%|██████▌   | 177/268 [00:07<00:03, 22.77it/s][A
accuracy: 0.6574, loss: 1.1976 ||:  67%|██████▋   | 180/268 [00:08<00:03, 22.57it/s][A
accuracy: 0.6577, loss: 1.1969 |

accuracy: 0.6672, loss: 1.1794 ||:  37%|███▋      | 100/268 [00:05<00:07, 21.15it/s][A
accuracy: 0.6675, loss: 1.1791 ||:  38%|███▊      | 103/268 [00:05<00:07, 20.93it/s][A
accuracy: 0.6673, loss: 1.1775 ||:  40%|███▉      | 106/268 [00:05<00:07, 20.39it/s][A
accuracy: 0.6685, loss: 1.1769 ||:  41%|████      | 109/268 [00:05<00:07, 20.64it/s][A
accuracy: 0.6693, loss: 1.1726 ||:  42%|████▏     | 112/268 [00:05<00:07, 21.46it/s][A
accuracy: 0.6694, loss: 1.1726 ||:  43%|████▎     | 115/268 [00:05<00:07, 21.59it/s][A
accuracy: 0.6699, loss: 1.1708 ||:  44%|████▍     | 118/268 [00:05<00:06, 22.36it/s][A
accuracy: 0.6715, loss: 1.1675 ||:  45%|████▌     | 121/268 [00:06<00:06, 21.00it/s][A
accuracy: 0.6720, loss: 1.1675 ||:  46%|████▋     | 124/268 [00:06<00:06, 20.64it/s][A
accuracy: 0.6731, loss: 1.1650 ||:  47%|████▋     | 127/268 [00:06<00:06, 20.73it/s][A
accuracy: 0.6718, loss: 1.1683 ||:  49%|████▊     | 130/268 [00:06<00:06, 19.73it/s][A
accuracy: 0.6718, loss: 1.1683 |

accuracy: 0.6907, loss: 1.1563 ||:  10%|█         | 27/268 [00:02<00:16, 14.31it/s][A
accuracy: 0.6925, loss: 1.1495 ||:  11%|█         | 29/268 [00:02<00:15, 15.58it/s][A
accuracy: 0.6978, loss: 1.1345 ||:  12%|█▏        | 31/268 [00:02<00:15, 15.62it/s][A
accuracy: 0.6944, loss: 1.1374 ||:  12%|█▏        | 33/268 [00:02<00:14, 16.33it/s][A
accuracy: 0.6967, loss: 1.1267 ||:  13%|█▎        | 35/268 [00:02<00:13, 16.66it/s][A
accuracy: 0.6986, loss: 1.1162 ||:  14%|█▍        | 37/268 [00:02<00:14, 15.40it/s][A
accuracy: 0.6944, loss: 1.1198 ||:  15%|█▍        | 39/268 [00:02<00:14, 15.47it/s][A
accuracy: 0.6939, loss: 1.1157 ||:  15%|█▌        | 41/268 [00:02<00:15, 14.63it/s][A
accuracy: 0.6953, loss: 1.1147 ||:  16%|█▌        | 43/268 [00:03<00:14, 15.39it/s][A
accuracy: 0.6959, loss: 1.1123 ||:  17%|█▋        | 45/268 [00:03<00:13, 16.06it/s][A
accuracy: 0.6968, loss: 1.1108 ||:  18%|█▊        | 47/268 [00:03<00:14, 14.76it/s][A
accuracy: 0.6963, loss: 1.1169 ||:  18%|█▊ 

accuracy: 0.6910, loss: 1.1099 ||:  92%|█████████▏| 246/268 [00:15<00:02, 10.25it/s][A
accuracy: 0.6915, loss: 1.1090 ||:  93%|█████████▎| 248/268 [00:16<00:01, 10.66it/s][A
accuracy: 0.6917, loss: 1.1084 ||:  93%|█████████▎| 250/268 [00:16<00:01, 11.57it/s][A
accuracy: 0.6917, loss: 1.1082 ||:  94%|█████████▍| 252/268 [00:16<00:01, 12.08it/s][A
accuracy: 0.6919, loss: 1.1078 ||:  95%|█████████▍| 254/268 [00:16<00:01, 12.16it/s][A
accuracy: 0.6924, loss: 1.1059 ||:  96%|█████████▌| 256/268 [00:16<00:01, 11.97it/s][A
accuracy: 0.6924, loss: 1.1067 ||:  96%|█████████▋| 258/268 [00:16<00:00, 12.88it/s][A
accuracy: 0.6922, loss: 1.1074 ||:  97%|█████████▋| 260/268 [00:16<00:00, 11.77it/s][A
accuracy: 0.6927, loss: 1.1064 ||:  98%|█████████▊| 262/268 [00:17<00:00, 11.06it/s][A
accuracy: 0.6927, loss: 1.1072 ||:  99%|█████████▊| 264/268 [00:17<00:00, 11.33it/s][A
accuracy: 0.6928, loss: 1.1074 ||:  99%|█████████▉| 266/268 [00:17<00:00, 12.76it/s][A
accuracy: 0.6931, loss: 1.1065 |

accuracy: 0.7041, loss: 1.0728 ||:  47%|████▋     | 126/268 [00:09<00:12, 11.65it/s][A
accuracy: 0.7048, loss: 1.0711 ||:  48%|████▊     | 128/268 [00:09<00:10, 12.85it/s][A
accuracy: 0.7050, loss: 1.0712 ||:  49%|████▊     | 130/268 [00:09<00:09, 14.08it/s][A
accuracy: 0.7048, loss: 1.0729 ||:  49%|████▉     | 132/268 [00:09<00:10, 13.33it/s][A
accuracy: 0.7039, loss: 1.0748 ||:  50%|█████     | 134/268 [00:09<00:11, 12.16it/s][A
accuracy: 0.7043, loss: 1.0747 ||:  51%|█████     | 136/268 [00:09<00:10, 12.38it/s][A
accuracy: 0.7051, loss: 1.0730 ||:  51%|█████▏    | 138/268 [00:09<00:09, 13.22it/s][A
accuracy: 0.7053, loss: 1.0721 ||:  52%|█████▏    | 140/268 [00:10<00:08, 14.35it/s][A
accuracy: 0.7061, loss: 1.0696 ||:  53%|█████▎    | 142/268 [00:10<00:10, 12.24it/s][A
accuracy: 0.7062, loss: 1.0684 ||:  54%|█████▎    | 144/268 [00:10<00:10, 12.25it/s][A
accuracy: 0.7062, loss: 1.0668 ||:  54%|█████▍    | 146/268 [00:10<00:09, 13.13it/s][A
accuracy: 0.7062, loss: 1.0667 |

accuracy: 0.7047, loss: 1.0518 ||:  12%|█▏        | 32/268 [00:01<00:11, 20.00it/s][A
accuracy: 0.7005, loss: 1.0642 ||:  13%|█▎        | 35/268 [00:01<00:11, 20.98it/s][A
accuracy: 0.7061, loss: 1.0505 ||:  14%|█▍        | 38/268 [00:01<00:10, 21.45it/s][A
accuracy: 0.7037, loss: 1.0499 ||:  15%|█▌        | 41/268 [00:02<00:10, 21.95it/s][A
accuracy: 0.7023, loss: 1.0563 ||:  16%|█▋        | 44/268 [00:02<00:10, 21.47it/s][A
accuracy: 0.7071, loss: 1.0456 ||:  18%|█▊        | 47/268 [00:02<00:10, 21.91it/s][A
accuracy: 0.7067, loss: 1.0507 ||:  19%|█▊        | 50/268 [00:02<00:09, 22.36it/s][A
accuracy: 0.7075, loss: 1.0461 ||:  20%|█▉        | 53/268 [00:02<00:09, 22.10it/s][A
accuracy: 0.7080, loss: 1.0367 ||:  21%|██        | 56/268 [00:02<00:09, 22.57it/s][A
accuracy: 0.7079, loss: 1.0387 ||:  22%|██▏       | 59/268 [00:02<00:09, 22.31it/s][A
accuracy: 0.7089, loss: 1.0348 ||:  23%|██▎       | 62/268 [00:02<00:09, 22.29it/s][A
accuracy: 0.7067, loss: 1.0391 ||:  24%|██▍

accuracy: 0.6667, loss: 1.1453 ||:   1%|          | 2/268 [00:00<00:16, 16.25it/s][A
accuracy: 0.6958, loss: 1.0854 ||:   1%|▏         | 4/268 [00:00<00:15, 16.56it/s][A
accuracy: 0.7000, loss: 1.0981 ||:   3%|▎         | 7/268 [00:00<00:15, 17.28it/s][A
accuracy: 0.6833, loss: 1.1137 ||:   4%|▎         | 10/268 [00:00<00:13, 18.80it/s][A
accuracy: 0.6944, loss: 1.0754 ||:   4%|▍         | 12/268 [00:00<00:13, 19.04it/s][A
accuracy: 0.6967, loss: 1.0585 ||:   6%|▌         | 15/268 [00:00<00:12, 20.23it/s][A
accuracy: 0.7056, loss: 1.0529 ||:   7%|▋         | 18/268 [00:00<00:11, 21.14it/s][A
accuracy: 0.7127, loss: 1.0367 ||:   8%|▊         | 21/268 [00:01<00:12, 19.74it/s][A
accuracy: 0.7264, loss: 1.0011 ||:   9%|▉         | 24/268 [00:01<00:11, 20.46it/s][A
accuracy: 0.7253, loss: 1.0047 ||:  10%|█         | 27/268 [00:01<00:11, 20.93it/s][A
accuracy: 0.7183, loss: 1.0065 ||:  11%|█         | 30/268 [00:01<00:10, 22.03it/s][A
accuracy: 0.7187, loss: 1.0022 ||:  12%|█▏    

accuracy: 0.6978, loss: 1.1138 ||:  22%|██▏       | 15/67 [00:00<00:01, 41.20it/s][A
accuracy: 0.7000, loss: 1.0919 ||:  30%|██▉       | 20/67 [00:00<00:01, 41.70it/s][A
accuracy: 0.7027, loss: 1.0799 ||:  37%|███▋      | 25/67 [00:00<00:00, 43.35it/s][A
accuracy: 0.6983, loss: 1.0973 ||:  45%|████▍     | 30/67 [00:00<00:00, 44.60it/s][A
accuracy: 0.7065, loss: 1.0789 ||:  54%|█████▎    | 36/67 [00:00<00:00, 46.67it/s][A
accuracy: 0.7067, loss: 1.0744 ||:  63%|██████▎   | 42/67 [00:00<00:00, 49.37it/s][A
accuracy: 0.7089, loss: 1.0708 ||:  70%|███████   | 47/67 [00:00<00:00, 47.80it/s][A
accuracy: 0.7077, loss: 1.0708 ||:  78%|███████▊  | 52/67 [00:01<00:00, 46.53it/s][A
accuracy: 0.7056, loss: 1.0796 ||:  85%|████████▌ | 57/67 [00:01<00:00, 45.20it/s][A
accuracy: 0.6989, loss: 1.0909 ||:  93%|█████████▎| 62/67 [00:01<00:00, 45.44it/s][A
accuracy: 0.6988, loss: 1.0919 ||: 100%|██████████| 67/67 [00:01<00:00, 47.00it/s][A
  0%|          | 0/268 [00:00<?, ?it/s][A
accuracy: 0

accuracy: 0.7106, loss: 1.0061 ||:  91%|█████████ | 244/268 [00:11<00:01, 23.33it/s][A
accuracy: 0.7107, loss: 1.0067 ||:  92%|█████████▏| 247/268 [00:11<00:00, 23.80it/s][A
accuracy: 0.7103, loss: 1.0075 ||:  93%|█████████▎| 250/268 [00:11<00:00, 23.88it/s][A
accuracy: 0.7106, loss: 1.0076 ||:  94%|█████████▍| 253/268 [00:11<00:00, 24.31it/s][A
accuracy: 0.7103, loss: 1.0084 ||:  96%|█████████▌| 256/268 [00:11<00:00, 24.85it/s][A
accuracy: 0.7105, loss: 1.0080 ||:  97%|█████████▋| 259/268 [00:11<00:00, 24.02it/s][A
accuracy: 0.7102, loss: 1.0084 ||:  98%|█████████▊| 262/268 [00:11<00:00, 22.25it/s][A
accuracy: 0.7100, loss: 1.0084 ||:  99%|█████████▉| 265/268 [00:11<00:00, 22.78it/s][A
accuracy: 0.7099, loss: 1.0086 ||: 100%|██████████| 268/268 [00:12<00:00, 23.25it/s][A
  0%|          | 0/67 [00:00<?, ?it/s][A
accuracy: 0.7467, loss: 0.9202 ||:   7%|▋         | 5/67 [00:00<00:01, 45.68it/s][A
accuracy: 0.6867, loss: 1.1053 ||:  15%|█▍        | 10/67 [00:00<00:01, 46.71it/s

accuracy: 0.7103, loss: 0.9953 ||:  80%|████████  | 215/268 [00:09<00:02, 21.02it/s][A
accuracy: 0.7097, loss: 0.9979 ||:  81%|████████▏ | 218/268 [00:09<00:02, 22.44it/s][A
accuracy: 0.7095, loss: 0.9985 ||:  82%|████████▏ | 221/268 [00:09<00:02, 22.90it/s][A
accuracy: 0.7092, loss: 0.9993 ||:  84%|████████▎ | 224/268 [00:09<00:01, 23.89it/s][A
accuracy: 0.7085, loss: 1.0006 ||:  85%|████████▍ | 227/268 [00:10<00:01, 21.07it/s][A
accuracy: 0.7084, loss: 1.0004 ||:  86%|████████▌ | 230/268 [00:10<00:01, 20.80it/s][A
accuracy: 0.7085, loss: 1.0003 ||:  87%|████████▋ | 233/268 [00:10<00:01, 21.97it/s][A
accuracy: 0.7090, loss: 0.9991 ||:  88%|████████▊ | 236/268 [00:10<00:01, 22.43it/s][A
accuracy: 0.7089, loss: 0.9993 ||:  89%|████████▉ | 239/268 [00:10<00:01, 22.32it/s][A
accuracy: 0.7085, loss: 0.9996 ||:  90%|█████████ | 242/268 [00:10<00:01, 22.80it/s][A
accuracy: 0.7089, loss: 0.9984 ||:  91%|█████████▏| 245/268 [00:10<00:00, 23.21it/s][A
accuracy: 0.7085, loss: 0.9989 |

accuracy: 0.7121, loss: 0.9795 ||:  60%|██████    | 162/268 [00:08<00:06, 17.49it/s][A
accuracy: 0.7111, loss: 0.9832 ||:  62%|██████▏   | 165/268 [00:08<00:05, 18.90it/s][A
accuracy: 0.7105, loss: 0.9847 ||:  63%|██████▎   | 168/268 [00:08<00:04, 20.01it/s][A
accuracy: 0.7111, loss: 0.9832 ||:  64%|██████▍   | 171/268 [00:08<00:04, 20.42it/s][A
accuracy: 0.7105, loss: 0.9846 ||:  65%|██████▍   | 174/268 [00:08<00:04, 21.10it/s][A
accuracy: 0.7111, loss: 0.9832 ||:  66%|██████▌   | 177/268 [00:09<00:04, 20.97it/s][A
accuracy: 0.7109, loss: 0.9832 ||:  67%|██████▋   | 180/268 [00:09<00:04, 20.85it/s][A
accuracy: 0.7114, loss: 0.9808 ||:  68%|██████▊   | 183/268 [00:09<00:03, 21.63it/s][A
accuracy: 0.7113, loss: 0.9805 ||:  69%|██████▉   | 186/268 [00:09<00:03, 22.03it/s][A
accuracy: 0.7110, loss: 0.9811 ||:  71%|███████   | 189/268 [00:09<00:03, 22.38it/s][A
accuracy: 0.7109, loss: 0.9809 ||:  72%|███████▏  | 192/268 [00:09<00:03, 23.21it/s][A
accuracy: 0.7096, loss: 0.9845 |

accuracy: 0.7127, loss: 0.9714 ||:  48%|████▊     | 129/268 [00:05<00:05, 24.76it/s][A
accuracy: 0.7131, loss: 0.9718 ||:  49%|████▉     | 132/268 [00:05<00:05, 24.32it/s][A
accuracy: 0.7125, loss: 0.9743 ||:  50%|█████     | 135/268 [00:05<00:05, 24.84it/s][A
accuracy: 0.7132, loss: 0.9735 ||:  51%|█████▏    | 138/268 [00:05<00:05, 25.45it/s][A
accuracy: 0.7125, loss: 0.9756 ||:  53%|█████▎    | 141/268 [00:05<00:05, 24.02it/s][A
accuracy: 0.7121, loss: 0.9759 ||:  54%|█████▎    | 144/268 [00:06<00:05, 24.21it/s][A
accuracy: 0.7122, loss: 0.9774 ||:  55%|█████▍    | 147/268 [00:06<00:04, 24.29it/s][A
accuracy: 0.7127, loss: 0.9756 ||:  56%|█████▌    | 150/268 [00:06<00:04, 23.84it/s][A
accuracy: 0.7140, loss: 0.9729 ||:  57%|█████▋    | 153/268 [00:06<00:04, 23.98it/s][A
accuracy: 0.7135, loss: 0.9755 ||:  58%|█████▊    | 156/268 [00:06<00:04, 23.36it/s][A
accuracy: 0.7137, loss: 0.9757 ||:  59%|█████▉    | 159/268 [00:06<00:04, 23.20it/s][A
accuracy: 0.7128, loss: 0.9776 |

accuracy: 0.7055, loss: 0.9857 ||:  33%|███▎      | 88/268 [00:04<00:07, 23.78it/s][A
accuracy: 0.7066, loss: 0.9840 ||:  34%|███▍      | 91/268 [00:04<00:07, 24.36it/s][A
accuracy: 0.7091, loss: 0.9768 ||:  35%|███▌      | 94/268 [00:04<00:07, 24.35it/s][A
accuracy: 0.7112, loss: 0.9715 ||:  36%|███▌      | 97/268 [00:04<00:07, 23.71it/s][A
accuracy: 0.7104, loss: 0.9746 ||:  37%|███▋      | 100/268 [00:04<00:06, 24.06it/s][A
accuracy: 0.7127, loss: 0.9668 ||:  38%|███▊      | 103/268 [00:04<00:07, 22.21it/s][A
accuracy: 0.7128, loss: 0.9651 ||:  40%|███▉      | 106/268 [00:05<00:07, 22.75it/s][A
accuracy: 0.7133, loss: 0.9633 ||:  41%|████      | 109/268 [00:05<00:07, 22.70it/s][A
accuracy: 0.7121, loss: 0.9656 ||:  42%|████▏     | 112/268 [00:05<00:07, 21.34it/s][A
accuracy: 0.7115, loss: 0.9679 ||:  43%|████▎     | 115/268 [00:05<00:06, 22.11it/s][A
accuracy: 0.7120, loss: 0.9670 ||:  44%|████▍     | 118/268 [00:05<00:06, 22.30it/s][A
accuracy: 0.7119, loss: 0.9665 ||:  

accuracy: 0.7202, loss: 0.9538 ||:  18%|█▊        | 49/268 [00:02<00:09, 23.60it/s][A
accuracy: 0.7187, loss: 0.9620 ||:  19%|█▉        | 52/268 [00:02<00:08, 24.51it/s][A
accuracy: 0.7192, loss: 0.9571 ||:  21%|██        | 55/268 [00:02<00:08, 24.15it/s][A
accuracy: 0.7198, loss: 0.9585 ||:  22%|██▏       | 58/268 [00:02<00:08, 23.68it/s][A
accuracy: 0.7191, loss: 0.9593 ||:  23%|██▎       | 61/268 [00:02<00:08, 23.08it/s][A
accuracy: 0.7174, loss: 0.9612 ||:  24%|██▍       | 64/268 [00:02<00:08, 23.02it/s][A
accuracy: 0.7159, loss: 0.9666 ||:  25%|██▌       | 67/268 [00:02<00:08, 23.56it/s][A
accuracy: 0.7183, loss: 0.9579 ||:  26%|██▌       | 70/268 [00:03<00:08, 23.85it/s][A
accuracy: 0.7203, loss: 0.9533 ||:  27%|██▋       | 73/268 [00:03<00:08, 23.88it/s][A
accuracy: 0.7202, loss: 0.9526 ||:  28%|██▊       | 76/268 [00:03<00:08, 23.92it/s][A
accuracy: 0.7235, loss: 0.9429 ||:  29%|██▉       | 79/268 [00:03<00:07, 24.79it/s][A
accuracy: 0.7236, loss: 0.9399 ||:  31%|███

accuracy: 0.7378, loss: 0.8486 ||:   6%|▌         | 15/268 [00:00<00:12, 19.79it/s][A
accuracy: 0.7417, loss: 0.8486 ||:   7%|▋         | 18/268 [00:00<00:12, 20.68it/s][A
accuracy: 0.7405, loss: 0.8725 ||:   8%|▊         | 21/268 [00:00<00:11, 21.27it/s][A
accuracy: 0.7403, loss: 0.8852 ||:   9%|▉         | 24/268 [00:01<00:11, 22.02it/s][A
accuracy: 0.7358, loss: 0.8865 ||:  10%|█         | 27/268 [00:01<00:11, 21.56it/s][A
accuracy: 0.7350, loss: 0.8891 ||:  11%|█         | 30/268 [00:01<00:10, 21.90it/s][A
accuracy: 0.7338, loss: 0.8938 ||:  12%|█▏        | 33/268 [00:01<00:10, 23.11it/s][A
accuracy: 0.7333, loss: 0.8967 ||:  13%|█▎        | 36/268 [00:01<00:10, 23.14it/s][A
accuracy: 0.7333, loss: 0.8965 ||:  15%|█▍        | 39/268 [00:01<00:09, 23.32it/s][A
accuracy: 0.7345, loss: 0.8898 ||:  16%|█▌        | 42/268 [00:01<00:09, 23.50it/s][A
accuracy: 0.7330, loss: 0.8987 ||:  17%|█▋        | 45/268 [00:02<00:09, 23.20it/s][A
accuracy: 0.7291, loss: 0.9100 ||:  18%|█▊ 

accuracy: 0.6996, loss: 1.0550 ||:  63%|██████▎   | 42/67 [00:00<00:00, 47.92it/s][A
accuracy: 0.7011, loss: 1.0575 ||:  70%|███████   | 47/67 [00:00<00:00, 47.03it/s][A
accuracy: 0.7016, loss: 1.0532 ||:  78%|███████▊  | 52/67 [00:01<00:00, 47.19it/s][A
accuracy: 0.7014, loss: 1.0569 ||:  87%|████████▋ | 58/67 [00:01<00:00, 50.28it/s][A
accuracy: 0.7046, loss: 1.0451 ||:  97%|█████████▋| 65/67 [00:01<00:00, 53.23it/s][A
accuracy: 0.7030, loss: 1.0468 ||: 100%|██████████| 67/67 [00:01<00:00, 50.16it/s][A
  0%|          | 0/268 [00:00<?, ?it/s][A
accuracy: 0.7583, loss: 0.8135 ||:   1%|          | 2/268 [00:00<00:14, 18.93it/s][A
accuracy: 0.7400, loss: 0.8501 ||:   2%|▏         | 5/268 [00:00<00:13, 19.82it/s][A
accuracy: 0.7500, loss: 0.8547 ||:   3%|▎         | 8/268 [00:00<00:12, 21.06it/s][A
accuracy: 0.7485, loss: 0.8487 ||:   4%|▍         | 11/268 [00:00<00:11, 21.59it/s][A
accuracy: 0.7440, loss: 0.8613 ||:   5%|▌         | 14/268 [00:00<00:11, 22.05it/s][A
accuracy:

accuracy: 0.7296, loss: 0.9069 ||:  98%|█████████▊| 263/268 [00:11<00:00, 23.89it/s][A
accuracy: 0.7299, loss: 0.9058 ||:  99%|█████████▉| 266/268 [00:11<00:00, 23.56it/s][A
accuracy: 0.7304, loss: 0.9051 ||: 100%|██████████| 268/268 [00:11<00:00, 23.51it/s][A
  0%|          | 0/67 [00:00<?, ?it/s][A
accuracy: 0.7333, loss: 0.9521 ||:   7%|▋         | 5/67 [00:00<00:01, 46.98it/s][A
accuracy: 0.7030, loss: 1.0352 ||:  16%|█▋        | 11/67 [00:00<00:01, 49.47it/s][A
accuracy: 0.7245, loss: 0.9720 ||:  25%|██▌       | 17/67 [00:00<00:00, 51.99it/s][A
accuracy: 0.7043, loss: 1.0357 ||:  34%|███▍      | 23/67 [00:00<00:00, 53.79it/s][A
accuracy: 0.7103, loss: 1.0209 ||:  43%|████▎     | 29/67 [00:00<00:00, 53.26it/s][A
accuracy: 0.7088, loss: 1.0182 ||:  54%|█████▎    | 36/67 [00:00<00:00, 55.61it/s][A
accuracy: 0.7052, loss: 1.0350 ||:  63%|██████▎   | 42/67 [00:00<00:00, 53.69it/s][A
accuracy: 0.7000, loss: 1.0450 ||:  72%|███████▏  | 48/67 [00:00<00:00, 51.96it/s][A
accurac

{'best_epoch': 28,
 'peak_cpu_memory_MB': 780.648,
 'training_duration': '00:07:13',
 'training_start_epoch': 0,
 'training_epochs': 29,
 'epoch': 29,
 'training_accuracy': 0.7303805364940736,
 'training_loss': 0.9050629405833003,
 'training_cpu_memory_MB': 780.648,
 'validation_accuracy': 0.7007462686567164,
 'validation_loss': 1.049820852813436,
 'best_validation_accuracy': 0.7029850746268657,
 'best_validation_loss': 1.0467508195051507}

In [107]:
# Manually test predictions
from allennlp.predictors import Predictor

class OwnPredictor(Predictor):
    # Takes as input model and dataset
    
    # define function for prediction, code similar to tutorial
    def predict_language(self, names):
        tag_logits = self.predict_instance(self._dataset_reader.text_to_instance(names)) 
        # No method to get instances other than using hidden variables
        tag_ids = np.argmax(tag_logits['tag_logits'], axis=-1)
        return [self._model.vocab.get_token_from_index(i, 'labels') for i in tag_ids]

predictor = OwnPredictor(model, dataset_reader=dataset_reader)
print(predictor.predict_language(["Kuznetsov", "Schneider", "Washington", "Lindemann", "Müller"]))

['Russian', 'English', 'English', 'English', 'English']
