In [1]:
# run this to shorten the data import from the files
import os
cwd = os.path.dirname(os.getcwd())+'/'
path_data = os.path.join(os.path.dirname(os.getcwd()), 'datasets/')


In [14]:
from transformers import BertTokenizer, BertForSequenceClassification
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
texts = ['I love this!', 'This is terrible.', 'Amazing experience!', 'Not my cup of tea.']

labels = [1, 0, 1, 0]

In [17]:
# exercise 01

"""
Transfer learning using BERT

At PyBooks, the company has decided to leverage the power of the BERT model, a pre-trained transformer model, for sentiment analysis. BERT has seen remarkable performance across various NLP tasks, making it a prime candidate for this use case.

You're tasked with setting up a basic workflow using the BERT model from the transformers library for binary sentiment classification.

The following has been imported for you: BertTokenizer, BertForSequenceClassification, torch. The example data texts and corresponding labels are also preloaded.
"""

# Instructions

"""

    Load the bert-base-uncased tokenizer and model suitable for binary classification.

    Tokenize your dataset and prepare it for the model, ensuring it returns PyTorch tensors using the return_tensors argument.

    Setup the optimizer using model parameters.

"""

# solution

# Load the BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Tokenize your data and return PyTorch tensors
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=32)
inputs["labels"] = torch.tensor(labels)

# Setup the optimizer using model parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
model.train()
for epoch in range(50):
    outputs = model(**inputs)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

#----------------------------------#

# Conclusion

"""
Nice! You've successfully fine-tuned the BERT model for sentiment analysis. This foundation will serve as a robust base for understanding user sentiments in book reviews. The output should show you model loss!
"""

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch: 1, Loss: 0.6816250681877136
Epoch: 2, Loss: 0.6234686374664307
Epoch: 3, Loss: 0.6200855374336243
Epoch: 4, Loss: 0.7368596196174622
Epoch: 5, Loss: 0.5780086517333984
Epoch: 6, Loss: 0.5112018585205078
Epoch: 7, Loss: 0.5353565216064453
Epoch: 8, Loss: 0.48216643929481506
Epoch: 9, Loss: 0.4854140877723694
Epoch: 10, Loss: 0.4816187918186188
Epoch: 11, Loss: 0.41129976511001587
Epoch: 12, Loss: 0.4665626287460327
Epoch: 13, Loss: 0.302004337310791
Epoch: 14, Loss: 0.30847564339637756
Epoch: 15, Loss: 0.3655702471733093
Epoch: 16, Loss: 0.3331689238548279
Epoch: 17, Loss: 0.27620846033096313
Epoch: 18, Loss: 0.33299770951271057
Epoch: 19, Loss: 0.28812164068222046
Epoch: 20, Loss: 0.244573175907135
Epoch: 21, Loss: 0.23351341485977173
Epoch: 22, Loss: 0.20445218682289124
Epoch: 23, Loss: 0.18575873970985413
Epoch: 24, Loss: 0.23773084580898285
Epoch: 25, Loss: 0.23708093166351318
Epoch: 26, Loss: 0.20038361847400665
Epoch: 27, Loss: 0.21345797181129456
Epoch: 28, Loss: 0.1612657

"\nNice! You've successfully fine-tuned the BERT model for sentiment analysis. This foundation will serve as a robust base for understanding user sentiments in book reviews. The output should show you model loss!\n"

In [18]:
# exercise 02

"""
Evaluating the BERT model

Having tokenized the sample reviews using BERT's tokenizer, it's now time to evaluate the BERT model with the samples at PyBooks. Additionally, you will evaluate the model's sentiment prediction on new data.

The following has been imported for you: BertTokenizer, BertForSequenceClassification, torch. The trained model instance is also preloaded. We will now test it on a new data sample.
"""

# Instructions

"""

    Prepare the evaluation text for the model by tokenizing it and returning PyTorch tensors.

    Convert the output logits to probabilities between zero and one.

    Display the sentiments from the probabilities.

"""

# solution

text = "I had an awesome day!"

# Tokenize the text and return PyTorch tensors
input_eval = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=32)
outputs_eval = model(**input_eval)

# Convert the output logits to probabilities
predictions = torch.nn.functional.softmax(outputs_eval.logits, dim=-1)

# Display the sentiments
predicted_label = 'positive' if torch.argmax(predictions) > 0 else 'negative'
print(f"Text: {text}\nSentiment: {predicted_label}")

#----------------------------------#

# Conclusion

"""
Impressive! Using the fine-tuned BERT model, you've accurately predicted sentiments of new texts. The printed sentiment should give you a glimpse of how the model perceives the provided text. With more training data and epochs the prediction can be improved. Remember, at PyBooks, understanding the sentiment of a review can be the key to the next bestseller recommendation!
"""

Text: I had an awesome day!
Sentiment: positive


"\nImpressive! Using the fine-tuned BERT model, you've accurately predicted sentiments of new texts. The printed sentiment should give you a glimpse of how the model perceives the provided text. With more training data and epochs the prediction can be improved. Remember, at PyBooks, understanding the sentiment of a review can be the key to the next bestseller recommendation!\n"

In [19]:
import torch.nn as nn
import torch.optim as optim

In [20]:
train_sentences = ['I love this product', 'This is terrible', 'Could be better']
train_labels = [1, 0, 0]

test_sentences = ['This is the best']
test_labels = [1]

In [21]:
# exercise 03

"""
Creating a transformer model

At PyBooks, the recommendation engine you're working on needs more refined capabilities to understand the sentiments of user reviews. You believe that using transformers, a state-of-the-art architecture, can help achieve this. You decide to build a transformer model that can encode the sentiments in the reviews to kickstart the project.

The following packages have been imported for you: torch, nn, optim.

The input data contains sentences such as : "I love this product", "This is terrible", "Could be better" … and their respective binary sentiment labels such as : 1, 0, 0, ...

The input data is split and converted to embeddings in the following variables: train_sentences, train_labels ,test_sentences,test_labels,token_embeddings
"""

# Instructions

"""

    Initialize the transformer encoder.

    Define the fully connected layer based on the number of sentiment classes.

    In the forward method, pass the input through the transformer encoder followed by the linear layer.

"""

# solution

class TransformerEncoder(nn.Module):
    def __init__(self, embed_size, heads, num_layers, dropout):
        super(TransformerEncoder, self).__init__()
        # Initialize the encoder 
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_size, nhead=heads),
            num_layers=num_layers)
        # Define the fully connected layer
        self.fc = nn.Linear(embed_size, 2)

    def forward(self, x):
        # Pass the input through the transformer encoder 
        x = self.encoder(x)
        x = x.mean(dim=1) 
        return self.fc(x)

