# Introduction

<center><h3>**Welcome to the Summarization Notebook.**</h3></center>

In this assignment, you are going to train a neural network to summarize news articles.
Your neural network is going to learn from example, as we provide you with (article, summary) pairs.
We provide you with a **toy dataset** made of only articles about police related news.
Usual datasets can be 20x larger in size, but we have reduced it for computational purposes.

You will do this using a Transformer network, from the __[Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)__ paper.
In this assignment you will:
- Learn to process text into sub-word tokens, to avoid fixed vocabulary sizes, and UNK tokens.
- Implement the key conceptual blocks of a Transformer.
- Use a Transformer to read a news article, and produce a summary.
- Perform operations on learned word-vectors to examine what the model has learned.

    
** Before you start **

You should read the Attention is all you need paper.
We are providing you with skeleton code for the Transformer, but there will have to implement 5 conceptual blocks of the transformer yourself:
-  AttentionQKV: the Query, Key, Value attention mechanism at the center of the Transformer
- MultiHeadAttention: the multiple heads that enable each input to attend at many places at once.
- PositionEmbedding: the sinusoid-based position embedding of the Transformer.
- Encoder & Decoder: The encoder (that reads inputs, such as news articles), the decoder (that produces the output summary, one token at a time)
- Full Transformer: piecing it all together.

All dataset files should be placed in the `dataset/` folder of this assignment.

If you are using Google Colab, follow the instructions to mount your Google Drive onto the remote machine.

# Library imports

In [1]:
!pip install segtok
!pip install sentencepiece

Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/40/62/ac0a3bf69c1149e509bc32c3efaf0a3e309a51ea3dfa8f8f7a42895c99fa/sentencepiece-0.1.95-cp37-cp37m-macosx_10_6_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 1.2MB/s eta 0:00:01
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.95


Run the first of the following two cells if you are running the homework locally, and run the second cell if you are running the homework in Colab

In [2]:
DRIVE=False
root_folder = ""
dataset_folder = "dataset/"

In [None]:
from google.colab import drive
drive.mount('/content/drive')
root_folder = "/content/drive/My Drive/cs182_hw3/"
dataset_folder = "/content/drive/My Drive/cs182_hw3_public/dataset/"

In [3]:
# This cell autoreloads the notebook when you change you python file code.
# If you think the notebook did not reload, rerun this cell.
%load_ext autoreload
%autoreload 2

In [30]:
import os
import sys
sys.path.append(root_folder)
#from transformer import Transformer
import sentencepiece as spm
import torch as th
from torch import nn
from torch.nn import functional as F
from torch import optim
import numpy as np
import json
import capita
import os
from transformer_utils import set_device
import gc
from utils import validate_to_array, model_out_to_list

list_to_device = lambda th_obj: [tensor.to(device) for tensor in th_obj]
device = th.device("cuda" if th.cuda.is_available() else "cpu")

In [7]:
# Load the word piece model that will be used to tokenize the texts into
# word pieces with a vocabulary size of 10000
sp = spm.SentencePieceProcessor()
sp.Load(root_folder+"dataset/wp_vocab10000.model")

vocab = [line.split('\t')[0] for line in open(root_folder+"dataset/wp_vocab10000.vocab", "r")]
pad_index = vocab.index('#')

def pad_sequence(numerized, pad_index, to_length):
    pad = numerized[:to_length]
    padded = pad + [pad_index] * (to_length - len(pad))
    mask = [w != pad_index for w in padded]
    return padded, mask

# Building blocks of a Transformer


**TODO**:

Implement the 5 blocks of the Transformer. In order to finish this section, you should get very small error <1e-7 on each of the 5 checks in this section.


The Transformer is split into 3 files: transformer_attention.py, transformer_utils.py and transformer.py

Each section below gives you directions and a way to verify your code works properly.

You do not need to modify the rest of the code provided, but should read it to understand overall architecture.

Our Transformer is built as a Pytorch model, a standard that is good for you to get accustomed to.



## (1) Implementing the Query-Key-Value Attention (AttentionQKV)

This part is located in AttentionQKV in transformer_attention.py. You must implement the call function of the class.
You will need to implement the mathematical procedure of AttentionQKV that is described in the [Attention is all you need paper](https://arxiv.org/pdf/1706.03762.pdf).

In [32]:
from transformer_attention import AttentionQKV

batch_size = 2;
n_queries = 3;
n_keyval = 5;
depth_k = 2;
depth_v = 2

with open(root_folder+"transformer_checks/attention_qkv_io.json", "r") as f:
    io = json.load(f)
    queries = th.tensor(io['queries'])
    keys = th.tensor(io['keys'])
    values = th.tensor(io['values'])
    expected_output  = th.tensor(io['output'])
    expected_weights = th.tensor(io['weights'])

attn_qkv = AttentionQKV()
output, weights = attn_qkv(queries, keys, values)
validate_to_array(model_out_to_list,((queries,keys,values),attn_qkv),'attentionqkv', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output)).item(), "(should be 0.0 or close to 0.0)")
print("Total error on the weights:",th.sum(th.abs(expected_weights-weights)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 2.8312206268310547e-07 (should be 0.0 or close to 0.0)
Total error on the weights: 2.849847078323364e-07 (should be 0.0 or close to 0.0)


## (2) Implementing Multi-head attention

This part is located in the class MultiHeadProjection in transformer_attention.py.
You must implement the call, \_split_heads, and \_combine_heads functions.

**Procedure**

The objective is to leverage the AttentionQKV class you already wrote.

Your input are the queries, keys, values as 3-d tensors (batch_size, sequence_length, feature_size).

Split them into 4-d tensors (batch_size, n_heads, sequence_length, new_feature_size). Where:
$$feature\_size = n\_heads * new_{feature\_size}.$$

You can then feed the split qkv to your implemented AttentionQKV, which will treat each head as an independent attention function.

Then the output must be combined back into a 3-d tensor.
You can test the validity of your implementation in the cell below.

In [33]:
from transformer_attention import MultiHeadProjection

batch_size = 2;
n_queries = 3;
n_heads = 4
n_keyval = 5;
depth_k = 8;
depth_v = 8;

with open(root_folder+"transformer_checks/multihead_io.json", "r") as f:
    io = json.load(f)
    queries = th.tensor(io['queries'])
    keys = th.tensor(io['keys'])
    values = th.tensor(io['values'])
    expected_output  = th.tensor(io['output'])

mhp = MultiHeadProjection(n_heads, (depth_k,depth_v))
multihead_output = mhp((queries, keys, values))
validate_to_array(model_out_to_list,(((queries,keys,values),),mhp),'multihead', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-multihead_output)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 1.5934929251670837e-06 (should be 0.0 or close to 0.0)


## (3) Position Embedding 

You must implement the FeedForward and PositionEmbedding classes in transformer.py.


The cell below helps you verify the validity of your implementation


In [34]:
from transformer import PositionEmbedding

batch_size = 2;
sequence_length = 3;
dim = 4;

with open(root_folder+"transformer_checks/position_embedding_io.json", "r") as f:
    io = json.load(f)
    inputs = th.tensor(io['inputs'])
    expected_output  = th.tensor(io['output'])

pos_emb = PositionEmbedding(dim)
(inputs,expected_output,pos_emb) = list_to_device((inputs,expected_output,pos_emb))
output_t = pos_emb(inputs)
validate_to_array(model_out_to_list,((inputs,),pos_emb),'position_embedding', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 2.980232238769531e-07 (should be 0.0 or close to 0.0)


## (4) Transformer Encoder / Transformer Decoder

You now have all the blocks needed to implement the Transformer.
For this part, you have to fill in 2 classes in the transformer.py file: TransformerEncoderBlock, TransformerDecoderBlock.

The code below will verify the accuracy of each block

In [35]:
from transformer import TransformerEncoderBlock

batch_size = 2
sequence_length = 5
hidden_size = 6
filter_size = 12
n_heads = 2

with open(root_folder+"transformer_checks/transformer_encoder_block_io_new.json", "r") as f:
    io = json.load(f)
    inputs = th.tensor(io['inputs'])
    expected_output = th.tensor(io['output'])
enc_block = TransformerEncoderBlock(input_size=6, n_heads=n_heads, filter_size=filter_size, hidden_size=hidden_size)
# th.save(enc_block.state_dict(),root_folder+"transformer_checks/transformer_encoder_block")
enc_block.load_state_dict(th.load(root_folder+"transformer_checks/transformer_encoder_block"))
(inputs,expected_output,enc_block) = list_to_device((inputs,expected_output,enc_block))
output_t = enc_block(inputs)
validate_to_array(model_out_to_list,((inputs,),enc_block),'encoder_block', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 5.8710575103759766e-06 (should be 0.0 or close to 0.0)


In [36]:
from transformer import TransformerDecoderBlock
batch_size = 2
encoder_length = 5
decoder_length = 3
hidden_size = 6
filter_size = 12
n_heads = 2

with open(root_folder+"transformer_checks/transformer_decoder_block_io_new.json", "r") as f:
    io = json.load(f)
    decoder_inputs = th.tensor(io['decoder_inputs'])
    encoder_output = th.tensor(io['encoder_output'])
    expected_output = th.tensor(io['output'])

dec_block = TransformerDecoderBlock(input_size=6, n_heads=n_heads, filter_size=filter_size, hidden_size=hidden_size)
dec_block.load_state_dict(th.load(root_folder+"transformer_checks/transformer_decoder_block"))
(decoder_inputs,encoder_output,expected_output,dec_block) = list_to_device((decoder_inputs,encoder_output,expected_output,dec_block))
output_t = dec_block(decoder_inputs, encoder_output)
validate_to_array(model_out_to_list,((decoder_inputs, encoder_output),dec_block),'decoder_block', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")


Total error on the output: 2.1904706954956055e-06 (should be 0.0 or close to 0.0)


## (5) Transformer

This is the final high-level function that pieces it all together.

You have to implement the call function of the Transformer class in the `transformer.py` file.

The block below verifies your implementation is correct.

In [37]:
from transformer import Transformer

batch_size = 2
vocab_size = 11
n_layers = 3
n_heads = 4
d_model = 8
d_filter = 16
input_length = 5
output_length = 3

with open(root_folder+"transformer_checks/transformer_io_new.json", "r") as f:
    io = json.load(f)
    enc_input = th.tensor(io['enc_input'])
    dec_input = th.tensor(io['dec_input'])
    enc_mask = th.tensor(io['enc_mask'])
    dec_mask = th.tensor(io['dec_mask'])
    expected_output = th.tensor(io['output'])
transformer = Transformer(vocab_size=vocab_size, n_layers=n_layers, n_heads=n_heads, d_model=d_model, d_filter=d_filter)
transformer.load_state_dict(th.load(root_folder+"transformer_checks/transformer"))
(enc_input,dec_input,enc_mask,dec_mask,expected_output,transformer) \
    = list_to_device((enc_input,dec_input,enc_mask,dec_mask,expected_output,transformer))
output_t = transformer(enc_input, target_sequence=dec_input, encoder_mask=enc_mask, decoder_mask=dec_mask)
validate_to_array(model_out_to_list, ((enc_input, dec_input, enc_mask, dec_mask),transformer),'transformer', root_folder)
print("Total error on the output:",th.sum(th.abs(expected_output-output_t)).item(), "(should be 0.0 or close to 0.0)")

Total error on the output: 5.692243576049805e-05 (should be 0.0 or close to 0.0)


# Training the model

Your objective is to train the Language on the dataset you are provided to reach a **validation loss <= 6.50**

Careful: we will be testing this loss on an unreleased test set, so make sure to evaluate properly on a validation set and not overfit.

You must save the model you want us to test under: models/final_transformer_summarization (the .index, .meta and .data files)

**Advice**:
- It should be possible to attain validation loss <= 6.50 with the model dimensions we've specified (n_layers=6, d_model=104, d_filter=416), but you can tune these hyperparameters. Increasing d_model will yield better model, at the cost of longer training time.
- You should try tuning the learning rate, as well as what optimizer you use.
- You might need to train for a few (up to 2 hours) to obtain our expected loss. Remember to tune your hyperparameters first, once you find ones that work well, let it train for longer.

**Dataset**: as in the previous notebook, make sure the dataset files are in the `dataset` folder. These can be found on the Google Drive.


In [48]:
with open(root_folder+"dataset/summarization_dataset_preprocessed.json", "r") as f:
    dataset = json.load(f)

# We load the dataset, and split it into 2 sub-datasets based on if they are training or validation.
# Feel free to split this dataset another way, but remember, a validation set is important, to have an idea of 
# the amount of overfitting that has occurred!

d_train = [d for d in dataset if d['cut'] == 'training']
d_valid = [d for d in dataset if d['cut'] == 'evaluation']

len(d_train), len(d_valid)

(61055, 1558)

In [49]:
# An example (article, summary) pair in the training data:

print(d_train[145]['story'])
print("=======================\n=======================")
print(d_train[145]['summary'])

Tbilisi, Georgia (CNN)Police have shot and killed a white tiger that killed a man Wednesday in Tbilisi, Georgia, a Ministry of Internal Affairs representative said, after severe flooding allowed hundreds of wild animals to escape the city zoo. 
The tiger attack happened at a warehouse in the city center. The animal had been unaccounted for since the weekend floods destroyed the zoo premises.
The man killed, who was 43, worked in a company based in the warehouse, the Ministry of Internal Affairs said. Doctors said he was attacked in the throat and died before reaching the hospital. 
Experts are still searching the warehouse, the ministry said, adding that earlier reports that the tiger had injured a second man were unfounded. 
The zoo administration said Wednesday that another tiger was still missing. It was unable to confirm if the creature was dead or had escaped alive.
Georgian Prime Minister Irakli Garibashvili apologized to the public, saying he had been misinformed by the zoo's ma

Similarly to the previous assignment, we create a function to get a random batch to train on, given a dataset.

In [50]:
def build_batch(dataset, batch_size):
    indices = list(np.random.randint(0, len(dataset), size=batch_size))
    
    batch = [dataset[i] for i in indices]
    batch_input = np.array([a['input'] for a in batch])
    batch_input_mask = np.array([a['input_mask'] for a in batch])
    batch_output = np.array([a['output'] for a in batch])
    batch_output_mask = np.array([a['output_mask'] for a in batch])
    
    return batch_input, batch_input_mask, batch_output, batch_output_mask

We now instantiate the Transformer with our sets of hyperparameters specific to the task of summarization.
In summarization, we are going to go from documents with up to 400 words, to documents with up to 100 words.
The vocabulary size is set for you, and is of 10,000 words (we are using WordPieces, [here is a paper about subword encoding](http://aclweb.org/anthology/P18-1007), if you are interested).

In [95]:
# Use this trainer to train a Transformer model

class TransformerTrainer(nn.Module):
    def __init__(self, vocab_size, d_model, input_length, output_length, n_layers, d_filter, dropout=0, learning_rate=1e-3):
        super().__init__()
        self.model = Transformer(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, d_filter=d_filter)

        # Summarization loss
        criterion = nn.CrossEntropyLoss(reduce='none')
        self.loss_fn = lambda pred,target,mask: (criterion(pred.permute(0,2,1),target)*mask).sum()/mask.sum()
        self.learning_rate = learning_rate
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
    def forward(self,batch,optimize=True):
        pred_logits = self.model(**batch)
        target,mask = batch['target_sequence'],batch['decoder_mask']
        loss = self.loss_fn(pred_logits,target,mask)
        accuracy = (th.eq(pred_logits.argmax(dim=2,keepdim=False),target).float()*mask).sum()/mask.sum()
        
        if optimize:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
                
        return loss, accuracy

In [99]:
# Dataset related parameters
vocab_size = len(vocab)
ilength = 400 # Length of the article
olength  = 100 # Length of the summaries

# Model related parameters, feel free to modify these.
n_layers = 6
d_model  = 160
d_filter = 4*d_model
batch_size = 32

dropout = 0.2
learning_rate = 3e-3
trainer = TransformerTrainer(vocab_size, d_model, ilength, olength, n_layers, d_filter, dropout)
model_id = 'test1'
os.makedirs(root_folder+'models/part2/',exist_ok=True)

device = th.device("cuda" if th.cuda.is_available() else "cpu")
print(device)
set_device(device)

cpu


In [100]:
# Skeleton code, as in the previous notebook.
# Write code training code and save your best performing model on the
# validation set. We will be testing the loss on a held-out test dataset.
from tqdm import tqdm
gc.collect()
trainer.model.to(device)
trainer.model.train()
losses,accuracies = [],[]
t = tqdm(range(int(2e3)+1))
for i in t:
    # Create a random mini-batch from the training dataset
    batch = build_batch(d_train, batch_size)
    # Build the feed-dict connecting placeholders and mini-batch
    batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
    batch_input, batch_input_mask, batch_output, batch_output_mask \
                = list_to_device([batch_input, batch_input_mask, batch_output, batch_output_mask])
    batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
            'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}

    # Obtain the loss. Be careful when you use the train_op and not, as previously.
    train_loss, accuracy = trainer(batch)
    losses.append(train_loss.item()),accuracies.append(accuracy.item())
    if i % 100 == 0:
        t.set_description(f"Iteration: {i} Loss: {np.mean(losses[-10:])} Accuracy: {np.mean(accuracies[-10:])}")
    if i % 100 == 0:
        save_dict = dict(
            kwargs = dict(
                vocab_size=vocab_size,
                d_model=d_model,
                n_layers=n_layers, 
                d_filter=d_filter
            ),
            model_state_dict = trainer.model.state_dict(),
            notes = ""
        )
        th.save(save_dict, root_folder+f'models/part2/model_{model_id}.pt')








  0%|          | 0/2001 [00:00<?, ?it/s][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   0%|          | 0/2001 [00:17<?, ?it/s][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   0%|          | 1/2001 [00:17<9:45:38, 17.57s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   0%|          | 2/2001 [00:53<12:48:50, 23.08s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   0%|          | 3/2001 [01:04<10:47:29, 19.44s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   0%|          | 4/2001 [01:16<9:34:50, 17.27s/it] [A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   0%|          | 5/2001 [01:32<9:21:32, 16.88s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accurac

Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   3%|▎         | 54/2001 [14:10<5:49:16, 10.76s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   3%|▎         | 55/2001 [14:21<5:45:52, 10.66s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   3%|▎         | 56/2001 [14:32<5:51:01, 10.83s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   3%|▎         | 57/2001 [14:43<5:50:27, 10.82s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   3%|▎         | 58/2001 [14:53<5:49:18, 10.79s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   3%|▎         | 59/2001 [15:04<5:44:14, 10.64s/it][A[A[A[A[A[A[A






Iteration: 0 Loss: 86.82952880859375 Accuracy: 0.0017152659129351377:   3%|▎         | 60/2001 [15:15<5:47

Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   5%|▌         | 108/2001 [24:16<5:15:28, 10.00s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   5%|▌         | 109/2001 [24:25<5:04:18,  9.65s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   5%|▌         | 110/2001 [24:35<5:05:12,  9.68s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   6%|▌         | 111/2001 [24:44<5:04:17,  9.66s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   6%|▌         | 112/2001 [24:53<5:00:23,  9.54s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   6%|▌         | 113/2001 [25:03<4:56:24,  9.42s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   6%|▌         | 114/2001 [25:12<4:53:49,  9

Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   8%|▊         | 163/2001 [37:16<6:11:52, 12.14s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   8%|▊         | 164/2001 [37:28<6:08:41, 12.04s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   8%|▊         | 165/2001 [37:39<6:00:59, 11.80s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   8%|▊         | 166/2001 [37:51<6:05:10, 11.94s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   8%|▊         | 167/2001 [38:03<6:06:20, 11.98s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   8%|▊         | 168/2001 [38:17<6:19:30, 12.42s/it][A[A[A[A[A[A[A






Iteration: 100 Loss: 5.622122430801392 Accuracy: 0.172064845263958:   8%|▊         | 169/2001 [38:31<6:34:19, 12

Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  11%|█         | 217/2001 [52:13<8:12:47, 16.57s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  11%|█         | 218/2001 [52:32<8:37:27, 17.41s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  11%|█         | 219/2001 [52:49<8:28:23, 17.12s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  11%|█         | 220/2001 [53:04<8:07:44, 16.43s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  11%|█         | 221/2001 [53:18<7:45:00, 15.67s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  11%|█         | 222/2001 [53:38<8:29:33, 17.19s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  11%|█         | 223/2001 [53:52<7:54:35, 16

Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  14%|█▎        | 272/2001 [1:05:13<6:00:13, 12.50s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  14%|█▎        | 273/2001 [1:05:25<5:55:13, 12.33s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  14%|█▎        | 274/2001 [1:05:41<6:26:56, 13.44s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  14%|█▎        | 275/2001 [1:05:55<6:28:30, 13.51s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  14%|█▍        | 276/2001 [1:06:08<6:22:36, 13.31s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  14%|█▍        | 277/2001 [1:06:20<6:16:45, 13.11s/it][A[A[A[A[A[A[A






Iteration: 200 Loss: 4.704784822463989 Accuracy: 0.203373222053051:  14%|█▍        | 278/2001 [1:06:

Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  16%|█▌        | 325/2001 [1:17:25<6:16:33, 13.48s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  16%|█▋        | 326/2001 [1:17:42<6:48:27, 14.63s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  16%|█▋        | 327/2001 [1:18:00<7:09:24, 15.39s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  16%|█▋        | 328/2001 [1:18:14<6:57:59, 14.99s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  16%|█▋        | 329/2001 [1:18:24<6:19:09, 13.61s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  16%|█▋        | 330/2001 [1:18:34<5:52:03, 12.64s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  17%|█▋        | 3

Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  19%|█▉        | 378/2001 [1:27:44<5:26:25, 12.07s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  19%|█▉        | 379/2001 [1:27:57<5:37:01, 12.47s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  19%|█▉        | 380/2001 [1:28:09<5:31:52, 12.28s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  19%|█▉        | 381/2001 [1:28:21<5:30:37, 12.25s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  19%|█▉        | 382/2001 [1:28:33<5:27:36, 12.14s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  19%|█▉        | 383/2001 [1:28:47<5:42:27, 12.70s/it][A[A[A[A[A[A[A






Iteration: 300 Loss: 4.144280004501343 Accuracy: 0.24129003584384917:  19%|█▉        | 3

Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  21%|██▏       | 430/2001 [1:38:22<5:10:55, 11.88s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  22%|██▏       | 431/2001 [1:38:34<5:11:13, 11.89s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  22%|██▏       | 432/2001 [1:38:46<5:15:02, 12.05s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  22%|██▏       | 433/2001 [1:38:58<5:14:51, 12.05s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  22%|██▏       | 434/2001 [1:39:09<5:07:16, 11.77s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  22%|██▏       | 435/2001 [1:39:21<5:08:13, 11.81s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  22%|██▏       | 4

Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  24%|██▍       | 483/2001 [1:49:13<5:23:05, 12.77s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  24%|██▍       | 484/2001 [1:49:26<5:18:03, 12.58s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  24%|██▍       | 485/2001 [1:49:38<5:19:26, 12.64s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  24%|██▍       | 486/2001 [1:49:52<5:26:37, 12.94s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  24%|██▍       | 487/2001 [1:50:06<5:35:10, 13.28s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  24%|██▍       | 488/2001 [1:50:19<5:28:55, 13.04s/it][A[A[A[A[A[A[A






Iteration: 400 Loss: 4.215539169311524 Accuracy: 0.25520038306713105:  24%|██▍       | 4

Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  27%|██▋       | 535/2001 [2:00:04<3:59:05,  9.79s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  27%|██▋       | 536/2001 [2:00:14<3:57:30,  9.73s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  27%|██▋       | 537/2001 [2:00:23<3:57:12,  9.72s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  27%|██▋       | 538/2001 [2:00:33<3:56:11,  9.69s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  27%|██▋       | 539/2001 [2:00:45<4:09:54, 10.26s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  27%|██▋       | 540/2001 [2:00:56<4:14:54, 10.47s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  27%|██▋       | 5

Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  29%|██▉       | 588/2001 [2:09:45<4:27:40, 11.37s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  29%|██▉       | 589/2001 [2:09:55<4:23:13, 11.19s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  29%|██▉       | 590/2001 [2:10:08<4:30:25, 11.50s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  30%|██▉       | 591/2001 [2:10:19<4:29:30, 11.47s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  30%|██▉       | 592/2001 [2:10:30<4:26:35, 11.35s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  30%|██▉       | 593/2001 [2:10:41<4:23:28, 11.23s/it][A[A[A[A[A[A[A






Iteration: 500 Loss: 3.9482913255691527 Accuracy: 0.2586625799536705:  30%|██▉       | 5

Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  32%|███▏      | 640/2001 [2:19:51<4:36:47, 12.20s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  32%|███▏      | 641/2001 [2:20:03<4:35:33, 12.16s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  32%|███▏      | 642/2001 [2:20:15<4:31:21, 11.98s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  32%|███▏      | 643/2001 [2:20:26<4:27:57, 11.84s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  32%|███▏      | 644/2001 [2:20:38<4:26:42, 11.79s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  32%|███▏      | 645/2001 [2:20:49<4:23:24, 11.66s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  32%|███▏  

Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  35%|███▍      | 693/2001 [2:30:04<3:39:20, 10.06s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  35%|███▍      | 694/2001 [2:30:15<3:46:25, 10.39s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  35%|███▍      | 695/2001 [2:30:25<3:43:54, 10.29s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  35%|███▍      | 696/2001 [2:30:36<3:46:19, 10.41s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  35%|███▍      | 697/2001 [2:30:47<3:51:47, 10.67s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  35%|███▍      | 698/2001 [2:30:58<3:53:45, 10.76s/it][A[A[A[A[A[A[A






Iteration: 600 Loss: 3.8638356924057007 Accuracy: 0.26453557163476943:  35%|███▍  

Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  37%|███▋      | 745/2001 [2:39:33<3:49:49, 10.98s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  37%|███▋      | 746/2001 [2:39:44<3:50:09, 11.00s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  37%|███▋      | 747/2001 [2:39:56<3:54:27, 11.22s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  37%|███▋      | 748/2001 [2:40:07<3:52:39, 11.14s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  37%|███▋      | 749/2001 [2:40:18<3:52:32, 11.14s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  37%|███▋      | 750/2001 [2:40:29<3:50:22, 11.05s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  38%|███▊  

Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  40%|███▉      | 798/2001 [2:49:49<4:26:27, 13.29s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  40%|███▉      | 799/2001 [2:50:00<4:12:02, 12.58s/it][A[A[A[A[A[A[A






Iteration: 700 Loss: 3.7881422519683836 Accuracy: 0.27176029682159425:  40%|███▉      | 800/2001 [2:50:11<4:00:56, 12.04s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  40%|███▉      | 800/2001 [2:50:21<4:00:56, 12.04s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  40%|████      | 801/2001 [2:50:21<3:48:51, 11.44s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  40%|████      | 802/2001 [2:50:32<3:45:03, 11.26s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  40%|████  

Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  42%|████▏     | 850/2001 [2:59:19<3:56:34, 12.33s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  43%|████▎     | 851/2001 [2:59:30<3:46:56, 11.84s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  43%|████▎     | 852/2001 [2:59:40<3:36:14, 11.29s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  43%|████▎     | 853/2001 [2:59:49<3:24:34, 10.69s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  43%|████▎     | 854/2001 [2:59:59<3:18:37, 10.39s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  43%|████▎     | 855/2001 [3:00:09<3:17:56, 10.36s/it][A[A[A[A[A[A[A






Iteration: 800 Loss: 3.8827962636947633 Accuracy: 0.27081959545612333:  43%|████▎ 

Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  45%|████▌     | 902/2001 [3:08:43<3:23:05, 11.09s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  45%|████▌     | 903/2001 [3:08:54<3:22:43, 11.08s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  45%|████▌     | 904/2001 [3:09:05<3:21:18, 11.01s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  45%|████▌     | 905/2001 [3:09:16<3:21:14, 11.02s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  45%|████▌     | 906/2001 [3:09:27<3:21:25, 11.04s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  45%|████▌     | 907/2001 [3:09:38<3:21:06, 11.03s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  45%|████▌     | 908/2001

Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  48%|████▊     | 956/2001 [3:19:49<3:48:02, 13.09s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  48%|████▊     | 957/2001 [3:20:02<3:46:32, 13.02s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  48%|████▊     | 958/2001 [3:20:17<3:54:05, 13.47s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  48%|████▊     | 959/2001 [3:20:31<3:56:03, 13.59s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  48%|████▊     | 960/2001 [3:20:44<3:53:33, 13.46s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  48%|████▊     | 961/2001 [3:20:57<3:53:23, 13.47s/it][A[A[A[A[A[A[A






Iteration: 900 Loss: 3.769312787055969 Accuracy: 0.2774701401591301:  48%|████▊     | 962/2001

Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  50%|█████     | 1009/2001 [3:30:20<2:57:49, 10.76s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  50%|█████     | 1010/2001 [3:30:30<2:56:40, 10.70s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  51%|█████     | 1011/2001 [3:30:41<2:57:13, 10.74s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  51%|█████     | 1012/2001 [3:30:52<2:58:31, 10.83s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  51%|█████     | 1013/2001 [3:31:03<2:58:56, 10.87s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  51%|█████     | 1014/2001 [3:31:14<2:58:10, 10.83s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  51%|█████  

Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  53%|█████▎    | 1062/2001 [3:40:21<3:17:34, 12.62s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  53%|█████▎    | 1063/2001 [3:40:32<3:09:54, 12.15s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  53%|█████▎    | 1064/2001 [3:40:43<3:05:34, 11.88s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  53%|█████▎    | 1065/2001 [3:40:55<3:02:50, 11.72s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  53%|█████▎    | 1066/2001 [3:41:06<2:58:38, 11.46s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  53%|█████▎    | 1067/2001 [3:41:17<2:56:38, 11.35s/it][A[A[A[A[A[A[A






Iteration: 1000 Loss: 3.634372043609619 Accuracy: 0.2838315963745117:  53%|█████▎ 

Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  56%|█████▌    | 1114/2001 [3:50:53<4:17:20, 17.41s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  56%|█████▌    | 1115/2001 [3:51:11<4:19:25, 17.57s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  56%|█████▌    | 1116/2001 [3:51:27<4:10:31, 16.98s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  56%|█████▌    | 1117/2001 [3:51:41<3:57:32, 16.12s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  56%|█████▌    | 1118/2001 [3:51:55<3:46:41, 15.40s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  56%|█████▌    | 1119/2001 [3:52:09<3:39:12, 14.91s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  56%|█████▌ 

Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  58%|█████▊    | 1167/2001 [4:02:34<2:48:18, 12.11s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  58%|█████▊    | 1168/2001 [4:02:46<2:49:36, 12.22s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  58%|█████▊    | 1169/2001 [4:02:59<2:51:18, 12.35s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  58%|█████▊    | 1170/2001 [4:03:12<2:54:37, 12.61s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  59%|█████▊    | 1171/2001 [4:03:25<2:55:38, 12.70s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  59%|█████▊    | 1172/2001 [4:03:38<2:58:29, 12.92s/it][A[A[A[A[A[A[A






Iteration: 1100 Loss: 3.777234935760498 Accuracy: 0.2836142271757126:  59%|█████▊ 

Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  61%|██████    | 1219/2001 [4:14:47<3:19:40, 15.32s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  61%|██████    | 1220/2001 [4:15:01<3:17:39, 15.18s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  61%|██████    | 1221/2001 [4:15:16<3:15:34, 15.04s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  61%|██████    | 1222/2001 [4:15:32<3:18:07, 15.26s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  61%|██████    | 1223/2001 [4:15:47<3:18:41, 15.32s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  61%|██████    | 1224/2001 [4:16:03<3:17:47, 15.27s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  61%|██████    | 12

Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  64%|██████▎   | 1272/2001 [4:28:43<3:32:16, 17.47s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  64%|██████▎   | 1273/2001 [4:29:03<3:43:03, 18.38s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  64%|██████▎   | 1274/2001 [4:29:17<3:26:57, 17.08s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  64%|██████▎   | 1275/2001 [4:29:33<3:21:10, 16.63s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  64%|██████▍   | 1276/2001 [4:29:48<3:17:26, 16.34s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  64%|██████▍   | 1277/2001 [4:30:04<3:13:18, 16.02s/it][A[A[A[A[A[A[A






Iteration: 1200 Loss: 3.58487811088562 Accuracy: 0.2978665232658386:  64%|██████▍   | 12

Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  66%|██████▌   | 1324/2001 [4:43:11<3:22:42, 17.97s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  66%|██████▌   | 1325/2001 [4:43:27<3:15:48, 17.38s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  66%|██████▋   | 1326/2001 [4:43:45<3:16:24, 17.46s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  66%|██████▋   | 1327/2001 [4:44:01<3:11:22, 17.04s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  66%|██████▋   | 1328/2001 [4:44:17<3:05:51, 16.57s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  66%|██████▋   | 1329/2001 [4:44:33<3:03:43, 16.40s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  66%|

Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  69%|██████▉   | 1377/2001 [4:57:19<2:54:31, 16.78s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  69%|██████▉   | 1378/2001 [4:57:39<3:04:03, 17.73s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  69%|██████▉   | 1379/2001 [4:57:57<3:03:28, 17.70s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  69%|██████▉   | 1380/2001 [4:58:12<2:54:35, 16.87s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  69%|██████▉   | 1381/2001 [4:58:27<2:50:28, 16.50s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  69%|██████▉   | 1382/2001 [4:58:42<2:44:22, 15.93s/it][A[A[A[A[A[A[A






Iteration: 1300 Loss: 3.655331611633301 Accuracy: 0.29269460439682005:  69%|

Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  71%|███████▏  | 1429/2001 [5:11:05<2:31:50, 15.93s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  71%|███████▏  | 1430/2001 [5:11:21<2:30:43, 15.84s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  72%|███████▏  | 1431/2001 [5:11:41<2:41:42, 17.02s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  72%|███████▏  | 1432/2001 [5:12:01<2:51:25, 18.08s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  72%|███████▏  | 1433/2001 [5:12:17<2:43:20, 17.25s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  72%|███████▏  | 1434/2001 [5:12:31<2:35:43, 16.48s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086

Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  74%|███████▍  | 1481/2001 [5:25:34<2:16:01, 15.70s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  74%|███████▍  | 1482/2001 [5:25:50<2:17:03, 15.84s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  74%|███████▍  | 1483/2001 [5:26:06<2:17:04, 15.88s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  74%|███████▍  | 1484/2001 [5:26:22<2:19:15, 16.16s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  74%|███████▍  | 1485/2001 [5:26:39<2:20:57, 16.39s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086:  74%|███████▍  | 1486/2001 [5:26:56<2:21:05, 16.44s/it][A[A[A[A[A[A[A






Iteration: 1400 Loss: 3.6343838691711428 Accuracy: 0.29061495065689086

Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  77%|███████▋  | 1532/2001 [5:39:46<2:07:21, 16.29s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  77%|███████▋  | 1533/2001 [5:40:04<2:11:14, 16.83s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  77%|███████▋  | 1534/2001 [5:40:26<2:21:47, 18.22s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  77%|███████▋  | 1535/2001 [5:40:45<2:23:16, 18.45s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  77%|███████▋  | 1536/2001 [5:41:02<2:21:24, 18.25s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  77%|███████▋  | 1537/2001 [5:41:19<2:18:16, 17.88s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907

Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  79%|███████▉  | 1584/2001 [5:54:55<1:48:28, 15.61s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  79%|███████▉  | 1585/2001 [5:55:11<1:50:38, 15.96s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  79%|███████▉  | 1586/2001 [5:55:27<1:50:34, 15.99s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  79%|███████▉  | 1587/2001 [5:55:44<1:51:33, 16.17s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  79%|███████▉  | 1588/2001 [5:56:00<1:51:27, 16.19s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907:  79%|███████▉  | 1589/2001 [5:56:17<1:51:32, 16.24s/it][A[A[A[A[A[A[A






Iteration: 1500 Loss: 3.4379063844680786 Accuracy: 0.30604331791400907

Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  82%|████████▏ | 1635/2001 [6:08:40<1:38:07, 16.09s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  82%|████████▏ | 1636/2001 [6:08:57<1:38:55, 16.26s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  82%|████████▏ | 1637/2001 [6:09:12<1:37:02, 16.00s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  82%|████████▏ | 1638/2001 [6:09:27<1:35:23, 15.77s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  82%|████████▏ | 1639/2001 [6:09:42<1:34:02, 15.59s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  82%|████████▏ | 1640/2001 [6:09:58<1:34:03, 15.63s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837

Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  84%|████████▍ | 1687/2001 [6:23:11<1:27:17, 16.68s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  84%|████████▍ | 1688/2001 [6:23:26<1:25:26, 16.38s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  84%|████████▍ | 1689/2001 [6:23:45<1:28:10, 16.96s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  84%|████████▍ | 1690/2001 [6:24:03<1:30:10, 17.40s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  85%|████████▍ | 1691/2001 [6:24:19<1:27:29, 16.93s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837:  85%|████████▍ | 1692/2001 [6:24:36<1:26:56, 16.88s/it][A[A[A[A[A[A[A






Iteration: 1600 Loss: 3.6572481870651243 Accuracy: 0.29360421895980837

Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  87%|████████▋ | 1739/2001 [6:37:58<1:13:34, 16.85s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  87%|████████▋ | 1740/2001 [6:38:16<1:14:22, 17.10s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  87%|████████▋ | 1741/2001 [6:38:33<1:13:40, 17.00s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  87%|████████▋ | 1742/2001 [6:38:49<1:12:50, 16.87s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  87%|████████▋ | 1743/2001 [6:39:06<1:12:37, 16.89s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  87%|████████▋ | 1744/2001 [6:39:25<1:14:06, 17.30s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  87%|███████

Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  90%|████████▉ | 1792/2001 [6:53:31<58:32, 16.81s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  90%|████████▉ | 1793/2001 [6:53:47<58:01, 16.74s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  90%|████████▉ | 1794/2001 [6:54:04<57:20, 16.62s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  90%|████████▉ | 1795/2001 [6:54:27<1:03:57, 18.63s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  90%|████████▉ | 1796/2001 [6:54:50<1:07:42, 19.82s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  90%|████████▉ | 1797/2001 [6:55:09<1:06:35, 19.59s/it][A[A[A[A[A[A[A






Iteration: 1700 Loss: 3.582536220550537 Accuracy: 0.2954024612903595:  90%|████████▉ | 1

Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  92%|█████████▏| 1844/2001 [7:07:20<36:58, 14.13s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  92%|█████████▏| 1845/2001 [7:07:35<36:58, 14.22s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  92%|█████████▏| 1846/2001 [7:07:49<36:59, 14.32s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  92%|█████████▏| 1847/2001 [7:08:03<36:36, 14.26s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  92%|█████████▏| 1848/2001 [7:08:19<37:22, 14.65s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  92%|█████████▏| 1849/2001 [7:08:35<37:44, 14.90s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  92%|█████████▏| 

Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  95%|█████████▍| 1897/2001 [7:20:56<26:44, 15.43s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  95%|█████████▍| 1898/2001 [7:21:12<26:36, 15.50s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  95%|█████████▍| 1899/2001 [7:21:27<26:00, 15.30s/it][A[A[A[A[A[A[A






Iteration: 1800 Loss: 3.7218676567077638 Accuracy: 0.3038091391324997:  95%|█████████▍| 1900/2001 [7:21:40<24:55, 14.81s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  95%|█████████▍| 1900/2001 [7:21:55<24:55, 14.81s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  95%|█████████▌| 1901/2001 [7:21:55<24:33, 14.74s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  95%|█████████▌| 

Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  97%|█████████▋| 1949/2001 [7:33:53<12:31, 14.46s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  97%|█████████▋| 1950/2001 [7:34:07<12:14, 14.40s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  98%|█████████▊| 1951/2001 [7:34:22<12:05, 14.50s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  98%|█████████▊| 1952/2001 [7:34:37<11:54, 14.59s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  98%|█████████▊| 1953/2001 [7:34:51<11:32, 14.43s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  98%|█████████▊| 1954/2001 [7:35:06<11:23, 14.54s/it][A[A[A[A[A[A[A






Iteration: 1900 Loss: 3.5140673875808717 Accuracy: 0.3024378031492233:  98%|█████████▊| 

Iteration: 2000 Loss: 3.7551266670227053 Accuracy: 0.30180462896823884: 100%|██████████| 2001/2001 [7:46:48<00:00, 14.00s/it][A[A[A[A[A[A[A


# Using the Summarization model

Now that you have trained a Transformer to perform Summarization, we will use the model on news articles from the wild.

The three subsections below explore what the model has learned.

## The validation loss

Measure the validation loss of your model. This part could be used, as in our previous notebook, in deciding what is a likely, vs. unlikely summary for an article.

We will use the code here with the unreleased test-set to evaluate your model.

In [101]:
gc.collect()
model_id = "test1"
save_dict = th.load(root_folder+'models/part2/'+f"model_{model_id}.pt", map_location='cpu')
model = Transformer(**save_dict['kwargs'])
model.load_state_dict(save_dict['model_state_dict'])
set_device('cpu')
model.eval()
trainer.model = model

In [102]:
gc.collect()
losses = []
for i in range(100):
    batch = build_batch(d_valid, 1)
    # Build the feed-dict connecting placeholders and mini-batch
    batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
    batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
            'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}
    valid_loss, accuracy = trainer(batch,optimize=False)
    losses.append(float(valid_loss.cpu().item()))
print("Validation loss:", np.mean(losses))

Validation loss: 5.117078212499618


In [103]:
# Your best performing model should go here.
os.makedirs(root_folder+"best_models",exist_ok=True)
best_model_file = root_folder+"best_models/part2_best_model.pt"
th.save(save_dict,best_model_file)

## Generating an article's summary

This model we have built is meant to be used to generate summaries for new articles we do not have summaries for.
We got a [news article](https://www.chicagotribune.com/news/local/breaking/ct-met-officer-shot-20190309-story.html) from the Chicago Tribune about a police shooting, and want to use our model to produce a summary.

As you will see, our model is still limited in its ability, and will most likely not produce an interpretible summary, however, with more data and training, this model would be able to produce good summaries.

In [61]:
article_text = "A 34-year-old Chicago police officer has been shot in the shoulder during the execution of a search warrant in the Humboldt Park neighborhood, police say. The alleged shooter, a 19-year-old woman, was in custody. The shooting happened about 7:20 p.m. in the 2700 block of West Potomac Avenue, police said. The officer, part of the Grand Central District tactical unit, was taken to Stroger Hospital. While officers were serving a \"typical\" search warrant for \"narcotics and illegal weapons\" and were attempting to reach a rear door, \"a shot was fired,\" striking the tactical officer in the shoulder, said Chicago police Superintendent Eddie Johnson during a news briefing outside the hospital. He said the officer, who has about four or five years on the job, was \"stable\" but in critical condition. \"His family is here,\" Johnson said. \"He’s talking a lot and just wants the ordeal to be over.\" He said this incident serves as just another reminder of how dangerous a police officer’s job is. At the scene of the shooting, crime tape closed Potomac from Washtenaw Avenue to California Avenue and encompassed the alley west of the brick apartment building, south of Potomac. Dozens of officers stood in the alley, while even more walked up and down the street. Neighbors gathered at the edge of the yellow tape on the sidewalk along California and watched them work. Standing next to a man, a woman talked to police in the crime scene, across the street. \"We're not under arrest? We can go?\" the woman checked with officers. They told her she could go, and she and the man walked underneath the yellow tape and out of the crime scene."
input_length = 400
output_length = 100

# Process the capitalization with the preprocess_capitalization of the capita package.
article_text = capita.preprocess_capitalization(article_text)

# Numerize the tokens of the processed text using the loaded sentencepiece model.
numerized = sp.EncodeAsIds(article_text)
# Pad the sequence and keep the mask of the input
padded, mask = pad_sequence(numerized, pad_index, input_length)

# Making the news article into a batch of size one, to be fed to the neural network.
encoder_input = np.array([padded])
encoder_mask = np.array([mask])

decoded_so_far = [0]

for j in range(output_length):
    padded_decoder_input, decoder_mask = pad_sequence(decoded_so_far, pad_index, output_length)
    padded_decoder_input = [padded_decoder_input]
    decoder_mask = [decoder_mask]
    print("========================")
    print(padded_decoder_input)
    # Use the model to find the distrbution over the vocabulary for the next word
    batch = (encoder_input,encoder_mask,padded_decoder_input,decoder_mask)
    batch_input, batch_input_mask, batch_output, batch_output_mask = [th.tensor(tensor) for tensor in batch]
    batch = {'source_sequence': batch_input, 'target_sequence': batch_output,
            'encoder_mask': batch_input_mask, 'decoder_mask': batch_output_mask}
    logits = trainer.model(**batch).cpu().detach().numpy()

    chosen_words = np.argmax(logits, axis=2) # Take the argmax, getting the most likely next word
    decoded_so_far.append(int(chosen_words[0, j])) # We add it to the summary so far


print("The final summary:")
print("".join([vocab[i] for i in decoded_so_far]).replace("▁", " "))

[[0, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 99

[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998

[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998

[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3,

[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 107, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 

[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 107, 4, 7, 3, 19, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 107, 107, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 107, 4, 7, 3, 19, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 107, 107, 32, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998, 9998]]
[[0, 3, 3, 4, 15, 107, 10, 32, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 107, 4, 7, 3, 19, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 107, 107, 32, 32, 9998,

## Word vectors

The model we train learns word representations for each word in our vocabulary. A word represention is a vector of **dim** size.

It is common in NLP to inspect the word vectors, as some properties of language often appear in the embedding structure.


We are going to load the word embeddings learned by our model, and inspect it.
Because our network was not trained for long, we are going for the simplest patterns, but if we let the network train longer, it learns more complex, semantic patterns.

In [62]:
# We help you load the matrix, as it is hidden within the Transformer structure.
E = trainer.model.encoder.embedding_layer.embedding.weight.cpu().detach().numpy()

print("The embedding matrix has shape:", E.shape)
print("The vocabulary has length:", len(vocab))

The embedding matrix has shape: (10000, 160)
The vocabulary has length: 10000


Pronouns serve very similar purposes, therefore we should expect the representation of "he" and "she" to be similar, and have cosine similarity.

- **TODO**:  Find the cosine similarity between the vectors that represent words "she" and "he".
- **TODO**:  Find the cosine similarity between the vectors that represent words "more" and "less".

We can contrast that with the cosine similarity to a random, non-related word, like "ball", or "gorilla".
- **TODO**: Compute the cosine similarity between "she" and "ball".
- **TODO**: Compute the cosine similarity between "more" and "protest".



In [69]:
def cosine_sim(v1, v2):
    # TODO: Implement the cosine similarity of 2 vectors. Careful: the words might not have unit norm.
    return np.dot(v1, v2)/np.linalg.norm(v1)/np.linalg.norm(v2)

for w1, w2 in [("she", "he"), ("more", "less"), ("she", "ball"), ("more", "gorilla")]:
    w1_index = vocab.index('▁'+w1) # The index of the first  word in our vocabulary
    w2_index = vocab.index('▁'+w2) # The index of the second word in our vocabulary
    w1_vec = E[w1_index] # Get the embedding vector of the first  word
    w2_vec = E[w2_index] # Get the embedding vector of the second word
    
    print(w1," vs. ", w2, "similarity:",cosine_sim(w1_vec, w2_vec))
validate_to_array(lambda f,i: (f(*i),i), (cosine_sim,tuple(20*np.random.random((2,1000))-1)),'cosine_sim', root_folder)

she  vs.  he similarity: -0.024962215
more  vs.  less similarity: 0.14132142
she  vs.  ball similarity: 0.048659142
more  vs.  gorilla similarity: 0.06929617


These effects are unfortunately small, as we have only trained the network on a few hours on a few thousand articles.
However, the same model trained for longer on more data exhibits many interesting semantic and syntactic patterns, such as:

- Words vectors with high cosine similarity usually represent words that have semantic similarity (such as duck and pigeon)
- Analogies can occur, a famous case is that of: woman - man + king ≈ queen. Or france - paris + rome ≈ italy.

- Looking at top-k similar words can help find synonyms.

To read examples of more complex patterns that appear in word embedding spaces, read [this blog](https://explosion.ai/blog/sense2vec-with-spacy). To play with a live demo and try similarities on rich word embeddings, [go here.](https://explosion.ai/demos/sense2vec)