model = TransformerEncoder(embed_size=512, heads=8, num_layers=3, dropout=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

#----------------------------------#

# Conclusion

"""
Fantastic! You've successfully created a Transformer model for sentiment analysis. With this architecture, you can encode and understand the nuances of reviews more effectively. Let's move on to training this model.
"""



"\nFantastic! You've successfully created a Transformer model for sentiment analysis. With this architecture, you can encode and understand the nuances of reviews more effectively. Let's move on to training this model.\n"

In [24]:
from torch import tensor

token_embeddings = {'I': tensor([[0.0853, 0.2057, 0.8046, 0.2828, 0.8155, 0.2743, 0.9170, 0.7140, 0.0304,
          0.7213, 0.5396, 0.9110, 0.7102, 0.6019, 0.1533, 0.1460, 0.5937, 0.0669,
          0.3502, 0.7979, 0.2620, 0.1561, 0.9213, 0.0265, 0.4499, 0.8657, 0.3645,
          0.0388, 0.2643, 0.6183, 0.6210, 0.5326, 0.5985, 0.5167, 0.6164, 0.7456,
          0.4870, 0.6180, 0.4759, 0.2585, 0.2806, 0.5770, 0.1741, 0.7249, 0.6110,
          0.8456, 0.2813, 0.0193, 0.3400, 0.1764, 0.8795, 0.9086, 0.0695, 0.7406,
          0.3931, 0.2374, 0.5236, 0.5884, 0.4193, 0.8450, 0.4999, 0.5154, 0.2725,
          0.0067, 0.8871, 0.1904, 0.5606, 0.6535, 0.0731, 0.4204, 0.5488, 0.1868,
          0.1094, 0.9935, 0.1202, 0.8779, 0.3257, 0.4929, 0.1829, 0.8570, 0.1582,
          0.6915, 0.3119, 0.3163, 0.3518, 0.3564, 0.0964, 0.2803, 0.7474, 0.9753,
          0.2208, 0.5510, 0.5935, 0.1838, 0.4768, 0.6959, 0.8245, 0.3615, 0.3842,
          0.6337, 0.2908, 0.1929, 0.0399, 0.5911, 0.8700, 0.1805, 0.9715, 0.7156,
          0.8599, 0.4019, 0.7425, 0.3918, 0.6254, 0.0888, 0.3349, 0.9160, 0.3359,
          0.5465, 0.4482, 0.0280, 0.9239, 0.5493, 0.2199, 0.4319, 0.1360, 0.9437,
          0.8985, 0.0775, 0.9257, 0.7552, 0.7653, 0.3163, 0.7731, 0.1595, 0.6066,
          0.4273, 0.8949, 0.1075, 0.6959, 0.0129, 0.0231, 0.1842, 0.3656, 0.3986,
          0.4604, 0.2785, 0.1204, 0.6831, 0.2701, 0.2599, 0.0534, 0.9153, 0.4211,
          0.8364, 0.2461, 0.2060, 0.2067, 0.6011, 0.6579, 0.3553, 0.7440, 0.3434,
          0.5602, 0.7849, 0.2194, 0.0026, 0.4429, 0.6730, 0.4594, 0.5548, 0.2347,
          0.0676, 0.3379, 0.9410, 0.3491, 0.3693, 0.8339, 0.2343, 0.1048, 0.2037,
          0.2462, 0.8216, 0.6360, 0.4583, 0.1292, 0.9099, 0.9494, 0.9149, 0.3931,
          0.3680, 0.2182, 0.5708, 0.2215, 0.2633, 0.0082, 0.3339, 0.2517, 0.7107,
          0.9364, 0.2077, 0.4278, 0.9798, 0.4041, 0.1532, 0.8598, 0.9309, 0.2863,
          0.1688, 0.9203, 0.4584, 0.7829, 0.8375, 0.9348, 0.7666, 0.0893, 0.1892,
          0.6989, 0.3454, 0.8447, 0.6798, 0.4781, 0.7268, 0.3875, 0.3177, 0.9693,
          0.8379, 0.5318, 0.9973, 0.6368, 0.4980, 0.4090, 0.1359, 0.5196, 0.2466,
          0.8668, 0.1822, 0.8732, 0.6006, 0.2396, 0.0730, 0.1854, 0.2982, 0.2144,
          0.7160, 0.8179, 0.1176, 0.1740, 0.2782, 0.2991, 0.4013, 0.0804, 0.7588,
          0.9926, 0.3408, 0.6186, 0.2656, 0.9376, 0.7353, 0.3930, 0.1065, 0.3465,
          0.9617, 0.4574, 0.0658, 0.4763, 0.4394, 0.1855, 0.3879, 0.4095, 0.9617,
          0.9795, 0.0730, 0.9656, 0.9170, 0.7049, 0.4574, 0.6070, 0.7356, 0.3238,
          0.3488, 0.8693, 0.8333, 0.1259, 0.2067, 0.9784, 0.4551, 0.7707, 0.4466,
          0.4970, 0.5477, 0.9288, 0.9687, 0.9467, 0.4073, 0.3795, 0.9182, 0.0792,
          0.0787, 0.4678, 0.1418, 0.2561, 0.5858, 0.5653, 0.4478, 0.7510, 0.2139,
          0.9986, 0.9244, 0.1748, 0.0611, 0.4510, 0.1100, 0.3203, 0.9271, 0.9986,
          0.4125, 0.6003, 0.7226, 0.0762, 0.2237, 0.8114, 0.6187, 0.2863, 0.0220,
          0.9183, 0.0143, 0.2284, 0.9345, 0.7852, 0.7250, 0.7945, 0.9335, 0.1584,
          0.8269, 0.9756, 0.4916, 0.3518, 0.8724, 0.4157, 0.4282, 0.4466, 0.7231,
          0.9244, 0.8999, 0.1611, 0.4001, 0.7450, 0.9809, 0.6599, 0.8516, 0.0586,
          0.3137, 0.8474, 0.5301, 0.4206, 0.3945, 0.3235, 0.1818, 0.3534, 0.0977,
          0.4626, 0.8904, 0.1929, 0.0905, 0.1165, 0.0856, 0.5241, 0.1211, 0.8545,
          0.8897, 0.1438, 0.4973, 0.9383, 0.0634, 0.6774, 0.1269, 0.1469, 0.1955,
          0.6177, 0.8114, 0.9090, 0.6191, 0.2583, 0.0637, 0.9804, 0.3672, 0.8342,
          0.1533, 0.9820, 0.5172, 0.5383, 0.8169, 0.0132, 0.1488, 0.9539, 0.7597,
          0.2730, 0.7923, 0.5642, 0.5999, 0.7241, 0.0681, 0.5954, 0.2816, 0.0987,
          0.2909, 0.4026, 0.5229, 0.9986, 0.5985, 0.1993, 0.2150, 0.3636, 0.2250,
          0.5858, 0.7490, 0.7828, 0.9677, 0.0634, 0.4220, 0.4487, 0.7586, 0.1433,
          0.8909, 0.2914, 0.4938, 0.2729, 0.5903, 0.0867, 0.8721, 0.2210, 0.4641,
          0.8509, 0.6425, 0.5586, 0.4561, 0.5634, 0.0205, 0.9594, 0.2871, 0.1343,
          0.5257, 0.2102, 0.4996, 0.7755, 0.7724, 0.7995, 0.9856, 0.6525, 0.5816,
          0.8549, 0.5278, 0.0957, 0.5074, 0.7417, 0.8509, 0.7082, 0.2497, 0.4659,
          0.2858, 0.0811, 0.8870, 0.5667, 0.1590, 0.3404, 0.4123, 0.1942, 0.2921,
          0.7993, 0.6292, 0.5678, 0.2713, 0.4189, 0.1761, 0.7051, 0.4123, 0.4482,
          0.4751, 0.1309, 0.1716, 0.2140, 0.4406, 0.0119, 0.9878, 0.5040, 0.0778,
          0.9401, 0.3427, 0.8406, 0.3405, 0.7980, 0.1054, 0.7023, 0.7700, 0.8412,
          0.0188, 0.6450, 0.0719, 0.8590, 0.3814, 0.6799, 0.8735, 0.5571, 0.8519,
          0.6969, 0.3577, 0.5458, 0.6296, 0.9949, 0.8963, 0.0073, 0.3958]]),
 'love': tensor([[3.6284e-01, 2.3154e-02, 2.6337e-01, 1.7345e-01, 1.6216e-01, 5.2091e-01,
          2.3351e-02, 9.9807e-01, 2.0579e-01, 1.4821e-01, 4.0756e-01, 6.4549e-01,
          1.9051e-01, 9.0507e-01, 5.7534e-01, 3.3057e-01, 6.1073e-01, 4.4476e-01,
          9.5151e-01, 7.7570e-01, 7.0230e-01, 4.0054e-01, 4.8791e-01, 6.5228e-01,
          5.4129e-02, 8.9736e-01, 1.3648e-03, 5.4171e-01, 1.9478e-01, 2.3144e-01,
          4.6105e-01, 3.5870e-01, 4.4645e-01, 3.8289e-01, 9.4930e-01, 7.9667e-01,
          2.7401e-01, 4.0041e-01, 3.3023e-02, 8.2151e-01, 2.4024e-01, 3.3517e-01,
          7.8629e-01, 7.1853e-02, 5.6120e-01, 8.4390e-01, 3.9144e-01, 7.8321e-01,
          6.7864e-01, 1.5575e-01, 9.0859e-01, 4.0503e-01, 2.0997e-01, 9.8058e-01,
          4.0286e-01, 8.2296e-01, 8.4852e-01, 9.2228e-01, 5.0724e-01, 2.8658e-01,
          5.4268e-01, 7.9620e-01, 2.6817e-01, 4.3123e-01, 1.1171e-02, 5.8252e-01,
          7.6674e-01, 4.9314e-02, 9.1531e-01, 9.9628e-01, 7.8890e-01, 7.1539e-01,
          6.4847e-01, 3.8875e-01, 9.8168e-01, 8.6509e-01, 1.7103e-01, 7.4730e-01,
          4.9865e-01, 5.2271e-01, 8.9986e-01, 7.8673e-01, 6.6448e-01, 5.4711e-01,
          5.9501e-01, 3.5695e-01, 1.7394e-01, 6.0894e-01, 5.0543e-01, 4.0172e-01,
          3.2366e-01, 8.8626e-01, 3.3826e-01, 2.4438e-06, 4.4866e-01, 8.9882e-01,
          2.5831e-02, 3.7999e-01, 3.7623e-01, 6.7137e-01, 2.1375e-01, 3.2104e-01,
          8.5708e-01, 3.9515e-01, 8.2327e-01, 2.2448e-01, 1.5764e-01, 8.5711e-01,
          5.5816e-01, 6.5027e-01, 4.2742e-01, 2.0926e-01, 7.7718e-01, 5.7950e-01,
          7.3660e-01, 8.6042e-01, 4.3815e-01, 2.7696e-01, 8.6967e-01, 4.3190e-01,
          6.1290e-01, 4.6694e-01, 9.7259e-01, 7.7465e-01, 6.0844e-01, 5.7683e-01,
          9.6053e-01, 7.1207e-01, 8.8106e-01, 1.3985e-01, 4.8683e-01, 8.8478e-01,
          9.2944e-01, 2.7325e-01, 5.8759e-01, 3.1168e-01, 1.9031e-01, 8.4362e-01,
          3.2074e-01, 2.5114e-01, 3.7986e-01, 5.9402e-01, 7.7321e-02, 4.5300e-01,
          9.1618e-01, 1.3463e-01, 4.7108e-01, 9.0283e-01, 7.6777e-01, 4.9431e-01,
          3.8720e-01, 9.8519e-01, 6.3109e-01, 3.7670e-01, 4.8880e-01, 1.8473e-01,
          6.6146e-01, 7.3275e-01, 6.2889e-01, 1.2067e-01, 7.4720e-01, 2.6012e-01,
          7.7149e-02, 5.4886e-01, 5.6048e-01, 1.0130e-01, 4.9167e-01, 5.5197e-01,
          9.3174e-01, 2.6070e-01, 8.1934e-01, 7.1675e-01, 9.4886e-01, 7.4611e-01,
          6.4336e-01, 6.0371e-01, 6.5223e-01, 3.8408e-01, 2.6833e-01, 3.5092e-01,
          2.1589e-01, 5.1512e-01, 7.4356e-02, 3.5561e-01, 5.2974e-01, 9.5343e-01,
          1.4256e-01, 9.3722e-01, 2.2913e-01, 8.1424e-01, 1.4074e-02, 6.2313e-01,
          7.0515e-01, 8.4868e-01, 7.0104e-01, 6.9378e-01, 7.3194e-01, 5.7701e-01,
          8.2713e-01, 8.4704e-01, 8.6033e-02, 9.9418e-01, 8.0317e-01, 2.2877e-01,
          3.8289e-01, 1.3534e-01, 4.2301e-01, 6.0456e-01, 3.2213e-01, 5.3385e-01,
          7.7966e-01, 1.7658e-01, 8.0513e-01, 6.2059e-01, 3.0893e-02, 4.7291e-01,
          7.9430e-01, 8.5669e-01, 9.1526e-01, 4.5298e-02, 6.2374e-01, 8.6623e-01,
          1.2900e-01, 3.7208e-01, 5.6696e-01, 1.6152e-01, 7.7639e-01, 2.4639e-01,
          4.2533e-01, 7.1675e-01, 7.6651e-01, 5.8663e-01, 8.1399e-01, 7.7266e-01,
          7.4340e-01, 7.3200e-01, 2.9563e-02, 2.0130e-01, 8.0724e-01, 5.3229e-01,
          3.9370e-01, 9.3142e-01, 5.3530e-01, 6.7946e-01, 4.2284e-01, 2.7244e-01,
          8.9897e-01, 7.2305e-01, 3.7785e-02, 5.3857e-01, 2.3721e-01, 4.6204e-01,
          3.3743e-01, 4.9571e-01, 7.0795e-01, 7.7339e-01, 5.7580e-01, 9.0633e-01,
          3.1101e-01, 4.8404e-01, 7.2685e-01, 1.0043e-01, 5.7684e-01, 1.0986e-01,
          1.6627e-01, 2.7824e-01, 9.2098e-01, 2.5373e-02, 9.0256e-02, 7.1284e-01,
          9.7597e-01, 8.2478e-02, 3.0384e-01, 6.9267e-01, 8.2306e-01, 7.9191e-01,
          9.9331e-01, 3.0980e-01, 7.2045e-01, 8.2027e-01, 6.4300e-01, 7.2316e-01,
          7.1902e-01, 9.9576e-01, 5.5377e-01, 2.2238e-01, 7.0674e-01, 8.3169e-01,
          1.8522e-01, 3.3046e-01, 4.4248e-01, 7.8935e-01, 2.6343e-01, 3.7640e-01,
          6.5580e-02, 4.3108e-01, 7.3951e-01, 9.6147e-01, 9.5936e-01, 2.5585e-01,
          1.2975e-01, 2.8679e-01, 7.7311e-01, 9.6823e-01, 6.7665e-02, 5.9448e-01,
          1.7762e-02, 1.3764e-01, 3.8910e-01, 5.8797e-01, 8.5799e-01, 8.1343e-01,
          5.8921e-01, 5.5113e-01, 9.6108e-01, 5.9016e-01, 2.4221e-01, 4.0691e-01,
          9.5544e-01, 1.5694e-01, 3.2366e-01, 3.5743e-01, 1.6135e-01, 3.8674e-01,
          8.2718e-01, 9.7463e-01, 2.8207e-01, 1.3329e-01, 4.1995e-01, 8.4739e-01,
          6.7025e-01, 6.3624e-01, 4.2583e-03, 1.2781e-01, 9.0348e-01, 6.3887e-01,
          6.6407e-01, 5.6605e-01, 6.1960e-01, 2.4126e-01, 1.5445e-01, 5.1137e-01,
          5.4997e-01, 8.5443e-01, 9.7447e-01, 3.0329e-01, 4.4029e-01, 1.0494e-01,
          4.3037e-01, 4.1764e-01, 1.5828e-01, 1.1458e-01, 9.9227e-01, 3.2794e-01,
          8.7223e-01, 9.8800e-01, 8.0386e-03, 4.8773e-01, 9.9816e-01, 8.8338e-01,
          8.2014e-01, 1.5026e-01, 9.3789e-01, 5.7878e-01, 1.0635e-02, 5.2203e-02,
          9.2780e-01, 8.6282e-01, 4.2968e-01, 7.4025e-01, 4.3499e-01, 1.1448e-01,
          5.6968e-01, 5.3053e-01, 7.7654e-01, 3.2689e-01, 8.5167e-01, 7.3258e-01,
          4.9866e-02, 7.9820e-01, 2.3463e-01, 5.4451e-01, 7.5304e-01, 4.7792e-03,
          1.6349e-01, 3.5259e-01, 3.0486e-01, 2.0570e-01, 6.4345e-01, 3.7716e-01,
          7.9528e-01, 8.9692e-01, 9.6268e-01, 2.0443e-01, 9.1013e-01, 2.0256e-01,
          2.5920e-02, 5.8424e-01, 3.0894e-01, 7.2405e-02, 8.6381e-01, 3.9566e-01,
          6.1221e-01, 6.0382e-01, 5.7222e-01, 2.7709e-01, 6.2846e-01, 4.4462e-01,
          7.7679e-02, 9.0988e-01, 9.9948e-01, 9.0802e-01, 5.7570e-01, 6.5708e-01,
          6.0688e-01, 1.7978e-01, 4.3260e-01, 8.6696e-01, 9.8685e-01, 4.4908e-01,
          5.9470e-01, 5.0440e-01, 1.0474e-01, 6.7835e-02, 7.0626e-01, 7.7168e-01,
          7.8853e-02, 4.9835e-01, 7.0396e-01, 8.0712e-01, 6.7114e-01, 6.7687e-01,
          9.5989e-01, 5.3756e-01, 6.5295e-01, 8.0306e-01, 2.5460e-01, 2.5019e-01,
          3.5158e-01, 5.1287e-01, 5.6016e-01, 6.1390e-01, 2.6805e-01, 8.6651e-02,
          2.0065e-01, 7.5220e-02, 2.7338e-01, 7.2639e-01, 9.9720e-01, 6.6406e-01,
          5.7528e-01, 2.2208e-01, 5.3138e-01, 8.1814e-01, 7.8272e-01, 9.0693e-01,
          5.8757e-01, 6.8250e-01, 5.5343e-01, 2.7522e-01, 8.1205e-01, 4.4467e-01,
          2.8675e-01, 7.5438e-01, 9.9899e-01, 2.1754e-01, 5.2981e-01, 2.3692e-02,
          8.9358e-01, 5.0169e-01, 3.1150e-01, 6.9560e-01, 9.9468e-01, 2.6121e-01,
          1.3324e-01, 5.4232e-01, 8.3067e-01, 9.1338e-01, 3.6671e-01, 1.4831e-01,
          5.2382e-01, 3.3574e-01, 9.7735e-01, 6.0138e-01, 6.0314e-01, 9.4558e-01,
          5.1844e-01, 7.7903e-01, 7.5838e-01, 2.9215e-01, 8.5636e-01, 7.1019e-01,
          8.9376e-01, 2.7421e-01, 3.9056e-01, 5.4653e-02, 6.1586e-01, 8.3270e-02,
          7.4035e-01, 1.3294e-01, 2.5873e-01, 5.3449e-01, 8.2023e-01, 4.3984e-01,
          7.5942e-01, 1.8341e-01, 6.9399e-01, 8.1754e-01, 1.9899e-01, 6.1296e-01,
          2.8700e-01, 1.2948e-01]]),
 'this': tensor([[0.1393, 0.5083, 0.9096, 0.7020, 0.0354, 0.5287, 0.3193, 0.8829, 0.1535,
          0.4220, 0.7973, 0.6898, 0.8220, 0.5788, 0.9604, 0.9439, 0.6119, 0.1887,
          0.6569, 0.9562, 0.3627, 0.1798, 0.6822, 0.6888, 0.4388, 0.0085, 0.1811,
          0.6615, 0.6178, 0.4544, 0.1008, 0.2333, 0.5066, 0.4606, 0.5809, 0.0437,
          0.0855, 0.7555, 0.3129, 0.2400, 0.0669, 0.1944, 0.5573, 0.1226, 0.7840,
          0.7860, 0.9048, 0.5675, 0.3874, 0.6583, 0.7650, 0.6027, 0.9554, 0.6443,
          0.6914, 0.4054, 0.4395, 0.8202, 0.8819, 0.8168, 0.2606, 0.7196, 0.3762,
          0.3257, 0.1802, 0.2191, 0.3427, 0.8890, 0.5437, 0.9263, 0.0313, 0.2668,
          0.6278, 0.1316, 0.8755, 0.5034, 0.4492, 0.1692, 0.0311, 0.9956, 0.2446,
          0.1939, 0.1712, 0.5323, 0.2820, 0.3809, 0.2541, 0.0635, 0.2202, 0.7344,
          0.6958, 0.3598, 0.8559, 0.4804, 0.5427, 0.2962, 0.7423, 0.1028, 0.6483,
          0.1090, 0.7447, 0.2122, 0.0398, 0.8054, 0.5337, 0.8548, 0.3935, 0.6289,
          0.5426, 0.7138, 0.3222, 0.9033, 0.9894, 0.4406, 0.3721, 0.2838, 0.2443,
          0.6788, 0.3227, 0.5614, 0.4274, 0.3119, 0.2023, 0.8644, 0.3707, 0.3047,
          0.6574, 0.2158, 0.9665, 0.6811, 0.0545, 0.3035, 0.3338, 0.7826, 0.1711,
          0.9079, 0.0211, 0.7322, 0.2176, 0.1750, 0.2028, 0.1425, 0.7105, 0.6901,
          0.6773, 0.4346, 0.3170, 0.2307, 0.8915, 0.4978, 0.8098, 0.0643, 0.4261,
          0.0734, 0.9664, 0.0531, 0.8678, 0.7669, 0.1593, 0.9740, 0.7501, 0.5179,
          0.0329, 0.4273, 0.1019, 0.8354, 0.5936, 0.9119, 0.0149, 0.9976, 0.5194,
          0.2072, 0.7592, 0.5627, 0.0387, 0.3589, 0.1352, 0.9743, 0.7691, 0.9912,
          0.5978, 0.5725, 0.2246, 0.2046, 0.9984, 0.2862, 0.5872, 0.2440, 0.9229,
          0.9124, 0.0031, 0.9625, 0.2870, 0.1041, 0.7540, 0.0756, 0.7488, 0.3075,
          0.1898, 0.2608, 0.1012, 0.5675, 0.7105, 0.5505, 0.3155, 0.5087, 0.0281,
          0.0151, 0.6625, 0.7362, 0.7404, 0.9664, 0.9865, 0.9292, 0.1601, 0.9355,
          0.1864, 0.5761, 0.0443, 0.6035, 0.2049, 0.3843, 0.2607, 0.6160, 0.7977,
          0.6899, 0.3098, 0.5478, 0.4929, 0.0071, 0.5762, 0.8721, 0.7873, 0.3058,
          0.5917, 0.4645, 0.2035, 0.7144, 0.1842, 0.8004, 0.9272, 0.3038, 0.8906,
          0.9392, 0.3335, 0.9405, 0.9774, 0.4383, 0.2684, 0.7451, 0.5587, 0.5861,
          0.9839, 0.8648, 0.9721, 0.6162, 0.4146, 0.9226, 0.8777, 0.3975, 0.4419,
          0.2843, 0.3018, 0.8192, 0.4630, 0.2595, 0.6915, 0.9237, 0.6456, 0.0054,
          0.1707, 0.7880, 0.7081, 0.9398, 0.8121, 0.8107, 0.3089, 0.6584, 0.8293,
          0.1473, 0.5205, 0.5811, 0.2414, 0.2388, 0.6536, 0.7533, 0.0231, 0.8642,
          0.4318, 0.5623, 0.8786, 0.1093, 0.4999, 0.6357, 0.2365, 0.9587, 0.8276,
          0.6893, 0.5776, 0.6956, 0.9345, 0.9140, 0.4000, 0.9320, 0.8559, 0.7022,
          0.4949, 0.9171, 0.8545, 0.6552, 0.3620, 0.4638, 0.8222, 0.5946, 0.4362,
          0.4735, 0.1451, 0.4655, 0.5864, 0.6453, 0.3832, 0.0545, 0.7205, 0.0011,
          0.2828, 0.1399, 0.3865, 0.4624, 0.4802, 0.0899, 0.9321, 0.8230, 0.0668,
          0.7346, 0.2352, 0.9249, 0.9775, 0.7391, 0.9066, 0.3927, 0.9911, 0.8212,
          0.0473, 0.2287, 0.8394, 0.8353, 0.0683, 0.0113, 0.4451, 0.4699, 0.1159,
          0.5725, 0.0673, 0.0931, 0.9720, 0.6494, 0.3146, 0.6183, 0.3334, 0.6214,
          0.0329, 0.6258, 0.7659, 0.8056, 0.6006, 0.3563, 0.2182, 0.5596, 0.7408,
          0.0665, 0.3185, 0.9149, 0.9976, 0.4451, 0.4730, 0.9571, 0.9895, 0.9825,
          0.3327, 0.3570, 0.1998, 0.2118, 0.9843, 0.9497, 0.8830, 0.5998, 0.4161,
          0.3853, 0.8136, 0.8864, 0.4736, 0.7545, 0.7768, 0.1187, 0.2964, 0.4637,
          0.6896, 0.3402, 0.3714, 0.6630, 0.4301, 0.6140, 0.3409, 0.3984, 0.3314,
          0.0482, 0.8662, 0.0180, 0.6702, 0.2219, 0.4566, 0.0776, 0.5543, 0.6058,
          0.0350, 0.6419, 0.2300, 0.7704, 0.4175, 0.1876, 0.7092, 0.8669, 0.4844,
          0.3398, 0.5607, 0.0080, 0.0612, 0.4939, 0.1239, 0.8578, 0.6937, 0.2772,
          0.0502, 0.3850, 0.9689, 0.9623, 0.4769, 0.9136, 0.1101, 0.0996, 0.5035,
          0.3428, 0.0633, 0.3397, 0.1233, 0.5144, 0.5925, 0.7865, 0.2696, 0.8296,
          0.1950, 0.8539, 0.3270, 0.0542, 0.9988, 0.9708, 0.5422, 0.7643, 0.9287,
          0.4905, 0.4120, 0.2286, 0.0802, 0.4261, 0.4105, 0.9346, 0.9181, 0.9785,
          0.9542, 0.5270, 0.5086, 0.0056, 0.8537, 0.9590, 0.7317, 0.7088, 0.2054,
          0.8393, 0.2235, 0.8355, 0.9063, 0.8746, 0.8053, 0.3331, 0.8734, 0.2112,
          0.3340, 0.9062, 0.2177, 0.5298, 0.0535, 0.5095, 0.4858, 0.5604, 0.4525,
          0.0823, 0.1423, 0.5584, 0.7483, 0.6326, 0.9568, 0.1898, 0.3503, 0.7035,
          0.8515, 0.7419, 0.6358, 0.6740, 0.8400, 0.3420, 0.6590, 0.9820]]),
 'product': tensor([[0.9927, 0.4105, 0.5736, 0.9000, 0.6684, 0.3299, 0.4028, 0.8779, 0.7229,
          0.2937, 0.3859, 0.9886, 0.4668, 0.9955, 0.8116, 0.8865, 0.1882, 0.3202,
          0.7449, 0.8631, 0.0811, 0.8307, 0.2316, 0.1005, 0.5443, 0.4902, 0.2390,
          0.2414, 0.0439, 0.2744, 0.7032, 0.7607, 0.1878, 0.1694, 0.6335, 0.7443,
          0.0013, 0.0464, 0.0814, 0.4256, 0.0748, 0.3207, 0.7336, 0.7642, 0.4928,
          0.1710, 0.8912, 0.1218, 0.6220, 0.1380, 0.2228, 0.3156, 0.3258, 0.2052,
          0.1323, 0.4103, 0.2767, 0.7824, 0.9627, 0.8004, 0.1549, 0.8023, 0.9817,
          0.1908, 0.1562, 0.3106, 0.3357, 0.4137, 0.0707, 0.6985, 0.4610, 0.5035,
          0.8670, 0.9922, 0.7197, 0.8948, 0.9566, 0.8465, 0.8136, 0.6795, 0.1761,
          0.7229, 0.3550, 0.6100, 0.7295, 0.6090, 0.5827, 0.4100, 0.1115, 0.0995,
          0.8274, 0.4692, 0.8838, 0.5386, 0.3595, 0.9360, 0.0635, 0.1450, 0.3926,
          0.8310, 0.9143, 0.3391, 0.2829, 0.3047, 0.2182, 0.1280, 0.9802, 0.9440,
          0.2920, 0.8299, 0.9144, 0.4076, 0.2970, 0.2294, 0.0244, 0.8343, 0.9275,
          0.3803, 0.6840, 0.9597, 0.0127, 0.7151, 0.8146, 0.4750, 0.7732, 0.0529,
          0.3670, 0.0474, 0.7812, 0.5963, 0.2915, 0.1508, 0.9005, 0.0653, 0.2741,
          0.4691, 0.2832, 0.8760, 0.3789, 0.5135, 0.9533, 0.6164, 0.4412, 0.9812,
          0.5579, 0.1630, 0.3495, 0.3664, 0.6503, 0.6115, 0.4864, 0.4534, 0.8528,
          0.9964, 0.8616, 0.1254, 0.4540, 0.6587, 0.3288, 0.5619, 0.4516, 0.6349,
          0.0433, 0.1487, 0.5109, 0.8212, 0.6952, 0.2594, 0.2266, 0.0285, 0.4658,
          0.2908, 0.3484, 0.6546, 0.5795, 0.9769, 0.1132, 0.3941, 0.2193, 0.3397,
          0.3034, 0.5778, 0.9488, 0.2916, 0.1261, 0.1648, 0.3977, 0.3013, 0.5378,
          0.5231, 0.3439, 0.7011, 0.9639, 0.0647, 0.4299, 0.6242, 0.0108, 0.6482,
          0.5631, 0.0433, 0.0072, 0.2714, 0.5814, 0.8436, 0.2873, 0.1828, 0.0207,
          0.3880, 0.0056, 0.0892, 0.3005, 0.0208, 0.8481, 0.4182, 0.5781, 0.1621,
          0.0307, 0.1083, 0.7884, 0.4944, 0.6945, 0.8205, 0.9389, 0.6657, 0.6617,
          0.5277, 0.2019, 0.6943, 0.1419, 0.1570, 0.9135, 0.8660, 0.9133, 0.2667,
          0.8609, 0.2811, 0.2786, 0.3341, 0.5719, 0.0939, 0.9498, 0.3742, 0.5940,
          0.5593, 0.1292, 0.9932, 0.7358, 0.2080, 0.1342, 0.3176, 0.8044, 0.4599,
          0.7427, 0.5190, 0.7722, 0.9083, 0.4219, 0.2230, 0.4722, 0.0647, 0.4213,
          0.7184, 0.3155, 0.5858, 0.0020, 0.2533, 0.8639, 0.2506, 0.8492, 0.0980,
          0.3719, 0.6480, 0.8644, 0.0336, 0.6201, 0.8107, 0.3128, 0.7263, 0.3085,
          0.5454, 0.8335, 0.9835, 0.1103, 0.5264, 0.6616, 0.4177, 0.4773, 0.9393,
          0.4111, 0.0221, 0.5058, 0.4673, 0.1570, 0.5689, 0.4166, 0.1995, 0.0650,
          0.2613, 0.9611, 0.0214, 0.1148, 0.1766, 0.0584, 0.4730, 0.3569, 0.8454,
          0.0632, 0.0399, 0.0247, 0.0149, 0.9001, 0.6240, 0.5441, 0.4143, 0.4645,
          0.9748, 0.8132, 0.1683, 0.0079, 0.3274, 0.5488, 0.0620, 0.7076, 0.5459,
          0.7736, 0.8970, 0.0664, 0.2715, 0.7240, 0.3534, 0.8843, 0.7607, 0.9868,
          0.9120, 0.1948, 0.3082, 0.2411, 0.9100, 0.9996, 0.4892, 0.0591, 0.1281,
          0.5867, 0.4562, 0.9857, 0.4760, 0.3758, 0.6793, 0.3710, 0.1643, 0.5407,
          0.7611, 0.5247, 0.9334, 0.5220, 0.0357, 0.0792, 0.7963, 0.8953, 0.6581,
          0.5346, 0.8336, 0.1668, 0.0208, 0.3757, 0.2779, 0.9020, 0.9528, 0.8381,
          0.9747, 0.4669, 0.7524, 0.9672, 0.7720, 0.5056, 0.0305, 0.2281, 0.1398,
          0.0985, 0.7852, 0.4494, 0.0256, 0.8155, 0.5599, 0.3325, 0.0545, 0.5642,
          0.0810, 0.6742, 0.0685, 0.0211, 0.1338, 0.9364, 0.7557, 0.0844, 0.3617,
          0.6427, 0.0342, 0.0620, 0.0150, 0.5888, 0.9062, 0.4438, 0.1808, 0.1418,
          0.8282, 0.4593, 0.2427, 0.5373, 0.5639, 0.8593, 0.7313, 0.6849, 0.2168,
          0.0089, 0.8486, 0.4724, 0.1263, 0.3424, 0.9206, 0.4476, 0.4695, 0.9656,
          0.4814, 0.6332, 0.1991, 0.4935, 0.4874, 0.2545, 0.3843, 0.9045, 0.9028,
          0.6466, 0.6465, 0.4072, 0.2926, 0.6517, 0.6999, 0.0480, 0.8689, 0.0674,
          0.4650, 0.8179, 0.2491, 0.7135, 0.7033, 0.5491, 0.0358, 0.2785, 0.2789,
          0.5865, 0.8865, 0.2058, 0.2914, 0.2535, 0.8645, 0.2759, 0.2899, 0.2394,
          0.3443, 0.6642, 0.4243, 0.9868, 0.4510, 0.4427, 0.3788, 0.2419, 0.0565,
          0.4474, 0.3293, 0.5249, 0.4127, 0.4185, 0.7206, 0.8775, 0.7980, 0.2288,
          0.7330, 0.6557, 0.6985, 0.0708, 0.6435, 0.5381, 0.2522, 0.4100, 0.3287,
          0.6213, 0.8059, 0.9592, 0.1105, 0.0664, 0.7876, 0.3266, 0.3887, 0.1051,
          0.1530, 0.7652, 0.4660, 0.1248, 0.4386, 0.3522, 0.4491, 0.0429, 0.1707,
          0.3630, 0.5893, 0.3732, 0.6447, 0.8251, 0.9409, 0.8528, 0.6812]]),
 'This': tensor([[0.2152, 0.4910, 0.0555, 0.3965, 0.1361, 0.6333, 0.7874, 0.3262, 0.9533,
          0.9151, 0.3858, 0.0729, 0.3577, 0.9574, 0.0961, 0.8687, 0.7329, 0.7161,
          0.0028, 0.1878, 0.1578, 0.9973, 0.7892, 0.1291, 0.5100, 0.2113, 0.7500,
          0.7518, 0.6969, 0.6678, 0.7071, 0.2619, 0.0511, 0.2766, 0.1498, 0.8881,
          0.8278, 0.5870, 0.4107, 0.1839, 0.3003, 0.1517, 0.9936, 0.0977, 0.8828,
          0.5137, 0.8722, 0.1515, 0.0191, 0.9137, 0.4782, 0.7324, 0.2876, 0.9247,
          0.3920, 0.8587, 0.8911, 0.7247, 0.4897, 0.8079, 0.2067, 0.8566, 0.8426,
          0.6401, 0.5114, 0.9205, 0.9048, 0.6201, 0.7892, 0.0939, 0.5736, 0.0632,
          0.0020, 0.8377, 0.5328, 0.5833, 0.5827, 0.7059, 0.0906, 0.0394, 0.5706,
          0.4251, 0.6540, 0.2355, 0.0076, 0.8425, 0.1208, 0.3536, 0.9235, 0.5518,
          0.7423, 0.0021, 0.9209, 0.0995, 0.1130, 0.9392, 0.9237, 0.0144, 0.2593,
          0.1682, 0.1468, 0.8025, 0.3270, 0.3097, 0.1350, 0.0175, 0.7042, 0.7312,
          0.4253, 0.6570, 0.8914, 0.1020, 0.2238, 0.9727, 0.4450, 0.1094, 0.7162,
          0.5889, 0.3312, 0.0515, 0.9068, 0.9424, 0.1927, 0.9883, 0.8952, 0.3729,
          0.6553, 0.3309, 0.7760, 0.4039, 0.2436, 0.7270, 0.3019, 0.1070, 0.0986,
          0.8868, 0.2003, 0.9761, 0.5982, 0.2740, 0.7520, 0.7361, 0.4371, 0.3868,
          0.4604, 0.1003, 0.7034, 0.1130, 0.6363, 0.7203, 0.4812, 0.8814, 0.2132,
          0.6425, 0.4852, 0.4419, 0.8754, 0.8906, 0.9516, 0.6100, 0.1276, 0.2572,
          0.5610, 0.3167, 0.0756, 0.6781, 0.1017, 0.7164, 0.7223, 0.4071, 0.1822,
          0.9271, 0.7066, 0.1327, 0.9106, 0.1312, 0.1806, 0.0825, 0.1450, 0.5540,
          0.0373, 0.6595, 0.4283, 0.6228, 0.5348, 0.8336, 0.7451, 0.5372, 0.3442,
          0.3210, 0.6286, 0.6358, 0.7323, 0.3296, 0.9592, 0.6645, 0.9122, 0.7191,
          0.2951, 0.5336, 0.2864, 0.5300, 0.7070, 0.2938, 0.4060, 0.1593, 0.9121,
          0.8149, 0.8110, 0.2749, 0.8883, 0.0882, 0.9458, 0.0273, 0.4802, 0.7409,
          0.5075, 0.8512, 0.8194, 0.7712, 0.1341, 0.5139, 0.8650, 0.9187, 0.5583,
          0.2440, 0.0451, 0.5297, 0.4512, 0.1062, 0.1365, 0.3634, 0.8234, 0.5071,
          0.9937, 0.0962, 0.1252, 0.0470, 0.0672, 0.5666, 0.9003, 0.9503, 0.6575,
          0.1329, 0.8781, 0.2141, 0.7359, 0.1347, 0.1087, 0.1305, 0.2735, 0.0370,
          0.1726, 0.9887, 0.8507, 0.9882, 0.7723, 0.0116, 0.2307, 0.4238, 0.5770,
          0.7707, 0.7474, 0.1533, 0.6648, 0.2473, 0.6814, 0.8091, 0.9181, 0.3027,
          0.8469, 0.3801, 0.7775, 0.9613, 0.5931, 0.6049, 0.1286, 0.3215, 0.6964,
          0.0776, 0.6338, 0.1262, 0.1429, 0.9059, 0.5601, 0.6938, 0.2413, 0.8679,
          0.7622, 0.1362, 0.0593, 0.4532, 0.9105, 0.0666, 0.5179, 0.2075, 0.2778,
          0.0293, 0.3237, 0.1635, 0.7588, 0.4740, 0.9886, 0.6257, 0.5496, 0.7624,
          0.9422, 0.4926, 0.3554, 0.8820, 0.5423, 0.3948, 0.3014, 0.4143, 0.8609,
          0.3807, 0.1924, 0.3661, 0.0327, 0.1262, 0.2671, 0.7440, 0.0733, 0.7692,
          0.5435, 0.6322, 0.1606, 0.7510, 0.7596, 0.5263, 0.3666, 0.8887, 0.5361,
          0.9546, 0.2630, 0.4031, 0.2907, 0.5361, 0.4828, 0.0992, 0.3085, 0.5243,
          0.5319, 0.0414, 0.0446, 0.5212, 0.8314, 0.0212, 0.1592, 0.5111, 0.1466,
          0.5830, 0.2077, 0.1974, 0.1766, 0.6371, 0.9192, 0.2996, 0.9078, 0.0788,
          0.4067, 0.2479, 0.8717, 0.7996, 0.1500, 0.0510, 0.9198, 0.4004, 0.7755,
          0.4903, 0.1689, 0.0805, 0.4850, 0.0605, 0.3738, 0.2233, 0.5349, 0.7080,
          0.4206, 0.8056, 0.8200, 0.3219, 0.2254, 0.2013, 0.2662, 0.8581, 0.1497,
          0.7531, 0.3715, 0.9141, 0.9732, 0.8252, 0.1709, 0.1940, 0.3923, 0.8825,
          0.7286, 0.3949, 0.3951, 0.3604, 0.0327, 0.2438, 0.1726, 0.5121, 0.6182,
          0.1685, 0.7596, 0.3971, 0.6145, 0.9902, 0.0415, 0.1528, 0.5079, 0.3273,
          0.9937, 0.3309, 0.0042, 0.0466, 0.4192, 0.7061, 0.3765, 0.6425, 0.8546,
          0.0054, 0.4596, 0.5703, 0.0192, 0.9992, 0.3187, 0.4929, 0.9040, 0.2235,
          0.2907, 0.1699, 0.3323, 0.4408, 0.6577, 0.4336, 0.6870, 0.8824, 0.9694,
          0.6050, 0.3246, 0.0858, 0.4210, 0.4980, 0.2536, 0.9378, 0.7191, 0.0937,
          0.4586, 0.7615, 0.1342, 0.6374, 0.8187, 0.9778, 0.1334, 0.9411, 0.4934,
          0.9919, 0.4043, 0.0480, 0.2767, 0.1688, 0.5246, 0.8663, 0.8355, 0.2361,
          0.5820, 0.7151, 0.0703, 0.8998, 0.5550, 0.3172, 0.2989, 0.0271, 0.6052,
          0.2344, 0.7189, 0.4151, 0.0471, 0.4596, 0.6512, 0.7085, 0.8491, 0.4976,
          0.3416, 0.9925, 0.8153, 0.8217, 0.5180, 0.1327, 0.6258, 0.8391, 0.1805,
          0.0258, 0.3462, 0.0355, 0.6398, 0.5492, 0.1894, 0.4301, 0.1394, 0.6999,
          0.0814, 0.3038, 0.1129, 0.7487, 0.6940, 0.6697, 0.8853, 0.9709]]),
 'is': tensor([[0.3030, 0.3308, 0.7832, 0.3432, 0.6770, 0.3670, 0.3881, 0.5339, 0.8059,
          0.7567, 0.2484, 0.1516, 0.0069, 0.9820, 0.7071, 0.4149, 0.3131, 0.5795,
          0.5673, 0.5513, 0.8398, 0.0250, 0.2461, 0.7250, 0.8397, 0.0186, 0.5228,
          0.9001, 0.5652, 0.9182, 0.3838, 0.5976, 0.1207, 0.6238, 0.1942, 0.7772,
          0.9609, 0.0902, 0.8891, 0.3836, 0.8912, 0.9970, 0.2842, 0.0047, 0.7782,
          0.8488, 0.8091, 0.4175, 0.8788, 0.8099, 0.1839, 0.1782, 0.2805, 0.9235,
          0.3507, 0.2638, 0.6751, 0.9793, 0.6889, 0.3542, 0.3580, 0.7339, 0.4761,
          0.0497, 0.4580, 0.5394, 0.0712, 0.7598, 0.8396, 0.3756, 0.7896, 0.7548,
          0.6100, 0.9465, 0.2353, 0.3506, 0.6422, 0.3654, 0.4055, 0.9402, 0.9394,
          0.6361, 0.0640, 0.8750, 0.2851, 0.9939, 0.8866, 0.8228, 0.4693, 0.1686,
          0.7326, 0.8852, 0.2968, 0.8017, 0.2980, 0.1850, 0.7003, 0.9568, 0.6192,
          0.9192, 0.6883, 0.2935, 0.4767, 0.1703, 0.1334, 0.4364, 0.1314, 0.4642,
          0.5792, 0.4950, 0.9560, 0.4920, 0.4559, 0.7454, 0.3997, 0.1666, 0.2848,
          0.6258, 0.8711, 0.0594, 0.7980, 0.0463, 0.8272, 0.3779, 0.4429, 0.3792,
          0.1710, 0.1790, 0.5089, 0.0110, 0.1680, 0.2550, 0.2730, 0.3912, 0.2280,
          0.2541, 0.5304, 0.3876, 0.2125, 0.9358, 0.4324, 0.9020, 0.4473, 0.5415,
          0.7728, 0.7709, 0.2019, 0.3371, 0.5245, 0.9472, 0.2037, 0.5409, 0.3365,
          0.5423, 0.6817, 0.8452, 0.4061, 0.6297, 0.7746, 0.8903, 0.1268, 0.6407,
          0.5464, 0.1595, 0.6258, 0.0076, 0.5248, 0.3342, 0.3432, 0.2398, 0.3887,
          0.9056, 0.7747, 0.4152, 0.8183, 0.5351, 0.7446, 0.5326, 0.6939, 0.4962,
          0.2977, 0.7936, 0.5097, 0.6000, 0.5075, 0.9794, 0.3765, 0.7564, 0.0796,
          0.5022, 0.5688, 0.1814, 0.0896, 0.4731, 0.1762, 0.4979, 0.4691, 0.4773,
          0.6575, 0.3425, 0.9001, 0.9767, 0.0455, 0.0552, 0.6302, 0.3974, 0.1900,
          0.9447, 0.1032, 0.3774, 0.1155, 0.9186, 0.4717, 0.6973, 0.9532, 0.1219,
          0.9037, 0.0897, 0.6306, 0.5702, 0.2796, 0.4481, 0.5185, 0.9985, 0.7970,
          0.0367, 0.5677, 0.5592, 0.7575, 0.6962, 0.3738, 0.3663, 0.9076, 0.1186,
          0.4297, 0.9126, 0.8874, 0.7698, 0.7921, 0.1246, 0.5731, 0.8023, 0.2925,
          0.5561, 0.5587, 0.9941, 0.4280, 0.5272, 0.9786, 0.0279, 0.2102, 0.3307,
          0.2436, 0.1124, 0.0205, 0.3292, 0.2912, 0.9293, 0.5807, 0.7178, 0.1287,
          0.8301, 0.2178, 0.0265, 0.1687, 0.0738, 0.3559, 0.6302, 0.7167, 0.6572,
          0.7784, 0.8716, 0.7321, 0.4485, 0.0304, 0.5796, 0.5450, 0.3381, 0.4131,
          0.3714, 0.1017, 0.2971, 0.0956, 0.3406, 0.6677, 0.5164, 0.3263, 0.7715,
          0.5499, 0.1281, 0.9226, 0.0208, 0.6375, 0.6504, 0.4582, 0.4054, 0.5053,
          0.3062, 0.5446, 0.6215, 0.5950, 0.1893, 0.9667, 0.5030, 0.3088, 0.8503,
          0.9194, 0.9775, 0.4632, 0.6326, 0.4333, 0.4616, 0.4370, 0.1435, 0.9571,
          0.6458, 0.4380, 0.8930, 0.4390, 0.2301, 0.7900, 0.8016, 0.4070, 0.7823,
          0.3949, 0.8369, 0.7405, 0.0283, 0.3859, 0.4705, 0.5998, 0.5095, 0.6932,
          0.6346, 0.8307, 0.2791, 0.1016, 0.1927, 0.8618, 0.5629, 0.6871, 0.4247,
          0.8271, 0.8167, 0.9779, 0.8438, 0.5253, 0.3824, 0.1098, 0.0614, 0.2007,
          0.1527, 0.5151, 0.8951, 0.1158, 0.3462, 0.1500, 0.0512, 0.6911, 0.8736,
          0.5641, 0.3516, 0.9979, 0.7322, 0.1876, 0.9804, 0.2973, 0.1629, 0.8169,
          0.2693, 0.5316, 0.8616, 0.0331, 0.7843, 0.7018, 0.7254, 0.2218, 0.8796,
          0.7971, 0.3426, 0.3906, 0.1256, 0.7299, 0.7841, 0.8594, 0.7824, 0.9124,
          0.5460, 0.2922, 0.6400, 0.8265, 0.7215, 0.1594, 0.4309, 0.3737, 0.1225,
          0.6984, 0.7075, 0.8405, 0.8004, 0.1104, 0.0504, 0.5213, 0.6874, 0.8836,
          0.6495, 0.7992, 0.1850, 0.1875, 0.1049, 0.6148, 0.4339, 0.2335, 0.4557,
          0.7283, 0.9005, 0.9859, 0.7186, 0.3309, 0.0278, 0.8650, 0.9548, 0.1230,
          0.0607, 0.2911, 0.4514, 0.0188, 0.2090, 0.7346, 0.3488, 0.1924, 0.0545,
          0.8849, 0.2997, 0.2874, 0.9031, 0.3072, 0.2281, 0.2750, 0.3751, 0.9533,
          0.4381, 0.2783, 0.5745, 0.7898, 0.6750, 0.0329, 0.3460, 0.9219, 0.2823,
          0.0343, 0.6328, 0.8489, 0.5083, 0.4986, 0.0184, 0.8050, 0.8948, 0.2779,
          0.7656, 0.6692, 0.7213, 0.2558, 0.5314, 0.3965, 0.5245, 0.4547, 0.2842,
          0.5265, 0.3705, 0.2827, 0.3007, 0.6895, 0.1027, 0.1035, 0.1607, 0.6210,
          0.8027, 0.5967, 0.9906, 0.4602, 0.9568, 0.0201, 0.3171, 0.5704, 0.7176,
          0.9687, 0.2517, 0.9186, 0.5749, 0.0969, 0.5811, 0.1428, 0.2323, 0.1737,
          0.8464, 0.8527, 0.9234, 0.2643, 0.6641, 0.3599, 0.7717, 0.8772, 0.1905,
          0.5952, 0.7481, 0.5463, 0.1475, 0.3530, 0.9743, 0.7039, 0.8784]]),
 'terrible': tensor([[0.5855, 0.2649, 0.6824, 0.9090, 0.6383, 0.5178, 0.6379, 0.7566, 0.3292,
          0.8116, 0.9876, 0.5941, 0.2604, 0.4238, 0.9928, 0.9324, 0.5274, 0.6950,
          0.3834, 0.9103, 0.9383, 0.5604, 0.5298, 0.9904, 0.8569, 0.0077, 0.4592,
          0.5985, 0.2869, 0.5493, 0.0760, 0.6456, 0.1898, 0.9665, 0.4362, 0.3686,
          0.5383, 0.4851, 0.4171, 0.4594, 0.6626, 0.2002, 0.3302, 0.8022, 0.8321,
          0.7010, 0.1242, 0.1441, 0.9586, 0.0031, 0.1399, 0.8919, 0.9431, 0.6082,
          0.4491, 0.3690, 0.2415, 0.6112, 0.0133, 0.6676, 0.6588, 0.6107, 0.4231,
          0.4005, 0.0541, 0.3298, 0.1880, 0.0603, 0.7156, 0.5508, 0.5516, 0.9828,
          0.8827, 0.7909, 0.9486, 0.9864, 0.2332, 0.8627, 0.6487, 0.2633, 0.5856,
          0.9812, 0.5241, 0.3584, 0.2319, 0.5050, 0.1728, 0.5274, 0.6476, 0.1978,
          0.7118, 0.0247, 0.8701, 0.9261, 0.4751, 0.0861, 0.1009, 0.3861, 0.5982,
          0.5560, 0.7288, 0.3407, 0.7503, 0.6544, 0.1126, 0.8854, 0.0364, 0.5838,
          0.9058, 0.6230, 0.3557, 0.4391, 0.8061, 0.7133, 0.8940, 0.6705, 0.2336,
          0.2213, 0.3162, 0.6932, 0.7178, 0.3505, 0.8724, 0.8739, 0.4647, 0.9227,
          0.1742, 0.1208, 0.7432, 0.1683, 0.9768, 0.8919, 0.3370, 0.2105, 0.2207,
          0.9693, 0.5896, 0.6219, 0.7924, 0.8424, 0.4448, 0.8825, 0.7603, 0.8306,
          0.8466, 0.2848, 0.0104, 0.5962, 0.1909, 0.1904, 0.2238, 0.6091, 0.6532,
          0.5450, 0.1255, 0.9367, 0.1976, 0.2342, 0.9886, 0.9213, 0.2628, 0.7929,
          0.9508, 0.7999, 0.2138, 0.2811, 0.3993, 0.3670, 0.3357, 0.5695, 0.6912,
          0.0586, 0.2857, 0.5345, 0.8322, 0.5131, 0.9281, 0.1931, 0.1783, 0.6845,
          0.6870, 0.2231, 0.0015, 0.7058, 0.9630, 0.7360, 0.4003, 0.7135, 0.9743,
          0.2934, 0.0876, 0.5936, 0.0575, 0.4957, 0.0220, 0.3590, 0.6742, 0.3618,
          0.2612, 0.3174, 0.7051, 0.0552, 0.0723, 0.9834, 0.6150, 0.7896, 0.7340,
          0.6538, 0.8430, 0.7414, 0.4367, 0.2961, 0.9244, 0.0072, 0.1066, 0.2035,
          0.5208, 0.6254, 0.3837, 0.9157, 0.1859, 0.3085, 0.6779, 0.3285, 0.7767,
          0.9036, 0.8840, 0.7617, 0.6616, 0.3860, 0.4358, 0.7442, 0.6677, 0.6069,
          0.8826, 0.4881, 0.8062, 0.6500, 0.4992, 0.3511, 0.8887, 0.1136, 0.2732,
          0.8419, 0.3908, 0.6561, 0.5277, 0.5864, 0.7776, 0.9331, 0.0249, 0.8702,
          0.0520, 0.6555, 0.2014, 0.0781, 0.0642, 0.9756, 0.3659, 0.4728, 0.6298,
          0.3379, 0.9141, 0.8182, 0.2483, 0.0842, 0.0370, 0.9076, 0.2064, 0.4826,
          0.8089, 0.6053, 0.2862, 0.0709, 0.8495, 0.9843, 0.3027, 0.2665, 0.1964,
          0.6025, 0.6537, 0.0985, 0.6722, 0.8917, 0.8172, 0.3095, 0.2762, 0.9319,
          0.7828, 0.2516, 0.0755, 0.1130, 0.5628, 0.8518, 0.3612, 0.2398, 0.4443,
          0.8719, 0.6261, 0.5423, 0.4970, 0.5225, 0.6409, 0.7047, 0.2922, 0.8604,
          0.4449, 0.1486, 0.1151, 0.2543, 0.4026, 0.4156, 0.1399, 0.9659, 0.1471,
          0.4652, 0.3409, 0.4933, 0.2884, 0.8774, 0.8924, 0.8515, 0.0341, 0.1024,
          0.2527, 0.7050, 0.6732, 0.4113, 0.8484, 0.6215, 0.0581, 0.3636, 0.4854,
          0.2169, 0.2375, 0.8805, 0.3963, 0.7790, 0.9966, 0.1225, 0.0534, 0.4620,
          0.3919, 0.5312, 0.6177, 0.3293, 0.5108, 0.5008, 0.0812, 0.6598, 0.9941,
          0.4989, 0.2650, 0.7844, 0.9584, 0.2155, 0.6457, 0.9917, 0.7261, 0.6733,
          0.1342, 0.5628, 0.3342, 0.0644, 0.7751, 0.6443, 0.6228, 0.7611, 0.6238,
          0.4463, 0.3433, 0.9535, 0.0861, 0.3667, 0.0380, 0.0787, 0.0663, 0.5057,
          0.3755, 0.8104, 0.1085, 0.8367, 0.5712, 0.6708, 0.8172, 0.6668, 0.5640,
          0.9033, 0.3231, 0.2261, 0.8740, 0.3122, 0.6422, 0.6556, 0.6612, 0.3385,
          0.4765, 0.6801, 0.0279, 0.9051, 0.7723, 0.7153, 0.2112, 0.2691, 0.5329,
          0.9408, 0.0958, 0.9265, 0.1441, 0.9691, 0.4358, 0.9721, 0.8310, 0.4063,
          0.1211, 0.1652, 0.9579, 0.2144, 0.2206, 0.4603, 0.0270, 0.6834, 0.3463,
          0.9988, 0.7037, 0.6884, 0.8875, 0.7038, 0.7547, 0.9815, 0.7266, 0.3914,
          0.8555, 0.3167, 0.6401, 0.2866, 0.5294, 0.3231, 0.5644, 0.0762, 0.1628,
          0.3396, 0.9040, 0.5072, 0.6015, 0.3013, 0.4791, 0.8451, 0.7284, 0.3798,
          0.1125, 0.4204, 0.4121, 0.5730, 0.4844, 0.9696, 0.3687, 0.6641, 0.3391,
          0.7629, 0.1631, 0.5782, 0.1350, 0.7313, 0.0616, 0.4614, 0.2980, 0.3153,
          0.1478, 0.0893, 0.2438, 0.4957, 0.1557, 0.6529, 0.9222, 0.2138, 0.6292,
          0.3316, 0.8129, 0.1526, 0.7863, 0.5225, 0.0305, 0.4374, 0.8123, 0.9999,
          0.0848, 0.3728, 0.5375, 0.6418, 0.8558, 0.8129, 0.1835, 0.4037, 0.8876,
          0.6153, 0.1861, 0.4165, 0.3945, 0.9856, 0.4220, 0.5546, 0.6610, 0.0379,
          0.7953, 0.5303, 0.4504, 0.6455, 0.4149, 0.8721, 0.9896, 0.1480]]),
 'Could': tensor([[0.1788, 0.6998, 0.0160, 0.1890, 0.0940, 0.7954, 0.3136, 0.7015, 0.4837,
          0.0895, 0.0797, 0.2087, 0.1809, 0.8542, 0.1615, 0.7618, 0.7180, 0.2226,
          0.7713, 0.5862, 0.1458, 0.5617, 0.0396, 0.9547, 0.2004, 0.7466, 0.1500,
          0.3556, 0.6314, 0.3376, 0.8765, 0.3809, 0.9087, 0.1862, 0.7172, 0.1081,
          0.9433, 0.0726, 0.7779, 0.4378, 0.5313, 0.2385, 0.2930, 0.4348, 0.4242,
          0.5042, 0.0158, 0.8615, 0.4147, 0.0137, 0.9537, 0.9339, 0.7837, 0.4780,
          0.1561, 0.0336, 0.7818, 0.0290, 0.5488, 0.5292, 0.6822, 0.0223, 0.0876,
          0.4658, 0.3108, 0.9164, 0.6847, 0.7528, 0.0842, 0.3358, 0.8691, 0.4841,
          0.1806, 0.9105, 0.3086, 0.1684, 0.4506, 0.0938, 0.8516, 0.9538, 0.0858,
          0.7284, 0.1106, 0.2432, 0.4368, 0.8324, 0.3482, 0.3612, 0.2654, 0.2153,
          0.5967, 0.8784, 0.6714, 0.1163, 0.0083, 0.0897, 0.2375, 0.5917, 0.8249,
          0.7432, 0.8105, 0.6229, 0.8174, 0.9466, 0.8997, 0.9439, 0.1664, 0.8562,
          0.6016, 0.0595, 0.6178, 0.9991, 0.4914, 0.3815, 0.1582, 0.7975, 0.5395,
          0.3392, 0.1656, 0.1989, 0.2411, 0.6063, 0.8013, 0.7509, 0.3138, 0.0218,
          0.6975, 0.5623, 0.4121, 0.0812, 0.6906, 0.6625, 0.9617, 0.9290, 0.9597,
          0.5498, 0.3561, 0.8026, 0.0862, 0.3672, 0.3850, 0.8285, 0.6327, 0.1792,
          0.8105, 0.6716, 0.4303, 0.2275, 0.6551, 0.3331, 0.8412, 0.4508, 0.3562,
          0.5563, 0.4268, 0.6573, 0.1124, 0.3116, 0.4107, 0.8613, 0.7302, 0.7061,
          0.5374, 0.6306, 0.6932, 0.6783, 0.9677, 0.5996, 0.2719, 0.9087, 0.9796,
          0.3444, 0.6142, 0.3418, 0.4510, 0.1936, 0.5485, 0.0666, 0.2781, 0.5119,
          0.2152, 0.9569, 0.9942, 0.3539, 0.1079, 0.8733, 0.3200, 0.1662, 0.5152,
          0.7219, 0.3570, 0.5602, 0.0545, 0.7758, 0.3827, 0.5439, 0.9522, 0.3348,
          0.7307, 0.0163, 0.3303, 0.3566, 0.5168, 0.8909, 0.3853, 0.8440, 0.0968,
          0.5692, 0.5761, 0.2599, 0.3675, 0.0491, 0.7342, 0.9233, 0.2408, 0.3962,
          0.1528, 0.1557, 0.9776, 0.4741, 0.6635, 0.6183, 0.7728, 0.2918, 0.4202,
          0.4036, 0.3600, 0.3673, 0.3926, 0.1015, 0.7987, 0.7533, 0.4966, 0.2592,
          0.5699, 0.4519, 0.9597, 0.8224, 0.1883, 0.1419, 0.3647, 0.2986, 0.2460,
          0.6712, 0.4328, 0.8304, 0.4397, 0.8851, 0.8078, 0.5798, 0.9265, 0.9872,
          0.3890, 0.8151, 0.2867, 0.8307, 0.3206, 0.8076, 0.6741, 0.1235, 0.7910,
          0.0214, 0.4734, 0.9259, 0.0419, 0.3624, 0.1569, 0.1643, 0.1188, 0.1038,
          0.8837, 0.7056, 0.9005, 0.8191, 0.7361, 0.2989, 0.0993, 0.8256, 0.3045,
          0.0634, 0.5866, 0.0510, 0.7024, 0.4911, 0.8372, 0.2586, 0.1036, 0.2710,
          0.7805, 0.5793, 0.5216, 0.7658, 0.8893, 0.1946, 0.0727, 0.4426, 0.1797,
          0.9500, 0.0181, 0.5121, 0.6000, 0.4243, 0.8502, 0.1441, 0.1795, 0.0567,
          0.4973, 0.3219, 0.4613, 0.6974, 0.4251, 0.6031, 0.8475, 0.0373, 0.6037,
          0.1798, 0.6151, 0.8773, 0.8287, 0.3949, 0.0762, 0.2927, 0.7165, 0.2687,
          0.1732, 0.0568, 0.5171, 0.4933, 0.6631, 0.3267, 0.6741, 0.0314, 0.3885,
          0.0408, 0.2617, 0.1316, 0.6290, 0.4145, 0.4740, 0.4730, 0.7773, 0.7539,
          0.8353, 0.2834, 0.0251, 0.8236, 0.1803, 0.6495, 0.8845, 0.2459, 0.5243,
          0.9615, 0.5642, 0.6382, 0.4954, 0.1166, 0.2645, 0.0150, 0.1523, 0.0036,
          0.2387, 0.5321, 0.7154, 0.1339, 0.2806, 0.6830, 0.5548, 0.6044, 0.1094,
          0.5140, 0.6117, 0.7182, 0.5254, 0.7846, 0.3237, 0.0938, 0.0333, 0.7609,
          0.3454, 0.8746, 0.5198, 0.3485, 0.1686, 0.4796, 0.7440, 0.1327, 0.0792,
          0.7244, 0.2773, 0.0737, 0.6383, 0.6044, 0.1403, 0.4693, 0.9916, 0.4180,
          0.8532, 0.2074, 0.8764, 0.2850, 0.5225, 0.5624, 0.9610, 0.6308, 0.2271,
          0.7003, 0.5000, 0.6851, 0.8271, 0.8589, 0.2921, 0.4742, 0.6079, 0.6293,
          0.4085, 0.2973, 0.8105, 0.7952, 0.2944, 0.6160, 0.0356, 0.5531, 0.3121,
          0.3025, 0.0667, 0.4196, 0.7257, 0.1835, 0.3341, 0.2990, 0.8497, 0.9446,
          0.9237, 0.1993, 0.0614, 0.3053, 0.9288, 0.1086, 0.1145, 0.3706, 0.0844,
          0.4932, 0.7999, 0.1245, 0.4541, 0.3333, 0.1474, 0.6295, 0.5708, 0.7338,
          0.4088, 0.3523, 0.2832, 0.3943, 0.3101, 0.1963, 0.8811, 0.0897, 0.3835,
          0.3020, 0.4282, 0.7927, 0.5443, 0.8612, 0.8743, 0.6701, 0.0074, 0.2139,
          0.7204, 0.0148, 0.2753, 0.4811, 0.9111, 0.2329, 0.4516, 0.5727, 0.3958,
          0.4733, 0.7981, 0.8306, 0.0941, 0.9213, 0.3481, 0.3871, 0.6267, 0.6558,
          0.9811, 0.1971, 0.1213, 0.5549, 0.8222, 0.8626, 0.6895, 0.2755, 0.5796,
          0.7274, 0.2519, 0.5316, 0.6140, 0.7786, 0.0670, 0.7691, 0.7789, 0.1022,
          0.4395, 0.4710, 0.4678, 0.0172, 0.4924, 0.9245, 0.8080, 0.9533]]),
 'be': tensor([[0.4259, 0.0446, 0.9524, 0.0021, 0.1511, 0.3268, 0.8524, 0.7908, 0.6162,
          0.3236, 0.4747, 0.8606, 0.3320, 0.7735, 0.7849, 0.8085, 0.5480, 0.6390,
          0.0715, 0.8693, 0.0222, 0.7354, 0.7427, 0.8260, 0.7553, 0.0352, 0.6216,
          0.5998, 0.4958, 0.2415, 0.9807, 0.8636, 0.6797, 0.6578, 0.3328, 0.8724,
          0.2234, 0.0625, 0.4281, 0.1008, 0.5074, 0.8294, 0.5806, 0.3830, 0.7060,
          0.7566, 0.7643, 0.9890, 0.2424, 0.4778, 0.2005, 0.4776, 0.4702, 0.3608,
          0.5671, 0.8091, 0.4186, 0.9737, 0.4332, 0.7190, 0.3768, 0.6399, 0.6785,
          0.5693, 0.5487, 0.2541, 0.6601, 0.2578, 0.4528, 0.3369, 0.6551, 0.5154,
          0.5596, 0.8055, 0.4951, 0.1091, 0.1779, 0.1638, 0.1562, 0.6283, 0.9421,
          0.6715, 0.4774, 0.7392, 0.1694, 0.6254, 0.5071, 0.3280, 0.0305, 0.1431,
          0.1214, 0.8620, 0.3829, 0.3721, 0.3436, 0.1989, 0.5390, 0.4001, 0.6280,
          0.6236, 0.9714, 0.6641, 0.2369, 0.2261, 0.8512, 0.0986, 0.7810, 0.2973,
          0.6482, 0.2837, 0.7455, 0.9275, 0.5073, 0.1076, 0.1933, 0.1886, 0.2278,
          0.3753, 0.0713, 0.0312, 0.3417, 0.7750, 0.4519, 0.1222, 0.7121, 0.2883,
          0.8322, 0.3292, 0.5157, 0.4997, 0.8104, 0.8594, 0.9138, 0.6142, 0.1827,
          0.5338, 0.2250, 0.7524, 0.9458, 0.5284, 0.4220, 0.6203, 0.7043, 0.2702,
          0.4812, 0.2725, 0.9084, 0.0365, 0.5207, 0.2783, 0.9685, 0.3701, 0.2096,
          0.8933, 0.4824, 0.2579, 0.9848, 0.6817, 0.0732, 0.7080, 0.2085, 0.4102,
          0.3211, 0.0898, 0.9920, 0.8842, 0.6913, 0.2279, 0.7270, 0.6377, 0.1534,
          0.9092, 0.9293, 0.4296, 0.4923, 0.8726, 0.2771, 0.3967, 0.1407, 0.9737,
          0.8254, 0.8946, 0.9787, 0.6632, 0.3734, 0.5907, 0.9940, 0.3351, 0.6804,
          0.2651, 0.7226, 0.2139, 0.9883, 0.0164, 0.9949, 0.3535, 0.4192, 0.7689,
          0.8477, 0.5729, 0.1691, 0.7991, 0.1324, 0.8607, 0.3467, 0.6266, 0.2704,
          0.5738, 0.3834, 0.2695, 0.8938, 0.4237, 0.2848, 0.4830, 0.0658, 0.2850,
          0.0720, 0.7665, 0.6722, 0.2854, 0.3721, 0.5166, 0.2293, 0.9633, 0.6571,
          0.6163, 0.8262, 0.5901, 0.4052, 0.4639, 0.7784, 0.5495, 0.8361, 0.9706,
          0.2000, 0.3211, 0.4290, 0.5401, 0.2917, 0.0627, 0.4948, 0.7593, 0.9441,
          0.7536, 0.2460, 0.7210, 0.4587, 0.6528, 0.7568, 0.2573, 0.9008, 0.1560,
          0.0111, 0.4836, 0.8161, 0.5433, 0.0499, 0.9480, 0.9068, 0.3802, 0.1391,
          0.4579, 0.7771, 0.4473, 0.0723, 0.6525, 0.0303, 0.5666, 0.0445, 0.9429,
          0.4105, 0.5869, 0.3013, 0.0978, 0.4466, 0.8126, 0.0124, 0.7008, 0.3104,
          0.3563, 0.4151, 0.3668, 0.8960, 0.7017, 0.7425, 0.8891, 0.8327, 0.4106,
          0.6540, 0.7920, 0.3264, 0.4959, 0.3416, 0.5613, 0.0157, 0.1712, 0.2564,
          0.8040, 0.2195, 0.4739, 0.6879, 0.1230, 0.4209, 0.3082, 0.5983, 0.8558,
          0.3691, 0.4640, 0.1684, 0.2421, 0.5704, 0.8536, 0.0094, 0.9899, 0.7508,
          0.2311, 0.8497, 0.4387, 0.4033, 0.9431, 0.3647, 0.4729, 0.5746, 0.2988,
          0.0768, 0.2213, 0.4358, 0.3477, 0.8455, 0.6141, 0.8072, 0.0190, 0.0566,
          0.4211, 0.1196, 0.1189, 0.1714, 0.9583, 0.3800, 0.7954, 0.9149, 0.7815,
          0.5894, 0.2918, 0.9469, 0.6777, 0.4352, 0.1712, 0.4616, 0.6909, 0.2448,
          0.5707, 0.4657, 0.1391, 0.2554, 0.4964, 0.7933, 0.9125, 0.3875, 0.5173,
          0.9583, 0.7972, 0.7364, 0.3878, 0.1677, 0.9579, 0.2579, 0.9124, 0.6810,
          0.6438, 0.4360, 0.5801, 0.5615, 0.5204, 0.7515, 0.0474, 0.4068, 0.4368,
          0.2929, 0.8861, 0.8575, 0.7709, 0.4987, 0.6291, 0.5330, 0.4880, 0.2685,
          0.8747, 0.5823, 0.2666, 0.3233, 0.6684, 0.5955, 0.3285, 0.8853, 0.3941,
          0.9921, 0.7938, 0.9142, 0.5566, 0.8785, 0.3083, 0.0873, 0.8995, 0.5198,
          0.1167, 0.3876, 0.1822, 0.6893, 0.1184, 0.9622, 0.4726, 0.0923, 0.8414,
          0.1237, 0.4730, 0.7193, 0.6529, 0.4518, 0.2103, 0.0439, 0.4147, 0.2596,
          0.0730, 0.8803, 0.4047, 0.2782, 0.6896, 0.3816, 0.8392, 0.0555, 0.1396,
          0.2262, 0.1190, 0.1720, 0.0657, 0.9231, 0.1371, 0.0011, 0.4935, 0.7888,
          0.2657, 0.8784, 0.5516, 0.8420, 0.9878, 0.9972, 0.6831, 0.8365, 0.3174,
          0.3071, 0.6530, 0.4448, 0.5184, 0.2356, 0.2549, 0.5443, 0.2946, 0.7365,
          0.2777, 0.1297, 0.6103, 0.3738, 0.7953, 0.6003, 0.9708, 0.3897, 0.6248,
          0.8176, 0.0244, 0.1821, 0.8671, 0.2927, 0.1544, 0.5942, 0.5398, 0.5749,
          0.1336, 0.1221, 0.2047, 0.8849, 0.8749, 0.4740, 0.7466, 0.6381, 0.3920,
          0.9456, 0.2977, 0.2140, 0.1522, 0.1489, 0.9544, 0.9647, 0.8917, 0.9450,
          0.2920, 0.2414, 0.1228, 0.2876, 0.0278, 0.2262, 0.0531, 0.0025, 0.4814,
          0.8346, 0.9718, 0.4833, 0.5596, 0.5231, 0.9197, 0.6634, 0.5031]]),
 'better': tensor([[8.0559e-01, 3.1079e-01, 5.6028e-02, 1.0601e-02, 3.5850e-01, 7.9549e-01,
          1.2711e-01, 6.6148e-01, 4.8979e-01, 1.4553e-01, 6.6997e-01, 6.4226e-01,
          2.1324e-01, 1.4310e-01, 6.0872e-01, 9.4765e-01, 6.1617e-02, 7.3873e-02,
          7.3152e-01, 5.0554e-01, 4.1034e-01, 2.6130e-01, 8.9329e-01, 3.8616e-01,
          8.3406e-01, 9.7408e-01, 7.2836e-01, 1.7194e-02, 4.3492e-01, 2.9678e-01,
          4.4808e-03, 3.8489e-01, 5.2455e-01, 3.4273e-01, 8.8719e-01, 2.6532e-01,
          7.9900e-01, 8.9404e-01, 3.0302e-01, 5.4006e-01, 8.8524e-01, 1.3292e-01,
          9.0553e-02, 9.6303e-02, 7.4187e-03, 5.7318e-01, 7.7479e-01, 7.3902e-01,
          3.2356e-01, 9.4293e-01, 4.8039e-01, 1.4785e-01, 8.6490e-01, 6.4608e-01,
          5.6074e-01, 9.0159e-01, 8.9594e-01, 7.4562e-01, 1.9953e-01, 3.5172e-01,
          1.2930e-01, 7.3532e-01, 3.8711e-01, 6.9405e-01, 4.9231e-01, 6.4351e-01,
          9.2878e-01, 3.8594e-01, 9.5903e-01, 2.0474e-01, 9.3852e-01, 4.6867e-01,
          3.2382e-01, 3.3148e-01, 7.4370e-01, 8.5372e-02, 4.7383e-01, 1.4697e-01,
          5.1866e-01, 5.8742e-01, 9.2905e-01, 4.7374e-01, 8.9554e-01, 1.6893e-01,
          2.3087e-02, 7.8440e-01, 2.9623e-02, 6.3534e-01, 4.7029e-01, 9.9077e-01,
          4.1011e-01, 8.6382e-01, 7.7611e-01, 1.1262e-01, 3.3860e-02, 3.2395e-01,
          7.5565e-01, 8.1029e-01, 7.2541e-01, 9.3518e-01, 4.6305e-01, 2.5759e-01,
          9.1802e-01, 7.0422e-01, 4.3576e-01, 3.5554e-01, 8.0382e-01, 2.6986e-01,
          4.4950e-01, 2.3300e-01, 9.0167e-01, 3.7970e-01, 3.3233e-01, 2.5714e-01,
          4.4902e-01, 3.2997e-01, 5.6966e-01, 5.2284e-01, 7.6356e-01, 2.0971e-01,
          6.1767e-01, 6.1274e-01, 6.7947e-01, 1.1205e-01, 2.6691e-01, 3.3141e-01,
          9.8605e-01, 7.7260e-01, 9.2324e-01, 7.9772e-01, 4.8901e-01, 2.8792e-01,
          1.1584e-01, 4.9979e-01, 4.7377e-01, 5.4883e-01, 7.9868e-01, 8.8789e-03,
          8.0802e-01, 6.5522e-01, 6.0236e-01, 1.4848e-02, 1.0937e-01, 8.9718e-01,
          2.3989e-01, 6.3060e-01, 2.0074e-01, 4.9786e-01, 7.9752e-01, 2.6748e-01,
          6.9593e-01, 5.7218e-01, 9.3317e-01, 7.1761e-01, 7.9264e-01, 6.0858e-01,
          1.5574e-01, 7.4566e-01, 7.0488e-02, 5.3332e-01, 8.2235e-01, 6.7551e-01,
          5.2493e-01, 9.9000e-01, 4.4549e-01, 5.0738e-01, 2.9558e-01, 6.7413e-01,
          9.4789e-01, 1.6541e-01, 8.3389e-01, 3.1217e-01, 3.3535e-01, 1.6414e-01,
          8.7943e-01, 6.4612e-01, 1.8041e-01, 7.6249e-02, 9.8930e-01, 4.9270e-01,
          9.1048e-01, 5.0686e-01, 5.6370e-01, 2.6645e-01, 4.1343e-01, 6.3957e-01,
          7.5649e-02, 3.0431e-01, 9.1317e-01, 6.1877e-03, 5.6369e-01, 5.0516e-02,
          8.3493e-01, 6.7099e-01, 1.4514e-01, 9.9899e-01, 8.3077e-01, 7.3516e-01,
          8.1379e-01, 8.2608e-01, 8.5239e-01, 5.7093e-02, 6.1294e-01, 4.9579e-01,
          1.0108e-01, 3.6106e-01, 9.5419e-01, 9.7343e-01, 8.2428e-01, 4.2032e-01,
          6.9386e-01, 1.6705e-01, 8.4456e-01, 7.3629e-01, 6.8664e-01, 7.0476e-01,
          1.4511e-01, 9.6826e-01, 4.6021e-01, 7.9869e-01, 6.3389e-01, 4.9370e-01,
          4.0169e-01, 9.2914e-01, 6.4539e-02, 1.3045e-01, 9.3815e-01, 5.5556e-01,
          9.0109e-02, 2.7734e-01, 5.8398e-01, 9.8506e-01, 5.8946e-01, 6.7270e-01,
          7.4010e-01, 3.3563e-01, 7.0808e-01, 2.5908e-01, 6.4426e-02, 2.0589e-01,
          1.1182e-01, 6.4991e-01, 1.6458e-01, 5.1466e-01, 3.1790e-01, 3.1662e-01,
          4.6003e-01, 7.2658e-05, 6.8131e-01, 4.0257e-01, 2.1960e-01, 6.4176e-02,
          4.2393e-01, 5.0610e-01, 8.0311e-01, 9.1126e-01, 5.8124e-02, 8.2545e-01,
          1.1527e-01, 9.8229e-01, 1.4608e-01, 5.4559e-01, 7.7621e-01, 2.0951e-01,
          2.9684e-01, 4.5306e-01, 6.7287e-01, 5.1184e-01, 3.2333e-01, 1.8697e-01,
          7.4090e-02, 2.4898e-01, 1.8556e-01, 5.2644e-01, 6.2876e-01, 1.7004e-01,
          7.1102e-01, 5.3458e-01, 4.6818e-01, 4.1697e-01, 6.8341e-02, 5.4751e-01,
          8.7448e-01, 9.4455e-02, 4.2057e-01, 5.4386e-02, 8.5642e-01, 7.0885e-01,
          7.5223e-01, 8.0065e-01, 4.6538e-01, 5.2892e-01, 9.6875e-01, 8.8735e-01,
          2.1390e-01, 8.2894e-01, 7.9962e-01, 4.7924e-01, 9.0915e-01, 2.4507e-01,
          2.3991e-01, 1.1877e-01, 6.4302e-01, 3.9908e-01, 7.8200e-01, 2.4800e-01,
          1.6455e-02, 6.6382e-01, 3.3685e-01, 1.1778e-01, 3.3815e-01, 4.6232e-02,
          4.8027e-01, 5.0941e-01, 5.5334e-01, 4.2506e-01, 7.1928e-02, 1.1600e-01,
          9.1019e-01, 2.2528e-01, 3.6414e-01, 9.2526e-01, 7.5269e-01, 9.7200e-01,
          7.4234e-01, 9.5537e-02, 3.4856e-01, 1.0327e-01, 8.6401e-01, 5.4606e-02,
          4.8839e-01, 4.1983e-01, 3.7903e-01, 4.2380e-01, 2.7893e-01, 8.7351e-01,
          1.6084e-01, 7.7649e-01, 8.2132e-01, 2.9091e-01, 6.7906e-01, 5.8851e-01,
          9.4994e-01, 5.1678e-01, 9.4358e-01, 9.8022e-01, 7.5631e-01, 6.6831e-02,
          3.1273e-01, 1.4209e-01, 3.2110e-02, 7.2068e-01, 1.4507e-01, 9.7812e-01,
          2.0234e-01, 7.8213e-02, 7.9383e-01, 6.1282e-01, 5.6981e-01, 7.5002e-01,
          2.5106e-01, 7.9204e-01, 9.4016e-01, 4.4273e-01, 5.7655e-01, 6.7380e-01,
          6.2354e-01, 3.5822e-02, 5.9152e-01, 1.4949e-01, 4.2843e-01, 8.0444e-01,
          6.0194e-01, 1.2333e-02, 5.5856e-01, 2.4781e-01, 8.5382e-01, 6.1294e-01,
          6.7451e-01, 7.2509e-01, 2.5884e-01, 6.0437e-01, 7.8035e-01, 8.3381e-01,
          6.1660e-01, 9.5430e-01, 7.6094e-01, 4.8242e-01, 1.0131e-01, 9.5591e-03,
          9.1491e-01, 7.5761e-01, 1.3844e-01, 5.5028e-01, 7.1723e-01, 9.5846e-01,
          2.3718e-01, 8.4171e-01, 2.7303e-01, 9.1117e-01, 2.8183e-01, 1.1714e-01,
          9.9398e-03, 1.5259e-01, 8.9360e-01, 6.0196e-01, 5.2251e-01, 8.1564e-01,
          5.6416e-01, 1.2118e-01, 2.2851e-01, 5.1707e-01, 4.8859e-02, 9.8628e-01,
          4.6055e-01, 8.0671e-01, 4.1338e-01, 5.4646e-01, 8.4825e-01, 3.2547e-01,
          1.2510e-01, 4.0173e-01, 1.4080e-01, 5.7927e-01, 8.0579e-02, 9.1139e-01,
          7.0695e-01, 7.8642e-01, 7.0261e-01, 7.8339e-01, 2.2369e-02, 8.6547e-01,
          2.5926e-02, 1.7647e-01, 6.8591e-02, 4.7940e-01, 7.7250e-01, 4.2383e-01,
          8.5070e-01, 7.0415e-01, 9.2839e-01, 3.3424e-01, 3.5805e-01, 5.0100e-01,
          2.0468e-01, 8.4265e-01, 5.9100e-01, 6.8993e-01, 9.5756e-01, 1.9643e-02,
          6.2030e-01, 4.3316e-01, 7.3525e-01, 5.6760e-01, 5.6718e-02, 4.7301e-01,
          9.9620e-01, 8.3605e-01, 3.7611e-01, 9.3674e-01, 3.2158e-02, 1.1762e-02,
          6.4258e-01, 7.9977e-01, 7.8799e-01, 2.5022e-01, 3.3654e-01, 5.9173e-01,
          2.8152e-01, 6.3955e-01, 8.6668e-01, 5.5823e-01, 8.4198e-01, 1.1537e-01,
          5.8704e-01, 1.2852e-01, 2.6263e-01, 8.5329e-01, 3.4093e-01, 3.8616e-01,
          9.7535e-01, 3.0644e-01, 7.8955e-01, 8.7749e-01, 4.4489e-01, 1.5628e-01,
          3.5852e-01, 6.9664e-01, 2.3336e-01, 2.2815e-01, 6.6574e-01, 3.1887e-01,
          4.6864e-01, 2.5437e-02, 4.4281e-01, 1.9499e-01, 5.3550e-01, 9.0887e-01,
          4.8409e-01, 7.6155e-01, 6.2040e-02, 2.0381e-01, 8.1838e-01, 1.2424e-01,
          9.9882e-01, 3.7003e-01, 6.1769e-01, 6.6185e-01, 1.4405e-02, 4.5349e-01,
          8.1813e-01, 5.3735e-01]]),
 'the': tensor([[7.1840e-01, 7.7338e-01, 8.7673e-01, 6.7132e-02, 3.8288e-02, 2.3667e-01,
          9.7728e-01, 3.8908e-01, 7.1819e-01, 2.7288e-01, 8.2627e-01, 8.3989e-01,
          4.4614e-01, 6.6495e-01, 9.4599e-01, 2.1418e-02, 1.4347e-01, 3.7995e-01,
          3.8929e-01, 5.1051e-01, 6.4330e-01, 6.0881e-01, 1.8741e-01, 5.5555e-01,
          7.2272e-01, 2.3085e-01, 6.2442e-01, 8.8043e-01, 6.9338e-01, 3.9472e-01,
          4.0467e-01, 4.5217e-01, 9.6289e-01, 1.8670e-01, 5.4521e-01, 1.0942e-01,
          2.9833e-01, 8.1512e-01, 2.5580e-01, 5.0049e-01, 6.3274e-01, 3.5642e-01,
          5.6065e-01, 5.4556e-01, 7.8556e-01, 3.5699e-01, 6.8016e-01, 1.2972e-01,
          3.6899e-01, 2.9032e-01, 9.7950e-01, 1.9539e-01, 4.8868e-01, 3.5121e-01,
          4.0649e-01, 1.4700e-01, 7.7626e-01, 1.4400e-01, 1.4335e-02, 8.5257e-01,
          5.4909e-01, 6.9023e-02, 9.3764e-01, 5.3508e-01, 9.8231e-01, 5.8992e-01,
          7.2729e-01, 4.9253e-01, 1.4412e-01, 2.1264e-01, 2.8610e-01, 4.3615e-01,
          1.0608e-01, 8.1135e-01, 2.0467e-01, 4.1088e-03, 2.5731e-02, 4.0456e-02,
          6.6288e-01, 6.9338e-01, 5.5029e-01, 2.0925e-01, 9.0276e-01, 1.1149e-01,
          8.6267e-02, 8.9037e-01, 1.2472e-01, 9.1535e-01, 4.5415e-01, 4.0818e-01,
          2.0396e-01, 6.1111e-01, 7.7290e-01, 3.6548e-01, 4.2381e-01, 7.2378e-02,
          5.6363e-01, 9.2064e-01, 7.1607e-01, 5.2356e-01, 3.7323e-01, 1.6813e-01,
          8.8534e-01, 4.0593e-01, 9.2707e-01, 6.5008e-01, 1.3383e-01, 4.7958e-01,
          1.9529e-01, 9.2881e-01, 9.7531e-01, 6.3649e-01, 4.8907e-01, 4.9299e-02,
          4.1259e-01, 8.2956e-01, 3.0565e-01, 9.1686e-01, 1.1644e-01, 5.1370e-01,
          1.2729e-01, 5.4562e-01, 8.2032e-01, 7.3722e-01, 5.2152e-01, 9.9628e-01,
          6.2727e-01, 6.5455e-01, 2.8654e-01, 1.6315e-01, 1.6087e-02, 7.6122e-01,
          1.4935e-01, 3.6509e-01, 7.9982e-02, 4.5024e-01, 3.0755e-01, 1.0502e-01,
          3.9230e-01, 9.6151e-01, 1.3762e-02, 7.9321e-01, 2.8358e-01, 9.8028e-01,
          7.0471e-01, 8.3850e-01, 8.0853e-01, 9.7964e-01, 9.1886e-01, 7.2591e-01,
          9.5182e-01, 5.9454e-01, 3.6842e-01, 8.2260e-01, 3.6362e-01, 5.1099e-01,
          4.6658e-01, 3.2152e-01, 7.6214e-01, 2.0699e-01, 1.4726e-01, 4.8888e-01,
          3.6083e-01, 3.2167e-02, 5.9175e-01, 3.5113e-01, 1.2702e-01, 5.1079e-01,
          2.7843e-01, 1.1979e-01, 8.7899e-01, 2.3941e-01, 1.4980e-01, 1.2037e-01,
          2.1316e-01, 3.9711e-01, 3.8577e-01, 3.1987e-01, 2.3427e-01, 5.8417e-01,
          3.7767e-01, 5.1054e-01, 7.7232e-02, 5.8387e-01, 6.8264e-01, 1.9045e-01,
          5.4768e-01, 5.5942e-01, 2.8774e-01, 7.9047e-01, 1.9944e-01, 9.4049e-01,
          3.2445e-01, 3.3877e-01, 6.3633e-01, 9.4373e-01, 1.2508e-01, 3.2732e-01,
          8.3756e-01, 7.6526e-01, 7.3927e-01, 4.0944e-01, 6.2295e-01, 7.2694e-01,
          8.2920e-01, 2.4822e-01, 4.0071e-01, 1.1894e-01, 4.0206e-01, 9.2455e-01,
          5.6509e-01, 2.3390e-01, 3.5575e-01, 9.6215e-01, 3.1350e-01, 7.7651e-01,
          9.4187e-01, 2.5965e-01, 9.2044e-01, 6.5031e-01, 8.2053e-01, 8.4794e-01,
          5.4828e-01, 4.0185e-02, 1.7616e-01, 5.2184e-01, 2.3190e-01, 1.8778e-01,
          9.2054e-01, 4.8606e-01, 5.1361e-01, 9.1763e-01, 8.1631e-01, 5.4024e-01,
          8.4045e-01, 6.5916e-01, 8.4600e-01, 9.2082e-01, 7.6153e-01, 6.3522e-02,
          1.7671e-01, 5.6072e-01, 1.5905e-01, 7.2006e-01, 4.3391e-01, 9.1709e-01,
          4.3677e-01, 7.6592e-05, 7.2817e-01, 5.4623e-02, 3.7463e-01, 4.3651e-01,
          9.1184e-01, 4.8384e-01, 5.0687e-03, 7.2494e-01, 1.2999e-01, 9.8654e-01,
          3.3011e-01, 5.4647e-01, 6.7433e-01, 4.0667e-01, 4.9045e-01, 9.6573e-01,
          8.2738e-01, 5.5243e-01, 1.7718e-01, 9.0797e-01, 8.8965e-01, 2.0178e-01,
          9.5757e-02, 2.8120e-01, 4.9798e-01, 1.5326e-01, 8.3055e-01, 8.6457e-01,
          7.9969e-02, 1.7904e-01, 7.8188e-02, 8.9807e-01, 5.8747e-01, 8.3733e-01,
          3.8565e-01, 1.5266e-01, 8.8976e-01, 6.2725e-01, 1.0543e-01, 7.7315e-01,
          7.3963e-01, 2.1477e-01, 7.6836e-01, 9.2767e-01, 6.1648e-01, 9.2392e-01,
          3.3208e-01, 8.9376e-01, 5.6489e-01, 3.0883e-01, 4.9281e-01, 9.3961e-01,
          8.1981e-01, 6.3638e-01, 8.8745e-01, 5.4526e-01, 4.9053e-01, 7.9725e-01,
          5.2858e-01, 2.7779e-01, 1.8729e-01, 8.3933e-02, 3.6367e-01, 5.5100e-01,
          2.8268e-01, 6.8347e-01, 6.8670e-01, 4.3244e-01, 4.2739e-02, 4.4663e-01,
          1.8882e-01, 4.3310e-01, 3.4636e-01, 6.5152e-01, 8.9151e-01, 3.1886e-01,
          1.5738e-01, 2.2903e-01, 8.2355e-01, 4.3578e-01, 6.9330e-01, 2.0353e-01,
          9.2070e-01, 1.6342e-01, 3.0455e-01, 7.4123e-01, 7.7341e-01, 7.8196e-01,
          9.7344e-01, 2.4251e-01, 6.8957e-01, 2.6401e-01, 6.3643e-01, 8.7621e-02,
          2.3297e-01, 5.6138e-01, 1.5684e-01, 8.0387e-01, 7.6650e-02, 2.6112e-01,
          8.8210e-01, 3.6256e-01, 9.9841e-01, 5.9225e-01, 9.5941e-01, 3.9421e-01,
          1.4235e-01, 2.3071e-01, 6.4362e-01, 9.4433e-01, 8.2584e-01, 9.5157e-01,
          5.0700e-01, 1.8443e-01, 7.3655e-01, 6.6424e-01, 5.9806e-01, 6.0882e-02,
          1.8460e-01, 8.8181e-01, 8.5082e-01, 6.6670e-01, 9.6290e-01, 2.5692e-01,
          4.1782e-01, 3.8015e-01, 9.2386e-02, 4.4776e-01, 8.8147e-01, 3.5820e-01,
          5.5182e-01, 6.2998e-01, 7.7095e-01, 1.9776e-01, 2.9303e-01, 6.4706e-01,
          6.1055e-01, 1.4056e-01, 7.0018e-01, 6.9095e-01, 8.0863e-01, 4.3507e-01,
          4.9761e-02, 1.7040e-01, 7.6384e-01, 3.3561e-01, 1.3727e-01, 4.4176e-01,
          4.5524e-02, 9.8697e-03, 3.5021e-01, 8.4212e-01, 3.9430e-01, 7.2686e-01,
          8.4674e-01, 4.9250e-01, 5.3553e-01, 4.7532e-01, 3.8304e-01, 4.6372e-01,
          3.4261e-01, 5.3908e-01, 3.7357e-01, 3.7999e-02, 9.4951e-01, 1.7841e-02,
          4.6648e-01, 2.9647e-01, 5.3898e-01, 4.2181e-01, 8.7135e-01, 4.3591e-02,
          7.9290e-01, 3.6870e-01, 4.8555e-01, 9.3457e-02, 7.4536e-01, 4.8156e-01,
          2.3233e-01, 3.2334e-01, 9.9272e-01, 4.7335e-01, 7.1260e-02, 6.1609e-01,
          6.9998e-01, 1.1199e-01, 6.2197e-01, 1.6617e-01, 7.3050e-01, 2.8466e-01,
          5.9996e-01, 1.9168e-01, 2.4003e-02, 6.5563e-01, 4.9900e-01, 2.5101e-01,
          8.3981e-01, 7.9246e-01, 7.7300e-01, 1.9648e-01, 3.8993e-01, 6.4507e-01,
          8.0284e-01, 3.1769e-02, 1.4212e-01, 6.0615e-01, 5.5965e-01, 1.0864e-01,
          8.7508e-01, 9.5445e-02, 7.5977e-01, 6.7958e-02, 4.5472e-01, 5.4031e-01,
          4.0682e-01, 4.9183e-01, 3.6144e-01, 2.1030e-01, 7.8168e-01, 6.2322e-01,
          7.6627e-01, 6.2419e-01, 3.6694e-01, 9.0619e-01, 1.0839e-01, 5.8833e-01,
          4.4795e-01, 4.9384e-02, 7.6691e-01, 2.7266e-01, 5.8624e-01, 3.7922e-01,
          5.5803e-01, 6.2642e-01, 4.8768e-01, 7.3950e-01, 7.6407e-01, 1.2518e-01,
          1.8207e-01, 1.0997e-01, 9.6424e-01, 2.2704e-01, 3.4412e-01, 6.5359e-01,
          8.0274e-01, 1.3493e-01, 2.9178e-03, 9.1747e-01, 1.8471e-01, 8.8337e-01,
          1.3594e-01, 6.4136e-01, 3.6794e-01, 8.9989e-02, 4.8156e-01, 6.0371e-01,
          2.5130e-02, 5.9591e-01, 6.9674e-01, 3.7529e-01, 2.5596e-01, 2.2914e-01,
          6.0645e-01, 8.1164e-01]]),
 'best': tensor([[4.0259e-01, 2.6733e-01, 6.6947e-01, 3.0438e-01, 3.5235e-01, 1.5854e-01,
          2.4253e-01, 6.1593e-01, 2.9808e-01, 8.4712e-01, 3.9959e-01, 8.9896e-01,
          9.1117e-01, 4.6500e-01, 1.8770e-01, 1.2655e-01, 3.4765e-01, 6.3233e-01,
          7.9288e-01, 1.1200e-01, 3.2495e-01, 2.9242e-01, 9.4287e-01, 6.7922e-01,
          1.6780e-01, 3.7313e-02, 5.1175e-01, 1.4107e-01, 9.0274e-01, 1.0418e-01,
          9.5133e-01, 6.0633e-02, 8.2186e-01, 4.2384e-01, 4.4708e-01, 7.4151e-01,
          2.2214e-01, 8.9054e-02, 5.7131e-01, 5.0907e-01, 6.4871e-01, 6.7990e-02,
          7.0874e-01, 7.3679e-01, 7.8155e-01, 4.4904e-02, 9.2028e-01, 4.4048e-01,
          4.5903e-02, 2.2152e-01, 4.6999e-01, 8.3788e-01, 1.7136e-01, 8.2537e-01,
          4.6803e-01, 7.7024e-01, 4.0445e-01, 3.6793e-01, 9.3816e-01, 7.0189e-01,
          3.7520e-01, 4.8734e-01, 5.7509e-02, 9.8492e-01, 9.6236e-01, 7.1405e-01,
          2.9598e-01, 3.3727e-01, 5.1162e-01, 3.8649e-01, 1.4404e-01, 4.2157e-01,
          7.1847e-01, 2.7349e-01, 6.9235e-01, 8.3867e-01, 6.0890e-01, 7.8846e-01,
          9.8521e-01, 5.4686e-01, 1.4134e-01, 6.2728e-01, 5.6019e-01, 3.1996e-01,
          5.3998e-01, 3.3028e-02, 9.0850e-01, 3.8336e-01, 4.8997e-01, 5.0786e-02,
          4.0996e-01, 8.6342e-01, 5.9602e-01, 5.0532e-02, 9.3803e-01, 3.4595e-01,
          7.8881e-01, 5.0838e-01, 6.7648e-01, 3.4233e-01, 4.6434e-01, 8.6880e-01,
          4.8683e-02, 1.1219e-01, 6.9597e-01, 2.5642e-01, 9.2668e-01, 9.9679e-02,
          7.8016e-01, 4.9688e-02, 8.3544e-01, 8.9185e-01, 5.0079e-01, 9.8121e-02,
          1.3874e-01, 4.1524e-01, 3.6530e-01, 5.6914e-01, 1.7391e-01, 4.9259e-01,
          8.5920e-02, 8.1704e-01, 7.4253e-01, 6.4722e-01, 5.1642e-01, 4.4749e-01,
          7.8668e-01, 3.5312e-01, 3.8716e-01, 6.6233e-01, 5.8588e-01, 1.8122e-01,
          7.1552e-01, 9.2032e-01, 6.1778e-01, 7.5500e-01, 6.8285e-01, 6.3541e-01,
          9.4300e-01, 5.9813e-01, 6.0206e-01, 5.7449e-01, 8.5439e-01, 2.4260e-01,
          4.1800e-01, 6.9756e-01, 2.3530e-01, 6.3199e-01, 8.3007e-01, 4.5415e-01,
          7.7702e-01, 7.2982e-01, 1.5903e-01, 8.7198e-01, 9.5660e-01, 4.4293e-01,
          8.5772e-01, 2.7440e-01, 1.8630e-01, 7.0106e-02, 1.8563e-01, 6.3644e-01,
          8.8426e-02, 8.8935e-01, 7.7345e-01, 2.1890e-01, 5.7373e-01, 8.1057e-01,
          8.1280e-01, 3.6632e-01, 8.9838e-01, 6.3273e-01, 1.1158e-01, 2.2609e-01,
          8.4928e-02, 7.7841e-01, 8.2815e-01, 1.6464e-01, 1.5969e-01, 4.0018e-02,
          6.5016e-01, 9.7102e-01, 4.5821e-01, 9.4256e-01, 9.6083e-01, 2.4425e-01,
          7.6293e-02, 6.9570e-01, 8.7143e-01, 6.6577e-01, 9.5211e-01, 3.7303e-01,
          4.6168e-02, 6.6735e-01, 9.2935e-01, 1.5296e-01, 5.1717e-01, 9.5416e-01,
          4.4116e-01, 5.5678e-01, 1.8386e-01, 1.1745e-01, 5.7748e-01, 8.4284e-01,
          4.7008e-01, 7.3362e-01, 5.4890e-01, 6.0634e-01, 9.9449e-01, 2.7201e-01,
          4.4197e-01, 5.5730e-01, 6.1462e-01, 6.6679e-01, 7.6868e-02, 2.0764e-01,
          7.6925e-01, 8.6299e-01, 9.1116e-01, 2.4143e-01, 2.2274e-01, 8.4563e-01,
          2.8279e-01, 7.2999e-01, 6.9952e-01, 5.9561e-01, 8.9736e-01, 6.2344e-01,
          5.2049e-01, 3.0299e-02, 8.0778e-01, 9.1270e-01, 1.2186e-01, 7.1941e-01,
          9.2692e-01, 8.6148e-01, 6.9103e-01, 3.9094e-01, 9.2692e-01, 8.6996e-01,
          4.0286e-01, 6.9535e-01, 2.6749e-01, 5.2382e-02, 3.8269e-01, 7.5663e-02,
          8.7671e-01, 2.9795e-01, 4.8355e-01, 4.4616e-02, 6.9960e-01, 5.9223e-01,
          4.9434e-01, 3.2469e-01, 7.8462e-01, 3.9616e-03, 6.2455e-01, 2.3131e-01,
          3.2802e-01, 5.5399e-01, 9.9225e-01, 9.2511e-01, 4.7620e-01, 2.0139e-01,
          4.9722e-01, 9.3483e-01, 3.2223e-01, 6.2205e-01, 5.3858e-01, 4.4507e-02,
          5.3362e-01, 6.1940e-01, 6.6380e-01, 1.0363e-01, 4.1321e-02, 3.8392e-02,
          2.5734e-01, 3.7534e-01, 4.2419e-01, 2.9276e-01, 3.6934e-01, 9.4461e-01,
          4.7225e-02, 5.7827e-01, 3.5153e-01, 6.0510e-01, 8.8451e-01, 7.8487e-01,
          5.7638e-01, 9.3526e-01, 9.2568e-01, 4.4647e-01, 6.8781e-01, 2.4766e-01,
          8.4528e-01, 3.3887e-01, 9.9332e-01, 3.1592e-01, 1.0693e-01, 5.3931e-01,
          6.0763e-01, 8.8199e-01, 1.7218e-01, 3.5889e-03, 3.0559e-01, 9.3777e-01,
          3.5501e-01, 2.5272e-01, 6.2234e-01, 5.5792e-01, 4.5468e-01, 4.0948e-01,
          2.9055e-01, 9.8439e-03, 2.4336e-01, 8.1169e-01, 4.1875e-01, 6.7015e-01,
          9.3507e-03, 2.5158e-01, 9.6145e-01, 1.3710e-01, 6.3734e-01, 7.3025e-01,
          9.1482e-01, 7.8041e-01, 9.1661e-01, 3.8479e-01, 2.9627e-01, 4.9819e-01,
          4.2965e-01, 5.6702e-01, 9.1540e-01, 2.4187e-01, 4.8407e-01, 4.2569e-01,
          9.3144e-01, 2.5211e-01, 3.2822e-01, 5.9368e-01, 2.8892e-01, 5.0113e-01,
          2.9161e-01, 2.9295e-01, 7.7374e-01, 3.5251e-01, 6.1450e-01, 2.1963e-01,
          3.9493e-01, 5.8554e-01, 8.6649e-02, 4.8263e-01, 8.8710e-01, 7.8400e-01,
          1.6956e-03, 7.1972e-01, 6.1527e-01, 5.3908e-02, 1.2874e-01, 7.6425e-01,
          2.7499e-02, 7.5466e-01, 1.2516e-01, 9.7461e-01, 7.7658e-01, 7.3800e-02,
          7.8404e-01, 8.3843e-02, 7.3888e-01, 9.8272e-01, 6.1329e-02, 2.4590e-01,
          8.0396e-01, 9.4918e-01, 2.8794e-01, 2.4190e-01, 6.9178e-02, 6.6753e-01,
          7.1654e-01, 6.5312e-01, 7.3100e-01, 5.4497e-01, 1.3812e-01, 5.2004e-01,
          3.6246e-01, 7.0083e-01, 4.0086e-01, 4.7319e-02, 1.5802e-01, 9.0083e-01,
          3.7054e-01, 9.4261e-02, 5.1712e-01, 9.9814e-01, 5.4027e-01, 8.7047e-01,
          2.9337e-01, 7.5962e-02, 7.5880e-01, 5.9591e-01, 5.2860e-01, 1.2146e-01,
          4.3342e-01, 9.8757e-01, 2.8284e-01, 2.7977e-01, 4.4233e-01, 7.7782e-01,
          8.7334e-01, 4.7263e-01, 5.8882e-01, 2.0797e-01, 7.3449e-01, 9.9074e-01,
          9.1840e-01, 3.6205e-01, 4.4807e-02, 8.3154e-01, 1.5899e-01, 1.9707e-01,
          8.0918e-01, 3.7621e-01, 3.9985e-01, 6.1168e-01, 4.6606e-01, 9.4034e-01,
          4.3830e-01, 1.7292e-01, 7.1325e-01, 6.4762e-01, 2.4917e-01, 8.4136e-01,
          5.3644e-01, 2.2851e-01, 3.1175e-02, 7.1844e-01, 2.5537e-01, 1.8867e-01,
          6.5454e-01, 4.3280e-01, 9.8126e-01, 6.5507e-01, 7.6583e-01, 2.4475e-01,
          2.0108e-01, 6.8190e-01, 8.6953e-01, 5.8021e-01, 2.0249e-01, 3.2345e-01,
          9.8828e-01, 1.9553e-01, 6.8106e-01, 9.3523e-01, 5.0367e-02, 6.2485e-01,
          2.2792e-01, 4.9810e-01, 4.1539e-01, 8.3473e-01, 7.5163e-01, 7.5147e-01,
          7.5522e-02, 7.8126e-01, 9.6220e-01, 9.3789e-01, 3.2885e-01, 5.0421e-01,
          1.8500e-01, 4.5145e-01, 3.5509e-01, 1.1354e-01, 4.6393e-01, 3.1882e-01,
          7.4172e-02, 8.7318e-02, 2.3920e-01, 9.7411e-02, 5.4623e-01, 3.4189e-04,
          3.2336e-01, 7.2842e-01, 8.2325e-02, 7.7559e-01, 2.4379e-01, 8.5758e-01,
          1.8102e-01, 2.3327e-01, 5.7683e-01, 6.1124e-01, 2.4198e-01, 7.2139e-02,
          6.1478e-02, 4.4682e-01, 1.1979e-01, 2.4950e-01, 9.1591e-02, 4.0937e-01,
          8.2345e-01, 1.9444e-01, 9.5916e-01, 1.2011e-01, 4.3093e-01, 2.6237e-01,
          1.0576e-01, 5.6762e-01, 8.3273e-01, 6.4544e-01, 1.9923e-01, 7.5069e-01,
          5.0031e-01, 9.2017e-01]])}

In [25]:
# exercise 04

"""
Training and testing the Transformer model

With the TransformerEncoder model in place, the next step at PyBooks is to train the model on sample reviews and evaluate its performance. Training on these sample reviews will help PyBooks understand the sentiment trends in their vast repository. By achieving a well-performing model, PyBooks can then automate sentiment analysis, ensuring readers get insightful recommendations and feedback.

The following packages have been imported for you: torch, nn, optim.

The model instance of the TransformerEncoder class, token_embeddings, and the train_sentences, train_labels ,test_sentences,test_labels are preloaded for you.
"""

# Instructions

"""

    In the training loop, split the sentences into tokens and stack the embeddings.

    Zero the gradients and perform a backward pass.

    In the predict function, deactivate the gradient computations then get the sentiment prediction.

"""

# solution

for epoch in range(5):  
    for sentence, label in zip(train_sentences, train_labels):
        # Split the sentences into tokens and stack the embeddings
        tokens = sentence.split()
        data = torch.stack([token_embeddings[token] for token in tokens], dim=1)
        output = model(data)
        loss = criterion(output, torch.tensor([label]))
        # Zero the gradients and perform a backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}, Loss: {loss.item()}")

def predict(sentence):
    model.eval()
    # Deactivate the gradient computations and get the sentiment prediction.
    with torch.no_grad():
        tokens = sentence.split()
        data = torch.stack([token_embeddings.get(token, torch.rand((1, 512))) for token in tokens], dim=1)
        output = model(data)
        predicted = torch.argmax(output, dim=1)
        return "Positive" if predicted.item() == 1 else "Negative"

sample_sentence = "This product can be better"
print(f"'{sample_sentence}' is {predict(sample_sentence)}")

#----------------------------------#

# Conclusion

"""
Excellent! You've successfully trained and tested the Transformer model. With such models in place, PyBooks' recommendation engine will provide even more accurate book recommendations based on user reviews. Well done!
"""

Epoch 0, Loss: 0.5908824801445007
Epoch 0, Loss: 13.509520530700684
Epoch 0, Loss: 4.571168899536133
Epoch 1, Loss: 5.174601078033447
Epoch 1, Loss: 0.28111374378204346
Epoch 1, Loss: 0.40869632363319397
Epoch 2, Loss: 1.8922356367111206
Epoch 2, Loss: 0.4770905673503876
Epoch 2, Loss: 0.6592643857002258
Epoch 3, Loss: 1.1888630390167236
Epoch 3, Loss: 0.5068765878677368
Epoch 3, Loss: 0.5032485127449036
Epoch 4, Loss: 1.4057996273040771
Epoch 4, Loss: 0.419771671295166
Epoch 4, Loss: 0.4097725450992584
'This product can be better' is Negative


"\nExcellent! You've successfully trained and tested the Transformer model. With such models in place, PyBooks' recommendation engine will provide even more accurate book recommendations based on user reviews. Well done!\n"

In [27]:
data = ['the cat sat on the mat', 'dogs are very loyal animals', 'parrots are colorful and noisy', 'whales are the largest mammals']
vocab = {'and', 'animals', 'are', 'cat', 'colorful', 'dogs', 'largest', 'loyal', 'mammals', 'mat', 'noisy', 'on', 'parrots', 'sat', 'the', 'very', 'whales'}
vocab_size = 17
word_to_ix = {'cat': 0, 'loyal': 1, 'very': 2, 'on': 3, 'sat': 4, 'are': 5, 'mat': 6, 'dogs': 7, 'the': 8, 'animals': 9, 'and': 10, 'noisy': 11, 'parrots': 12, 'colorful': 13, 'whales': 14, 'largest': 15, 'mammals': 16}
ix_to_word = {0: 'cat', 1: 'loyal', 2: 'very', 3: 'on', 4: 'sat', 5: 'are', 6: 'mat', 7: 'dogs', 8: 'the', 9: 'animals', 10: 'and', 11: 'noisy', 12: 'parrots', 13: 'colorful', 14: 'whales',
 15: 'largest', 16: 'mammals'}
input_data = [[8, 0, 4, 3, 8], [7, 5, 2, 1], [12, 5, 13, 10], [14, 5, 8, 15]]
target_data = [6, 9, 11, 16]
embedding_dim = 10 
hidden_dim = 16
pairs = [sentence.split() for sentence in data]

In [28]:
# exercise 05

"""
Creating a RNN model with attention

At PyBooks, the team has been exploring various deep learning architectures. After some research, you decide to implement an RNN with Attention mechanism to predict the next word in a sentence. You're given a dataset with sentences and a vocabulary created from them.

The following packages have been imported for you: torch, nn.

The following has been preloaded for you:

    vocab and vocab_size: The vocabulary set and its size
    word_to_ix and ix_to_word: dictionary for word to index and index to word mappings
    input_data and target_data: converted dataset to input-output pairs
    embedding_dim and hidden_dim: dimensions for embedding and RNN hidden state

You can inspect the data variable in the console to see the example sentences.
"""

# Instructions

"""

    Create an embedding layer for the vocabulary with the given embedding_dim.

    Apply a linear transformation to the RNN sequence output to get the attention scores.

    Get the attention weights from the score.

    Compute the context vector as the weighted sum of RNN outputs and attention weights

"""

# solution

class RNNWithAttentionModel(nn.Module):
    def __init__(self):
        super(RNNWithAttentionModel, self).__init__()
        # Create an embedding layer for the vocabulary
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        # Apply a linear transformation to get the attention scores
        self.attention = nn.Linear(hidden_dim, 1)
        self.fc = nn.Linear(hidden_dim, vocab_size)
    def forward(self, x):
        x = self.embeddings(x)
        out, _ = self.rnn(x)
        #  Get the attention weights
        attn_weights = torch.nn.functional.softmax(self.attention(out).squeeze(2), dim=1)
        # Compute the context vector 
        context = torch.sum(attn_weights.unsqueeze(2) * out, dim=1)
        out = self.fc(context)
        return out
      
attention_model = RNNWithAttentionModel()
optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
print("Model Instantiated")

#----------------------------------#

# Conclusion

"""
Great job! You've incorporated the attention mechanism into the RNN, which helps the model to focus on relevant parts of the input when making a prediction. Let's see how it performs in the next exercise!
"""

Model Instantiated


"\nGreat job! You've incorporated the attention mechanism into the RNN, which helps the model to focus on relevant parts of the input when making a prediction. Let's see how it performs in the next exercise!\n"

In [29]:
vocab = set(' '.join(data).split())
word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {i: word for word, i in word_to_ix.items()}
pairs = [sentence.split() for sentence in data]
input_data = [[word_to_ix[word] for word in sentence[:-1]] for sentence in pairs]
target_data = [word_to_ix[sentence[-1]] for sentence in pairs]
inputs = [torch.tensor(seq, dtype=torch.long) for seq in input_data]
targets = torch.tensor(target_data, dtype=torch.long)

In [31]:
def pad_sequences(batch):
    max_len = max([len(seq) for seq in batch])
    return torch.stack([torch.cat([seq, torch.zeros(max_len-len(seq)).long()]) for seq in batch])

In [36]:
import torch
import torch.nn as nn

class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super(RNNModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.embeddings(x)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x

# Example usage
vocab_size = 17
embedding_dim = 10
hidden_dim = 16
output_dim = 17

rnn_model = RNNModel(vocab_size, embedding_dim, hidden_dim, output_dim)
print(model)

RNNModel(
  (embeddings): Embedding(17, 10)
  (rnn): RNN(10, 16, batch_first=True)
  (fc): Linear(in_features=16, out_features=17, bias=True)
)


In [42]:
# exercise 06

"""
Training and testing the RNN model with attention

At PyBooks, the team had previously built an RNN model for word prediction without the attention mechanism. This initial model, referred to as rnn_model, has already been trained and its instance is preloaded. Your task now is to train the new RNNWithAttentionModel and compare its predictions with that of the earlier rnn_model.

The following has been preloaded for you:

    inputs: list of input sequences as tensors
    targets: tensor containing target words for each input sequence
    optimizer: Adam optimizer function
    criterion: CrossEntropyLoss function
    pad_sequences: function to pad input sequences for batching
    attention_model: defined model class from the previous exercise
    rnn_model:trained RNN model from the team at PyBooks

"""

# Instructions

"""

    Set the RNN model to evaluation mode before testing it with the test data.

    Get the RNN output by passing the appropriate input to the RNN model.

    Extract the word with the highest prediction score from the RNN output.

    Similarly, for the attention model, extract the word with the highest prediction score from the attention output.

"""

# solution
epochs = 300
for epoch in range(epochs):
    attention_model.train()
    optimizer.zero_grad()
    padded_inputs = pad_sequences(inputs)
    outputs = attention_model(padded_inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

for input_seq, target in zip(input_data, target_data):
    input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
   	
    #  Set the RNN model to evaluation mode
    rnn_model.eval()
    # Get the RNN output by passing the appropriate input 
    rnn_output = rnn_model(input_test)
    # Extract the word with the highest prediction score 
    rnn_prediction = ix_to_word[torch.argmax(rnn_output).item()]

    attention_model.eval()
    attention_output = attention_model(input_test)
    # Extract the word with the highest prediction score
    attention_prediction = ix_to_word[torch.argmax(attention_output).item()]

    print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}")
    print(f"Target: {ix_to_word[target]}")
    print(f"RNN prediction: {rnn_prediction}")
    print(f"RNN with Attention prediction: {attention_prediction}")

#----------------------------------#

# Conclusion

"""
Excellent! With the attention mechanism, the model can potentially focus better on the context of the input sequence when predicting the next word compared to RNN. The results make it clear that using Attention layer has improved the results significantly!
"""

KeyError: 57

# Adversarial attack classification

Imagine you're a Data Scientist on a mission to safeguard machine learning models from malicious attacks. In order to do so, you need to be aware of the different attacks that you and your model could encounter. Being aware of these vulnerabilities will allow you to protect your models against these adversarial threats.

![Answer](/home/nero/Documents/Estudos/DataCamp/Python/courses/deep-learning-for-text-with-pytorch/ch_04.png)

# Safeguarding AI at PyBooks

You are responsible for implementing a new chatbot at PyBooks to assist users on the platform. However, you've been alerted about potential adversarial attacks against text-based AI systems. To ensure the chatbot remains robust against such threats, you decide to research various strategies and techniques.

Which of the following approaches would be LEAST effective in safeguarding the chatbot against adversarial attacks?

### Possible Answers


    Implementing regularization to prevent the model from overfitting on the training data
    
    
    Utilizing gradient masking to make gradient information less usable for potential attackers
    
    
    Increasing the chatbot's response speed by using a high-performance GPU {Answer}