In [10]:
# Import the necessary functions
# !pip install torchtext

import nltk
import torch
from torchtext.data.utils import get_tokenizer
from nltk.probability import FreqDist
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords


text = "In the city of Dataville, a data analyst named Alex explores hidden insights within vast data. With determination, Alex uncovers patterns, cleanses the data, and unlocks innovation. Join this adventure to unleash the power of data-driven decisions."

# Initialize the tokenizer and tokenize the text
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(text)

tokens

['in',
 'the',
 'city',
 'of',
 'dataville',
 ',',
 'a',
 'data',
 'analyst',
 'named',
 'alex',
 'explores',
 'hidden',
 'insights',
 'within',
 'vast',
 'data',
 '.',
 'with',
 'determination',
 ',',
 'alex',
 'uncovers',
 'patterns',
 ',',
 'cleanses',
 'the',
 'data',
 ',',
 'and',
 'unlocks',
 'innovation',
 '.',
 'join',
 'this',
 'adventure',
 'to',
 'unleash',
 'the',
 'power',
 'of',
 'data-driven',
 'decisions',
 '.']

In [5]:
threshold = 1
# Remove rare words and print common tokens
freq_dist = FreqDist(tokens)
common_tokens = [token for token in tokens if freq_dist[token] > threshold]
print(common_tokens)

['the', 'of', ',', 'data', 'alex', 'data', '.', ',', 'alex', ',', 'the', 'data', ',', '.', 'the', 'of', '.']


In [6]:
text = 'The moor is very sparsely inhabited, and those who live near each other are thrown very much together. For this reason I saw a good deal of Sir Charles Baskerville. With the exception of Mr. Frankland, of Lafter Hall, and Mr. Stapleton, the naturalist, there are no other men of education within many miles. Sir Charles was a retiring man, but the chance of his illness brought us together, and a community of interests in science kept us so. He had brought back much scientific information from South Africa, and many a charming evening we have spent together discussing the comparative anatomy of the Bushman and the Hottentot.'

In [8]:
# Initialize and tokenize the text
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(text)

tokens[:10]

['the',
 'moor',
 'is',
 'very',
 'sparsely',
 'inhabited',
 ',',
 'and',
 'those',
 'who']

In [11]:
# Remove any stopwords
stop_words = set(stopwords.words("english"))
filtered_tokens = [token for token in tokens if token.lower() not in stop_words]

filtered_tokens

['city',
 'dataville',
 ',',
 'data',
 'analyst',
 'named',
 'alex',
 'explores',
 'hidden',
 'insights',
 'within',
 'vast',
 'data',
 '.',
 'determination',
 ',',
 'alex',
 'uncovers',
 'patterns',
 ',',
 'cleanses',
 'data',
 ',',
 'unlocks',
 'innovation',
 '.',
 'join',
 'adventure',
 'unleash',
 'power',
 'data-driven',
 'decisions',
 '.']

In [12]:
# Perform stemming on the filtered tokens
stemmer = PorterStemmer()
stemmed_tokens = [stemmer.stem(token) for token in filtered_tokens]
print(stemmed_tokens)

['citi', 'datavil', ',', 'data', 'analyst', 'name', 'alex', 'explor', 'hidden', 'insight', 'within', 'vast', 'data', '.', 'determin', ',', 'alex', 'uncov', 'pattern', ',', 'cleans', 'data', ',', 'unlock', 'innov', '.', 'join', 'adventur', 'unleash', 'power', 'data-driven', 'decis', '.']


In [13]:
genres = ['Fiction','Non-fiction','Biography', 'Children','Mystery']

# Define the size of the vocabulary
vocab_size = len(genres)

# Create one-hot vectors
one_hot_vectors = torch.eye(vocab_size)

# Create a dictionary mapping genres to their one-hot vectors
one_hot_dict = {genre: one_hot_vectors[i] for i, genre in enumerate(genres)}

one_hot_dict

{'Fiction': tensor([1., 0., 0., 0., 0.]),
 'Non-fiction': tensor([0., 1., 0., 0., 0.]),
 'Biography': tensor([0., 0., 1., 0., 0.]),
 'Children': tensor([0., 0., 0., 1., 0.]),
 'Mystery': tensor([0., 0., 0., 0., 1.])}

In [14]:
for genre, vector in one_hot_dict.items():
    print(f'{genre}: {vector.numpy()}')

Fiction: [1. 0. 0. 0. 0.]
Non-fiction: [0. 1. 0. 0. 0.]
Biography: [0. 0. 1. 0. 0.]
Children: [0. 0. 0. 1. 0.]
Mystery: [0. 0. 0. 0. 1.]


In [15]:
# Import from sklearn
from sklearn.feature_extraction.text import CountVectorizer

titles = ['The Great Gatsby','To Kill a Mockingbird','1984','The Catcher in the Rye','The Hobbit', 'Great Expectations']

# Initialize Bag-of-words with the list of book titles
vectorizer = CountVectorizer()
bow_encoded_titles = vectorizer.fit_transform(titles)

bow_encoded_titles

<6x12 sparse matrix of type '<class 'numpy.int64'>'
	with 15 stored elements in Compressed Sparse Row format>

In [16]:
# Extract and print the first five features
print(vectorizer.get_feature_names_out()[:5])
print(bow_encoded_titles.toarray()[0, :5])

['1984' 'catcher' 'expectations' 'gatsby' 'great']
[0 0 0 1 1]


In [18]:
print(vectorizer.get_feature_names_out())
print(bow_encoded_titles.toarray())

['1984' 'catcher' 'expectations' 'gatsby' 'great' 'hobbit' 'in' 'kill'
 'mockingbird' 'rye' 'the' 'to']
[[0 0 0 1 1 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1 1 0 0 1]
 [1 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 1 0 0 1 2 0]
 [0 0 0 0 0 1 0 0 0 0 1 0]
 [0 0 1 0 1 0 0 0 0 0 0 0]]


In [20]:
descriptions = ['A portrait of the Jazz Age in all of its decadence and excess.',
 'A gripping, heart-wrenching, and wholly remarkable tale of coming-of-age in a South poisoned by virulent prejudice.',
 'A startling and haunting vision of the world.',
 'A story of lost innocence.',
 'A timeless adventure story.']

In [21]:
# Importing TF-IDF from sklearn
from sklearn.feature_extraction.text import TfidfVectorizer

# Initialize TF-IDF encoding vectorizer
vectorizer = TfidfVectorizer()
tfidf_encoded_descriptions = vectorizer.fit_transform(descriptions)

tfidf_encoded_descriptions

<5x32 sparse matrix of type '<class 'numpy.float64'>'
	with 41 stored elements in Compressed Sparse Row format>

In [22]:
# Extract and print the first five features
print(vectorizer.get_feature_names_out()[:5])
print(tfidf_encoded_descriptions.toarray()[0, :5])

['adventure' 'age' 'all' 'and' 'by']
[0.         0.25943581 0.321564   0.21535516 0.        ]


In [23]:
print(vectorizer.get_feature_names_out())
print(tfidf_encoded_descriptions.toarray())

['adventure' 'age' 'all' 'and' 'by' 'coming' 'decadence' 'excess'
 'gripping' 'haunting' 'heart' 'in' 'innocence' 'its' 'jazz' 'lost' 'of'
 'poisoned' 'portrait' 'prejudice' 'remarkable' 'south' 'startling'
 'story' 'tale' 'the' 'timeless' 'virulent' 'vision' 'wholly' 'world'
 'wrenching']
[[0.         0.25943581 0.321564   0.21535516 0.         0.
  0.321564   0.321564   0.         0.         0.         0.25943581
  0.         0.321564   0.321564   0.         0.36232709 0.
  0.321564   0.         0.         0.         0.         0.
  0.         0.25943581 0.         0.         0.         0.
  0.         0.        ]
 [0.         0.20817488 0.         0.17280396 0.2580274  0.2580274
  0.         0.         0.2580274  0.         0.2580274  0.20817488
  0.         0.         0.         0.         0.29073627 0.2580274
  0.         0.2580274  0.2580274  0.2580274  0.         0.
  0.2580274  0.         0.         0.2580274  0.         0.2580274
  0.         0.2580274 ]
 [0.         0.       

In [24]:
# Create a list of stopwords
stop_words = set(stopwords.words("english"))

# Initialize the tokenizer and stemmer
tokenizer = get_tokenizer("basic_english")
stemmer = PorterStemmer() 

# Complete the function to preprocess sentences
def preprocess_sentences(sentences):
    processed_sentences = []
    for sentence in sentences:
        sentence = sentence.lower()
        tokens = tokenizer(sentence)
        tokens = [token for token in tokens if token not in stop_words]
        tokens = [stemmer.stem(token) for token in tokens]
        processed_sentences.append(' '.join(tokens))
    return processed_sentences


In [36]:
with open('shakespeare.txt', 'r', encoding='utf-8') as file:
    shakespeare = file.read().split('.')
    


In [56]:
shakespeare[:5]

['The Project Gutenberg eBook of The Complete Works of William Shakespeare, by William Shakespeare\n\nThis eBook is for the use of anyone anywhere in the United States and\nmost other parts of the world at no cost and with almost no restrictions\nwhatsoever',
 ' You may copy it, give it away or re-use it under the terms\nof the Project Gutenberg License included with this eBook or online at\nwww',
 'gutenberg',
 'org',
 ' If you are not located in the United States, you\nwill have to check the laws of the country where you are located before\nusing this eBook']

In [58]:
processed_shakespeare = preprocess_sentences(shakespeare)
print(processed_shakespeare[:5]) 

['', '', '', '', '']


In [59]:
# Import libraries
from torch.utils.data import Dataset, DataLoader

# Define your Dataset class
class ShakespeareDataset(Dataset):
    # it is mandatory to define these three methods when you extend the Dataset class in PyTorch.
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


In [83]:
# threshold to remove all tokens that appear less than or equal to threshold times in the sentence

filter_threshod = 1

In [84]:
def preprocess_sentences(sentences):
    processed_sentences = []
    for sentence in sentences:
        sentence = sentence.lower()
        tokens = tokenizer(sentence)
        tokens = [token for token in tokens if token not in stop_words]
        tokens = [stemmer.stem(token) for token in tokens]
        freq_dist = FreqDist(tokens)
        threshold = filter_threshod
        tokens = [token for token in tokens if freq_dist[token] > threshold]
        processed_sentences.append(' '.join(tokens))
    return processed_sentences


In [85]:
def encode_sentences(sentences):
    vectorizer = CountVectorizer()
    X = vectorizer.fit_transform(sentences)
    encoded_sentences = X.toarray()
    return encoded_sentences, vectorizer


In [79]:
def extract_sentences(data):
    sentences = re.findall(r'[A-Z][^.!?]*[.!?]', data)
    return sentences


In [86]:
    
# Complete the text processing pipeline
def text_processing_pipeline(sentences):
    processed_sentences = preprocess_sentences(sentences)
    encoded_sentences, vectorizer = encode_sentences(processed_sentences)
    dataset = ShakespeareDataset(encoded_sentences)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    return dataloader, vectorizer


In [91]:
dataloader, vectorizer = text_processing_pipeline(processed_shakespeare)

# Print the vectorizer's feature names and the first 100 components of the first item
print(vectorizer.get_feature_names_out()[:100]) 
next(iter(dataloader))[0, :100]


['_hic' '_solus_' 'abat' 'abhorson' 'accommod' 'adder' 'adieu' 'aemiliu'
 'afterward' 'age' 'agreement' 'alack' 'alexa' 'alexand' 'andronicu'
 'angel' 'angelo' 'anoth' 'another' 'antoni' 'antonio' 'ari' 'ariel' 'arm'
 'art' 'ass' 'athen' 'attend' 'aufidiu' 'awak' 'away' 'ay' 'back' 'bad'
 'banish' 'baptista' 'barren' 'base' 'bassanio' 'bastard' 'bawd' 'bear'
 'beard' 'beast' 'beaufort' 'beauti' 'bed' 'better' 'bianca' 'bid'
 'blanch' 'blood' 'blow' 'boar' 'bolingbrok' 'book' 'born' 'boy'
 'brabantio' 'brave' 'break' 'breath' 'britain' 'broken' 'brook' 'brother'
 'brutu' 'buck' 'build' 'buy' 'cade' 'caesar' 'caiu' 'call' 'camillo'
 'cannot' 'caphi' 'captain' 'capulet' 'care' 'cargo' 'carri' 'cassio'
 'cassiu' 'catch' 'cau' 'cawdor' 'ceremoni' 'challeng' 'charl' 'cheek'
 'child' 'choo' 'clarenc' 'claudio' 'cleopatra' 'clifford' 'clink' 'coach'
 'come']


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

In [146]:
# Map a unique index to each word
words = ["This", "book", "was", "fantastic", "I", "really", "love", "science", "fiction", "but", "the", "protagonist", "was", "rude", "sometimes"]

word_to_idx = {word: i for i, word in enumerate(words)}

word_to_idx

{'This': 0,
 'book': 1,
 'was': 12,
 'fantastic': 3,
 'I': 4,
 'really': 5,
 'love': 6,
 'science': 7,
 'fiction': 8,
 'but': 9,
 'the': 10,
 'protagonist': 11,
 'rude': 13,
 'sometimes': 14}

In [147]:
# Convert word_to_idx to a tensor
inputs = torch.LongTensor([word_to_idx[w] for w in words])

inputs

tensor([ 0,  1, 12,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])

In [148]:
# Pass the tensor to the embedding layer (creates an embedding of 10 values for each of the 15 words)
output = embedding(inputs)
print(output)

tensor([[ 3.2591e-01,  3.8871e-01, -1.0202e+00, -5.4242e-01, -3.8684e-02,
          1.1134e+00, -8.0390e-01,  4.2076e-01, -2.5217e+00, -1.4594e+00],
        [ 7.5349e-01, -1.7978e+00,  7.3524e-01,  2.4658e-01,  4.0211e-01,
          9.9624e-02,  1.7329e+00,  3.7023e-01, -6.5028e-01, -1.2494e+00],
        [ 1.3734e-01,  7.9268e-01,  9.0330e-01, -1.7315e+00,  7.3597e-02,
         -4.3615e-01,  3.9540e-02, -6.2524e-01, -6.6301e-01,  7.3274e-01],
        [ 2.5612e-01,  4.8039e-01,  1.1930e-01,  1.2893e+00,  7.9125e-01,
          2.8376e-02, -8.4166e-01,  2.0709e+00, -2.0737e+00,  1.3860e-01],
        [ 1.1167e-01,  2.1041e-01,  7.5730e-01,  3.9823e-01, -4.8729e-01,
          1.1858e-01, -9.9135e-01,  1.2653e+00, -5.5789e-01,  1.4413e-01],
        [ 6.9417e-01,  2.9416e-01, -1.0067e+00, -8.5463e-01, -1.0690e+00,
         -7.6825e-02,  9.2565e-01,  6.5659e-01,  6.8181e-01, -4.4510e-01],
        [-5.9533e-01,  1.1413e-02, -1.4163e+00,  5.8150e-01,  8.2895e-01,
         -6.7301e-01, -1.8067e-0

In [None]:
####################################################################################################################333

In [164]:
data = [(['I', 'love', 'this', 'first','book'], 1),
     (['This', 'is', 'an', 'amazing', 'novel'], 1),
     (['I', 'really', 'like', 'this', 'story'], 1),
     (['I', 'do', 'not', 'like', 'this', 'book'], 0),
     (['I', 'hate', 'this', 'novel'], 0),
     (['This', 'is', 'a', 'terrible', 'story'], 0)]

In [165]:
# unique_words = []

# for sub_list, _ in data:
#     for word in sub_list:
#         unique_words.append(word)

# list(set(unique_words))

unique_words = list(set(word for sub_list, _ in data for word in sub_list))

vocab_size = len(unique_words)

vocab_size, unique_words

(18,
 ['an',
  'amazing',
  'story',
  'a',
  'this',
  'like',
  'really',
  'I',
  'novel',
  'hate',
  'terrible',
  'This',
  'first',
  'love',
  'do',
  'is',
  'book',
  'not'])

In [166]:
import torch.nn as nn

embed_dim = 10

# Initialize embedding layer with input and output dimensions
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

embedding

Embedding(18, 10)

In [187]:
import torch.nn.functional as F

class TextClassificationCNN(nn.Module):
    
    def __init__(self, vocab_size, embed_dim):
        super(TextClassificationCNN, self).__init__() # initializes the base class nn.Module
        # Initialize the embedding layer 
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.conv = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(embed_dim, 2) # binary classification
        
#         the permute(0, 2, 1) method, it is used to rearrange the dimensions of the tensor to match the expected input of the convolutional layer. The numbers 0, 2, and 1 are the indices of the original dimensions.
# In PyTorch, the input to a Conv1d layer should be in the form (batch_size, num_channels, length). In your case:
# 0 corresponds to the batch size dimension (the number of sentences processed at once).
# 2 corresponds to the embedding dimension (the number of channels, which is the size of the word embeddings).
# 1 corresponds to the sequence length dimension (the length of each sentence).
        
    def forward(self, text):
        embedded = self.embedding(text).permute(0, 2, 1) # Match tensors to convolution layer's expected input (batch_size, num_channels, length)
        # Pass the embedded text through the convolutional layer and apply a ReLU
        conved = F.relu(self.conv(embedded)) # Extract important features with ReLU
        conved = conved.mean(dim=2) # Eliminate extra layers and dimensions
        return self.fc(conved)

In [188]:
word_to_ix = {word: i for i, word in enumerate(unique_words)}

word_to_ix

{'an': 0,
 'amazing': 1,
 'story': 2,
 'a': 3,
 'this': 4,
 'like': 5,
 'really': 6,
 'I': 7,
 'novel': 8,
 'hate': 9,
 'terrible': 10,
 'This': 11,
 'first': 12,
 'love': 13,
 'do': 14,
 'is': 15,
 'book': 16,
 'not': 17}

In [189]:
# each word w in the sentence is replaced by its index in word_to_ix. If a word is not found in word_to_ix, the get method returns 0 as a default value. The result is a list of indices that represent the sentene (applied later in the model training)

torch.LongTensor([word_to_ix.get(w, 0) for sentence, _ in data for w in sentence]).unsqueeze(0) # gets the indexed in the word_to_ix vocabulary of the words in the data

tensor([[ 7, 13,  4, 12, 16, 11, 15,  0,  1,  8,  7,  6,  5,  4,  2,  7, 14, 17,
          5,  4, 16,  7,  9,  4,  8, 11, 15,  3, 10,  2]])

In [190]:
model = TextClassificationCNN(vocab_size, embed_dim)

model

TextClassificationCNN(
  (embedding): Embedding(18, 10)
  (conv): Conv1d(10, 10, kernel_size=(3,), stride=(1,), padding=(1,))
  (fc): Linear(in_features=10, out_features=2, bias=True)
)

In [191]:
import torch.optim as optim

# Define the loss function
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(10):
    for sentence, label in data:     
        # Clear the gradients
        model.zero_grad()
        sentence = torch.LongTensor([word_to_ix.get(w, 0) for w in sentence]).unsqueeze(0) # provides the indexed in the word_to_ix vocabulary of the sentences in data (sentences encoding). Necessary before applying the self.embedding to the text in the model
        label = torch.LongTensor([int(label)])
        outputs = model(sentence)
        loss = criterion(outputs, label)
        loss.backward()
        # print(loss)
        # Update the parameters
        optimizer.step()
print(sentence, output.shape, label) # just last one
print('Training complete!')

tensor([[11, 15,  3, 10,  2]]) torch.Size([15, 10]) tensor([0])
Training complete!


In [193]:
output

tensor([[ 3.2591e-01,  3.8871e-01, -1.0202e+00, -5.4242e-01, -3.8684e-02,
          1.1134e+00, -8.0390e-01,  4.2076e-01, -2.5217e+00, -1.4594e+00],
        [ 7.5349e-01, -1.7978e+00,  7.3524e-01,  2.4658e-01,  4.0211e-01,
          9.9624e-02,  1.7329e+00,  3.7023e-01, -6.5028e-01, -1.2494e+00],
        [ 1.3734e-01,  7.9268e-01,  9.0330e-01, -1.7315e+00,  7.3597e-02,
         -4.3615e-01,  3.9540e-02, -6.2524e-01, -6.6301e-01,  7.3274e-01],
        [ 2.5612e-01,  4.8039e-01,  1.1930e-01,  1.2893e+00,  7.9125e-01,
          2.8376e-02, -8.4166e-01,  2.0709e+00, -2.0737e+00,  1.3860e-01],
        [ 1.1167e-01,  2.1041e-01,  7.5730e-01,  3.9823e-01, -4.8729e-01,
          1.1858e-01, -9.9135e-01,  1.2653e+00, -5.5789e-01,  1.4413e-01],
        [ 6.9417e-01,  2.9416e-01, -1.0067e+00, -8.5463e-01, -1.0690e+00,
         -7.6825e-02,  9.2565e-01,  6.5659e-01,  6.8181e-01, -4.4510e-01],
        [-5.9533e-01,  1.1413e-02, -1.4163e+00,  5.8150e-01,  8.2895e-01,
         -6.7301e-01, -1.8067e-0

In [195]:
book_reviews = [
    "I love this book".split(),
    "I do not like this book".split()
]

book_reviews

[['I', 'love', 'this', 'book'], ['I', 'do', 'not', 'like', 'this', 'book']]

In [200]:
for review in book_reviews:
    # Convert the review words into tensor form
    input_tensor = torch.tensor([word_to_ix[w] for w in review], dtype=torch.long).unsqueeze(0) 
    print('input_tensor',input_tensor)
    
    # Get the model's output
    outputs = model(input_tensor)
    print('outputs',outputs)
    # Find the index of the most likely sentiment category
    _, predicted_label = torch.max(outputs.data, 1)
    print('predicted_label',predicted_label)
    
    # Convert the predicted label into a sentiment string
    sentiment = "Positive" if predicted_label.item() == 1 else "Negative"
    print(f"Book Review: {' '.join(review)}")
    print(f"Sentiment: {sentiment}\n")

input_tensor tensor([[ 7, 13,  4, 16]])
outputs tensor([[-0.0373, -0.0905]], grad_fn=<AddmmBackward0>)
predicted_label tensor([0])
Book Review: I love this book
Sentiment: Negative

input_tensor tensor([[ 7, 14, 17,  5,  4, 16]])
outputs tensor([[ 0.2010, -0.2404]], grad_fn=<AddmmBackward0>)
predicted_label tensor([0])
Book Review: I do not like this book
Sentiment: Negative



In [182]:
# Complete the RNN class
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.rnn(x, h0)
        out = out[:, -1, :] 
        out = self.fc(out)
        return out


In [214]:
input_size = 6
hidden_size = 32
num_layers = 2
num_classes = 3

# Initialize the model
rnn_model = RNNModel(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_model.parameters(), lr=0.01)

rnn_model

RNNModel(
  (rnn): RNN(6, 32, num_layers=2, batch_first=True)
  (fc): Linear(in_features=32, out_features=3, bias=True)
)

In [212]:
X_train_seq = torch.tensor([[[0., 0., 0., 1., 0., 1.]],

        [[0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0.]]])

In [217]:

y_train_seq = torch.tensor([2, 2, 2, 0, 2, 1])

In [218]:
# Train the model for ten epochs and zero the gradients
for epoch in range(10): 
    optimizer.zero_grad()
    outputs = rnn_model(X_train_seq)
    loss = criterion(outputs, y_train_seq)
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 1.1831812858581543
Epoch: 2, Loss: 1.054624319076538
Epoch: 3, Loss: 0.9578561186790466
Epoch: 4, Loss: 0.8862207531929016
Epoch: 5, Loss: 0.8430938720703125
Epoch: 6, Loss: 0.8329666256904602
Epoch: 7, Loss: 0.8425272107124329
Epoch: 8, Loss: 0.8445902466773987
Epoch: 9, Loss: 0.8297855257987976
Epoch: 10, Loss: 0.8045610785484314


In [219]:
# Initialize the LSTM and the output layer with parameters
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :] 
        out = self.fc(out)
        return out



In [220]:
# Initialize model with required parameters
lstm_model = LSTMModel(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(lstm_model.parameters(), lr=0.01)



In [221]:
# Train the model by passing the correct parameters and zeroing the gradient
for epoch in range(10): 
    optimizer.zero_grad()
    outputs = lstm_model(X_train_seq)
    loss = criterion(outputs, y_train_seq)
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 1.1166517734527588
Epoch: 2, Loss: 1.0943870544433594
Epoch: 3, Loss: 1.0730739831924438
Epoch: 4, Loss: 1.0514944791793823
Epoch: 5, Loss: 1.0285530090332031
Epoch: 6, Loss: 1.0035464763641357
Epoch: 7, Loss: 0.9762029647827148
Epoch: 8, Loss: 0.9467628002166748
Epoch: 9, Loss: 0.9162300229072571
Epoch: 10, Loss: 0.886702299118042


In [222]:
# Complete the GRU model
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)       
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) 
        out, _ = self.gru(x, h0)
        out = out[:, -1, :] 
        out = self.fc(out)
        return out



In [223]:
# Initialize the model
gru_model = GRUModel(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gru_model.parameters(), lr=0.01)



In [224]:
# Train the model and backpropagate the loss after initialization
for epoch in range(15): 
    optimizer.zero_grad()
    outputs = gru_model(X_train_seq)
    loss = criterion(outputs, y_train_seq)
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 1.0819989442825317
Epoch: 2, Loss: 1.0384217500686646
Epoch: 3, Loss: 0.9954107403755188
Epoch: 4, Loss: 0.952071487903595
Epoch: 5, Loss: 0.9101083278656006
Epoch: 6, Loss: 0.8736910820007324
Epoch: 7, Loss: 0.8495261073112488
Epoch: 8, Loss: 0.8441532254219055
Epoch: 9, Loss: 0.8538398146629333
Epoch: 10, Loss: 0.8614375591278076
Epoch: 11, Loss: 0.8570571541786194
Epoch: 12, Loss: 0.8426758646965027
Epoch: 13, Loss: 0.8241392970085144
Epoch: 14, Loss: 0.8065580725669861
Epoch: 15, Loss: 0.7926747798919678


In [226]:
from torchmetrics import Accuracy, Precision, Recall, F1Score

# Create an instance of the metrics
accuracy = Accuracy(task="multiclass", num_classes=3)
precision = Precision(task="multiclass", num_classes=3)
recall = Recall(task="multiclass", num_classes=3)
f1 = F1Score(task="multiclass", num_classes=3)



  from .autonotebook import tqdm as notebook_tqdm


In [228]:
X_test_seq = torch.tensor([[[0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 1., 1., 0.]],

        [[0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]]])

In [229]:
# Generate the predictions
outputs = rnn_model(X_test_seq)
_, predicted = torch.max(outputs, 1)

outputs,predicted

(tensor([[-0.5673, -0.6103,  0.8328],
         [-0.8141, -0.4340,  0.9553],
         [-0.5673, -0.6103,  0.8328],
         [-0.5673, -0.6103,  0.8328],
         [-0.9195, -0.4185,  0.9284],
         [-0.5673, -0.6103,  0.8328]], grad_fn=<AddmmBackward0>),
 tensor([2, 2, 2, 2, 2, 2]))

In [236]:
y_test_seq = torch.tensor([2, 1, 0, 1, 2, 2])

In [237]:
# Calculate the metrics
accuracy_score = accuracy(predicted, y_test_seq)
precision_score = precision(predicted, y_test_seq)
recall_score = recall(predicted, y_test_seq)
f1_score = f1(predicted, y_test_seq)
print("RNN Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_score, precision_score, recall_score, f1_score))

RNN Model - Accuracy: 0.5, Precision: 0.5, Recall: 0.5, F1 Score: 0.5


In [238]:
outputs = lstm_model(X_test_seq)
_, y_pred_lstm = torch.max(outputs, 1)

In [240]:

# Calculate metrics for the LSTM model
accuracy_1 = accuracy(y_pred_lstm, y_test_seq)
precision_1 = precision(y_pred_lstm, y_test_seq)
recall_1 = recall(y_pred_lstm, y_test_seq)
f1_1 = f1(y_pred_lstm, y_test_seq)
print("LSTM Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_1, precision_1, recall_1, f1_1))


LSTM Model - Accuracy: 0.5, Precision: 0.5, Recall: 0.5, F1 Score: 0.5


In [242]:
outputs = gru_model(X_test_seq)
_, y_pred_gru = torch.max(outputs, 1)

In [243]:
# Calculate metrics for the GRU model
accuracy_2 = accuracy(y_pred_gru, y_test_seq)
precision_2 = precision(y_pred_gru, y_test_seq)
recall_2 = recall(y_pred_gru, y_test_seq)
f1_2 = f1(y_pred_gru, y_test_seq)
print("GRU Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_2, precision_2, recall_2, f1_2))

GRU Model - Accuracy: 0.5, Precision: 0.5, Recall: 0.5, F1 Score: 0.5


In [244]:
data = 'The rabbit-hole went straight on like a tunnel for some way, and then dipped suddenly down, so suddenly that Alice had not a moment to think about stopping herself before she found herself falling down a very deep well.'

In [275]:
# Include an RNN layer and linear layer in RNNmodel class

class RNNmodel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNmodel, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, _ = self.rnn(x, h0)  
        out = self.fc(out[:, -1, :])  # out[:, -1, :] means “for each sequence in the batch, take the last hidden state” (In the context of RNNs, the last hidden state often contains information about the entire sequence because it has seen all the previous inputs in the sequence. This is why it’s commonly used for sequence classification tasks.)
        return out


In [276]:
# Instantiate the RNN model
model = RNNmodel(len(chars), 16, len(chars))
model

RNNmodel(
  (rnn): RNN(28, 16, batch_first=True)
  (fc): Linear(in_features=16, out_features=28, bias=True)
)

In [277]:
chars = ['y', ',','b','w','a','d','e','p','o','k','n','v','s','T','l','m','.','g','c','h','u','f','-','A',' ','r','t','i']
len(chars)

28

In [278]:
char_to_ix = {chars[i]:i for i in range(len(chars))}
char_to_ix

{'y': 0,
 ',': 1,
 'b': 2,
 'w': 3,
 'a': 4,
 'd': 5,
 'e': 6,
 'p': 7,
 'o': 8,
 'k': 9,
 'n': 10,
 'v': 11,
 's': 12,
 'T': 13,
 'l': 14,
 'm': 15,
 '.': 16,
 'g': 17,
 'c': 18,
 'h': 19,
 'u': 20,
 'f': 21,
 '-': 22,
 'A': 23,
 ' ': 24,
 'r': 25,
 't': 26,
 'i': 27}

In [279]:
# Instantiate the loss function
criterion = nn.CrossEntropyLoss()

# Instantiate the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    weight_decay: 0
)

In [282]:
inputs = torch.tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [291]:
inputs.shape # batch_size is 218, sequence_length is 1, and number_of_features (characters) is 28

torch.Size([218, 1, 28])

In [283]:
targets = torch.tensor([19,  6, 24, 25,  4,  2,  2, 27, 26, 22, 19,  8, 14,  6, 24,  3,  6, 10,
        26, 24, 12, 26, 25,  4, 27, 17, 19, 26, 24,  8, 10, 24, 14, 27,  9,  6,
        24,  4, 24, 26, 20, 10, 10,  6, 14, 24, 21,  8, 25, 24, 12,  8, 15,  6,
        24,  3,  4,  0,  1, 24,  4, 10,  5, 24, 26, 19,  6, 10, 24,  5, 27,  7,
         7,  6,  5, 24, 12, 20,  5,  5,  6, 10, 14,  0, 24,  5,  8,  3, 10,  1,
        24, 12,  8, 24, 12, 20,  5,  5,  6, 10, 14,  0, 24, 26, 19,  4, 26, 24,
        23, 14, 27, 18,  6, 24, 19,  4,  5, 24, 10,  8, 26, 24,  4, 24, 15,  8,
        15,  6, 10, 26, 24, 26,  8, 24, 26, 19, 27, 10,  9, 24,  4,  2,  8, 20,
        26, 24, 12, 26,  8,  7,  7, 27, 10, 17, 24, 19,  6, 25, 12,  6, 14, 21,
        24,  2,  6, 21,  8, 25,  6, 24, 12, 19,  6, 24, 21,  8, 20, 10,  5, 24,
        19,  6, 25, 12,  6, 14, 21, 24, 21,  4, 14, 14, 27, 10, 17, 24,  5,  8,
         3, 10, 24,  4, 24, 11,  6, 25,  0, 24,  5,  6,  6,  7, 24,  3,  6, 14,
        14, 16])

targets.shape

torch.Size([218])

In [284]:
# Train the model
for epoch in range(100):
    model.train()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch+1) % 10 == 0:
        print(f'Epoch {epoch+1}/100, Loss: {loss.item()}')


Epoch 10/100, Loss: 3.0296435356140137
Epoch 20/100, Loss: 2.730748414993286
Epoch 30/100, Loss: 2.5233685970306396
Epoch 40/100, Loss: 2.31036114692688
Epoch 50/100, Loss: 2.1362075805664062
Epoch 60/100, Loss: 2.0067527294158936
Epoch 70/100, Loss: 1.9184000492095947
Epoch 80/100, Loss: 1.8595623970031738
Epoch 90/100, Loss: 1.8199223279953003
Epoch 100/100, Loss: 1.793260931968689


In [285]:
print(outputs.shape) # outputs the probabilities of all possible characters

outputs

torch.Size([218, 28])


tensor([[-2.5415,  0.0058, -0.7994,  ...,  1.6438,  0.8553, -2.4360],
        [-0.8541, -1.9664,  1.2513,  ..., -3.4895,  2.6112,  2.6951],
        [-0.3373, -0.9688, -1.0104,  ...,  2.8426, -0.4365, -0.6327],
        ...,
        [-0.3373, -0.9688, -1.0104,  ...,  2.8426, -0.4365, -0.6327],
        [ 2.9877, -0.8687,  0.7187,  ..., -0.9503, -1.8045,  3.4610],
        [ 2.9877, -0.8687,  0.7187,  ..., -0.9503, -1.8045,  3.4610]],
       grad_fn=<AddmmBackward0>)

In [286]:
ix_to_char = {i:chars[i] for i in range(len(chars))}
ix_to_char

{0: 'y',
 1: ',',
 2: 'b',
 3: 'w',
 4: 'a',
 5: 'd',
 6: 'e',
 7: 'p',
 8: 'o',
 9: 'k',
 10: 'n',
 11: 'v',
 12: 's',
 13: 'T',
 14: 'l',
 15: 'm',
 16: '.',
 17: 'g',
 18: 'c',
 19: 'h',
 20: 'u',
 21: 'f',
 22: '-',
 23: 'A',
 24: ' ',
 25: 'r',
 26: 't',
 27: 'i'}

In [288]:
# Test the model
model.eval()
test_input = char_to_ix['r']
test_input = nn.functional.one_hot(torch.tensor(test_input).view(-1, 1), num_classes=len(chars)).float()
print(test_input.shape)
test_input

torch.Size([1, 1, 28])


tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]])

In [289]:
predicted_output = model(test_input)
predicted_char_ix = torch.argmax(predicted_output, 1).item()
print(f"Test Input: 'r', Predicted Output: '{ix_to_char[predicted_char_ix]}'")
predicted_output

Test Input: 'r', Predicted Output: 'a'


tensor([[ 2.6673, -3.4252,  0.7880, -0.7315,  3.4651, -0.1600,  2.6154, -0.6251,
          0.8910, -4.3891, -1.1892, -1.2683,  3.3465, -2.5057, -0.6777, -1.6488,
         -1.3966, -4.9496, -2.9807, -0.9279, -2.5757, -0.2586, -1.8951, -1.4640,
          2.3371, -0.2853, -0.8855,  1.4694]], grad_fn=<AddmmBackward0>)

In [290]:
import torch
import torch.nn as nn

# Define the generator class
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(seq_length,seq_length), nn.Sigmoid()) # the generator is designed to generate samples of size seq_length. The output of the generator is passed through a Sigmoid activation function, which squashes the output values to be between 0 and 1. So, the generated samples will indeed have values ranging from 0 to 1
    def forward(self, x):
        return self.model(x)

# Define the discriminator networks
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(seq_length, 1), nn.Sigmoid())
    def forward(self, x):
        return self.model(x)

In [7]:
seq_length = 5 #: Length of each synthetic data sequence
num_sequences = 100 #: Total number of sequences generated
num_epochs = 50 #: Number of complete passes through the dataset
print_every = 10 #: Output display frequency, showing results every 10 epochs

In [8]:
generator = Generator()
generator

Generator(
  (model): Sequential(
    (0): Linear(in_features=5, out_features=5, bias=True)
    (1): Sigmoid()
  )
)

In [10]:
discriminator = Discriminator()
discriminator

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=5, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

In [11]:
# Define the loss function and optimizer
criterion = nn.BCELoss()
optimizer_gen = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizer_disc = torch.optim.Adam(discriminator.parameters(), lr=0.001)

In [16]:
data = torch.tensor([[0., 1., 0., 0., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 0., 1., 1., 0.],
                    [0., 1., 0., 0., 1.],
                    [0., 1., 1., 1., 0.],
                    [0., 0., 1., 0., 1.],
                    [0., 0., 0., 1., 1.],
                    [0., 1., 0., 1., 1.],
                    [1., 1., 1., 0., 0.],
                    [0., 1., 1., 1., 0.],
                    [1., 0., 0., 0., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 1., 1., 0., 1.],
                    [1., 0., 1., 1., 0.],
                    [0., 0., 1., 0., 0.],
                    [1., 0., 1., 0., 1.],
                    [0., 0., 0., 1., 0.],
                    [1., 1., 0., 1., 1.],
                    [1., 1., 1., 0., 0.],
                    [1., 1., 0., 0., 0.],
                    [0., 1., 0., 1., 1.],
                    [1., 1., 1., 0., 1.],
                    [1., 1., 0., 1., 0.],
                    [0., 1., 1., 1., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [1., 0., 0., 0., 0.],
                    [0., 1., 1., 0., 0.],
                    [1., 0., 0., 1., 0.],
                    [0., 1., 1., 1., 1.],
                    [1., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.],
                    [1., 1., 0., 1., 0.],
                    [1., 0., 0., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [0., 0., 1., 0., 0.],
                    [0., 1., 0., 1., 0.],
                    [0., 1., 0., 0., 1.],
                    [1., 0., 1., 0., 1.],
                    [1., 0., 1., 1., 0.],
                    [1., 1., 1., 0., 0.],
                    [0., 0., 0., 0., 1.],
                    [1., 1., 1., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [1., 1., 0., 1., 1.],
                    [0., 1., 1., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 0., 1., 0., 1.],
                    [0., 0., 1., 0., 1.],
                    [1., 0., 1., 0., 1.],
                    [0., 0., 1., 1., 0.],
                    [1., 0., 0., 0., 1.],
                    [0., 1., 0., 1., 1.],
                    [1., 0., 1., 1., 0.],
                    [1., 1., 1., 0., 1.],
                    [0., 0., 0., 0., 1.],
                    [1., 1., 0., 1., 0.],
                    [1., 1., 1., 0., 0.],
                    [0., 1., 0., 1., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 1., 1., 1., 1.],
                    [1., 0., 0., 1., 1.],
                    [1., 0., 1., 0., 0.],
                    [0., 0., 0., 0., 1.],
                    [0., 1., 0., 1., 1.],
                    [1., 0., 1., 1., 0.],
                    [0., 0., 0., 1., 1.],
                    [0., 1., 1., 1., 1.],
                    [1., 0., 1., 1., 1.],
                    [0., 1., 0., 1., 0.],
                    [1., 0., 1., 0., 1.],
                    [1., 1., 0., 1., 0.],
                    [0., 0., 0., 1., 1.],
                    [0., 1., 1., 1., 0.],
                    [0., 0., 0., 0., 0.],
                    [0., 0., 1., 1., 1.],
                    [1., 1., 1., 0., 1.],
                    [1., 0., 0., 0., 0.],
                    [1., 1., 1., 1., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 1., 0., 1., 0.],
                    [1., 1., 1., 1., 1.],
                    [0., 0., 0., 0., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 0., 0., 1., 1.],
                    [0., 0., 0., 0., 0.],
                    [0., 1., 0., 1., 1.],
                    [0., 0., 0., 0., 0.],
                    [1., 1., 1., 1., 0.],
                    [1., 1., 0., 1., 0.],
                    [1., 1., 1., 0., 1.],
                    [1., 0., 0., 0., 1.],
                    [0., 1., 0., 1., 1.],
                    [0., 1., 0., 1., 0.],
                    [1., 0., 1., 0., 1.],
                    [1., 0., 1., 1., 0.],
                    [1., 0., 1., 1., 1.],
                    [0., 1., 1., 0., 1.]])

In [138]:
test_data = torch.tensor([1., 0., 1., 1., 0.])
test_real = test_data.unsqueeze(0)
test_noise = torch.rand((1, seq_length))
test_noise = generator(test_noise)
disc_test_real = discriminator(test_real)
disc_test_noise = discriminator(test_noise.detach())
disc_test_real, torch.ones_like(disc_test_real)
disc_test_noise, torch.zeros_like(disc_test_noise)

(tensor([[0.4962]], grad_fn=<SigmoidBackward0>), tensor([[0.]]))

In [139]:
# Start of the training process
for epoch in range(num_epochs):
    # For each epoch, the model goes through the entire dataset
    for real_data in data:
        # The real data is unsqueezed to add an extra dimension: unsqueeze(0) is used to create a batch with size 1 from the single sample of real_data. This ensures that the data can be correctly processed by the discriminator and generator, which are designed to handle batches of data.
        real_data = real_data.unsqueeze(0)
        # A random noise vector is generated
        noise = torch.rand((1, seq_length))
        # This noise vector is then passed through the generator to create the fake data
        fake_data = generator(noise)
        # The discriminator is then used to classify the real and fake data (outputs '1' if considers the input real, and outputs '0' if considers the input fake: ideally, disc_real should be always classified to 1, and disc_fake should be always classified to 0)
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.detach()) # we want to use the output of the generator 'fake_data' to feed into the discriminator but don’t want these operations to influence the gradients of the generator (because here we're training the discriminator, and not the generator)
        # The loss for the discriminator is calculated as the sum of the losses for the real and fake data
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
        # The gradients are then backpropagated through the discriminator
        optimizer_disc.zero_grad()
        loss_disc.backward()
        # The discriminator's weights are updated
        optimizer_disc.step()

        # The generator is trained here
        disc_fake = discriminator(fake_data) # In this case (when we train the generator), we need to calculate gradients for the generator because we want to update its parameters to minimize this loss. So, we don’t use detach() when passing the fake data through the discriminator.
        # The loss for the generator is calculated
        loss_gen = criterion(disc_fake, torch.ones_like(disc_fake)) # the generator loss is small if it's able to fool the discriminator (so if the outputs of the discriminator for the fake data is 1)
        # The gradients are backpropagated through the generator
        optimizer_gen.zero_grad()
        loss_gen.backward()
        # The generator's weights are updated
        optimizer_gen.step()

    # The losses for the generator and the discriminator are printed every 'print_every' epochs
    if (epoch+1) % print_every == 0:
        print(f"Epoch {epoch+1}/{num_epochs}:\t Generator loss: {loss_gen.item()}\t Discriminator loss: {loss_disc.item()}")


Epoch 10/50:	 Generator loss: 0.6851096749305725	 Discriminator loss: 1.358232855796814
Epoch 20/50:	 Generator loss: 0.6987009644508362	 Discriminator loss: 1.3918246030807495
Epoch 30/50:	 Generator loss: 0.6911744475364685	 Discriminator loss: 1.3963098526000977
Epoch 40/50:	 Generator loss: 0.688890814781189	 Discriminator loss: 1.3672797679901123
Epoch 50/50:	 Generator loss: 0.6931207180023193	 Discriminator loss: 1.423922061920166


In [141]:
# After the training process, the real data is printed
print("\nReal data: ")
print(data[:5])

# The generator is used to generate data and this generated data is printed
print("\nGenerated data: ")
for _ in range(5):
    noise = torch.rand((1, seq_length))
    generated_data = generator(noise)
    # The generated data is detached from its computation graph and rounded before printing
    print(noise)
    print(torch.round(generated_data).detach())



Real data: 
tensor([[0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 1.],
        [0., 0., 1., 1., 0.],
        [0., 1., 0., 0., 1.],
        [0., 1., 1., 1., 0.]])

Generated data: 
tensor([[0.6386, 0.8570, 0.9376, 0.7493, 0.2763]])
tensor([[1., 0., 1., 0., 1.]])
tensor([[0.6004, 0.4461, 0.4721, 0.6752, 0.0456]])
tensor([[1., 0., 1., 0., 1.]])
tensor([[0.4224, 0.9064, 0.4799, 0.4294, 0.1463]])
tensor([[1., 0., 1., 0., 1.]])
tensor([[0.9468, 0.9875, 0.4684, 0.8525, 0.7690]])
tensor([[1., 0., 1., 0., 1.]])
tensor([[0.6852, 0.2630, 0.9440, 0.4777, 0.2111]])
tensor([[1., 0., 1., 0., 0.]])


In [145]:
from transformers import GPT2Tokenizer,GPT2LMHeadModel 

# Initialize the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Initialize the pre-trained model
model = GPT2LMHeadModel.from_pretrained('gpt2')

seed_text = "I opened the window"

# Encode the seed text to get input tensors
input_ids = tokenizer.encode(seed_text, return_tensors='pt')

# Generate text from the model
output = model.generate(input_ids, max_length=100, temperature=0.7, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id) 

generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)

I opened the window and saw a man standing in the middle of the street. He was wearing a black hoodie and a white shirt. I asked him what he was doing and he said he had a gun.

I asked if he knew who he is and what his intentions were. The man said that he wanted to kill me. Then he started to walk away. When I looked back at him, he looked like he didn't know what to do. It was like a nightmare.


In [152]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Initalize tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

input_prompt = "translate English to French: 'Hello, give me the pen'"

# Encode the input prompt using the tokenizer
input_ids = tokenizer.encode(input_prompt, return_tensors="pt")

# Generate the translated ouput
output = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated text:",generated_text)

Generated text: "Jo, donnez-moi le stylo"


In [155]:
from torchmetrics.text import BLEUScore, ROUGEScore

reference_text = "Once upon a time, there was a little girl who lived in a village near the forest."
generated_text = "Once upon a time, the world was a place of great beauty and great danger. The world of the gods was the place where the great gods were born, and where they were to live."

# Initialize BLEU and ROUGE scorers
bleu = BLEUScore()
rouge = ROUGEScore()

# Calculate the BLEU and ROUGE scores
bleu_score = bleu([generated_text], [[reference_text]])
rouge_score = rouge([generated_text], [[reference_text]])

# Print the BLEU and ROUGE scores
print("BLEU Score:", bleu_score.item())
print("ROUGE Score:", rouge_score)

BLEU Score: 0.08170417696237564
ROUGE Score: {'rouge1_fmeasure': tensor(0.2692), 'rouge1_precision': tensor(0.2000), 'rouge1_recall': tensor(0.4118), 'rouge2_fmeasure': tensor(0.1600), 'rouge2_precision': tensor(0.1176), 'rouge2_recall': tensor(0.2500), 'rougeL_fmeasure': tensor(0.2692), 'rougeL_precision': tensor(0.2000), 'rougeL_recall': tensor(0.4118), 'rougeLsum_fmeasure': tensor(0.2692), 'rougeLsum_precision': tensor(0.2000), 'rougeLsum_recall': tensor(0.4118)}


In [4]:
from transformers import BertTokenizer, BertForSequenceClassification

# Load the BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

model

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [6]:
import torch

# Setup the optimizer using model parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 1e-05
    maximize: False
    weight_decay: 0.01
)

In [10]:
texts = ['I love this!',
 'This is terrible.',
 'Amazing experience!',
 'Not my cup of tea.']

labels = [1, 0, 1, 0]

In [11]:
# Tokenize your data and return PyTorch tensors
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=32)
inputs["labels"] = torch.tensor(labels)

inputs # The tokenizer also adds these special tokens ([CLS] and [SEP]) to each sentence, which is why the maximum number of tokens is 8 (in the sencence 'Not my cup of tea.' --> [ 101, 2025, 2026, 2452, 1997, 5572, 1012,  102]). 

{'input_ids': tensor([[ 101, 1045, 2293, 2023,  999,  102,    0,    0],
        [ 101, 2023, 2003, 6659, 1012,  102,    0,    0],
        [ 101, 6429, 3325,  999,  102,    0,    0,    0],
        [ 101, 2025, 2026, 2452, 1997, 5572, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([1, 0, 1, 0])}

In [13]:
model.train()
for epoch in range(2):
    outputs = model(**inputs)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 0.6970258951187134
Epoch: 2, Loss: 0.6908831596374512


In [32]:
text = "I had a bad day!"

# Tokenize the text and return PyTorch tensors
input_eval = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=32)
print('input_eval',input_eval)
outputs_eval = model(**input_eval)
print('outputs_eval',outputs_eval)


input_eval {'input_ids': tensor([[ 101, 1045, 2018, 1037, 2919, 2154,  999,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
outputs_eval SequenceClassifierOutput(loss=None, logits=tensor([[ 0.0975, -0.0061]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [192]:
outputs_eval.logits

tensor([[ 0.0975, -0.0061]], grad_fn=<AddmmBackward0>)

In [33]:
# Convert the output logits to probabilities
predictions = torch.nn.functional.sigmoid(outputs_eval.logits)

predictions

tensor([[0.5243, 0.4985]], grad_fn=<SigmoidBackward0>)

In [34]:
# Convert the output logits to probabilities
predictions = torch.nn.functional.softmax(outputs_eval.logits, dim=-1)

predictions

tensor([[0.5259, 0.4741]], grad_fn=<SoftmaxBackward0>)

In [35]:
# Display the sentiments
predicted_label = 'positive' if torch.argmax(predictions) > 0 else 'negative'
print(f"Text: {text}\nSentiment: {predicted_label}")

Text: I had a bad day!
Sentiment: negative


In [95]:
import torch.nn as nn

class TransformerEncoder(nn.Module):
    
    def __init__(self, embed_size, heads, num_layers, dropout): # heads is the number of parallel attention layers (or “heads”) in the multi-head attention mechanism. Each head learns a different type of attention and then the model combines the results from all heads.
        super(TransformerEncoder, self).__init__()
        # Initialize the encoder 
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_size, nhead=heads),
            num_layers=num_layers)
        # Define the fully connected layer
        self.fc = nn.Linear(embed_size, 2)

    def forward(self, x):
        # Pass the input through the transformer encoder 
        x = self.encoder(x)
        x = x.mean(dim=1) # mean often calculated when you want to convert a sequence output (for example, a sequence of word embeddings) into a single vector that can be used for classification. The mean operation effectively captures the ‘average’ representation of all the words in the sequence.
        return self.fc(x)



In [131]:
model = TransformerEncoder(embed_size=512, heads=8, num_layers=3, dropout=0.5)
model

TransformerEncoder(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=512, out_features=2, bias=True)
)

In [132]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [133]:
train_sentences = ['I love this product', 'This is terrible', 'Could be better']
train_labels = [1, 0, 0]

In [120]:
token_embeddings ={'I': torch.tensor([[0.4724, 0.5557, 0.4722, 0.5729, 0.5307, 0.9750, 0.8660, 0.7654, 0.4658,
          0.2715, 0.6433, 0.6470, 0.0429, 0.2430, 0.8058, 0.8350, 0.3995, 0.5057,
          0.4583, 0.3466, 0.4189, 0.5420, 0.3751, 0.0058, 0.2638, 0.0964, 0.5558,
          0.2919, 0.0440, 0.0876, 0.8482, 0.1768, 0.0022, 0.6446, 0.8069, 0.7994,
          0.7561, 0.4435, 0.6643, 0.1541, 0.0802, 0.9589, 0.6049, 0.2023, 0.5180,
          0.9577, 0.9393, 0.3826, 0.9703, 0.4290, 0.7170, 0.0488, 0.0913, 0.6126,
          0.1487, 0.7760, 0.9713, 0.2769, 0.6659, 0.1106, 0.9111, 0.3310, 0.7131,
          0.6214, 0.7096, 0.0343, 0.9956, 0.6589, 0.4008, 0.5218, 0.5822, 0.5437,
          0.5674, 0.3276, 0.5489, 0.7580, 0.6542, 0.3743, 0.8366, 0.8950, 0.3374,
          0.7953, 0.5709, 0.2408, 0.1108, 0.1644, 0.7186, 0.3075, 0.4607, 0.1386,
          0.4156, 0.1327, 0.8914, 0.1970, 0.3939, 0.7728, 0.4647, 0.1176, 0.7855,
          0.9218, 0.8032, 0.1213, 0.8072, 0.5489, 0.1751, 0.1391, 0.5596, 0.9518,
          0.4132, 0.7992, 0.4523, 0.1384, 0.1769, 0.0551, 0.7344, 0.3447, 0.2006,
          0.3341, 0.4965, 0.8561, 0.6418, 0.7935, 0.3037, 0.3758, 0.4237, 0.4331,
          0.7871, 0.4673, 0.5429, 0.5747, 0.9529, 0.8706, 0.2877, 0.0265, 0.3033,
          0.3784, 0.8570, 0.7634, 0.5767, 0.0142, 0.5429, 0.2438, 0.0991, 0.7664,
          0.3974, 0.2828, 0.4412, 0.6825, 0.8435, 0.4237, 0.0114, 0.7765, 0.8383,
          0.2938, 0.5893, 0.1870, 0.7938, 0.9394, 0.0395, 0.0248, 0.5547, 0.9700,
          0.1410, 0.9871, 0.6451, 0.3110, 0.1185, 0.7799, 0.6624, 0.2143, 0.9620,
          0.4180, 0.5459, 0.6799, 0.0414, 0.9062, 0.3307, 0.8211, 0.1592, 0.0665,
          0.3260, 0.4353, 0.0146, 0.4330, 0.5727, 0.2316, 0.7482, 0.3631, 0.2879,
          0.3853, 0.6852, 0.1552, 0.6271, 0.1736, 0.0279, 0.0316, 0.2504, 0.2129,
          0.4054, 0.7876, 0.2000, 0.7274, 0.8623, 0.9110, 0.8755, 0.6636, 0.5635,
          0.8232, 0.9384, 0.6593, 0.6574, 0.6949, 0.4896, 0.4622, 0.3240, 0.0959,
          0.9655, 0.3602, 0.4973, 0.1227, 0.0671, 0.6328, 0.7856, 0.6733, 0.7388,
          0.5601, 0.1548, 0.6681, 0.4272, 0.6121, 0.0556, 0.5449, 0.3339, 0.2709,
          0.8426, 0.9179, 0.4196, 0.6820, 0.3140, 0.4514, 0.5685, 0.5238, 0.5477,
          0.8930, 0.7630, 0.7924, 0.2452, 0.2047, 0.9237, 0.4576, 0.3021, 0.1066,
          0.2762, 0.8644, 0.2386, 0.9506, 0.0805, 0.1870, 0.7464, 0.9408, 0.1643,
          0.1643, 0.1321, 0.7372, 0.2180, 0.6723, 0.3140, 0.1717, 0.6649, 0.9463,
          0.7329, 0.3882, 0.6719, 0.8262, 0.0635, 0.1892, 0.9027, 0.4663, 0.4737,
          0.9851, 0.3101, 0.8937, 0.0725, 0.7412, 0.4631, 0.9490, 0.8089, 0.7272,
          0.6552, 0.8817, 0.7996, 0.6931, 0.6791, 0.7083, 0.9935, 0.0685, 0.7763,
          0.1802, 0.7639, 0.4640, 0.0369, 0.1886, 0.9038, 0.6049, 0.0194, 0.0413,
          0.5101, 0.4828, 0.2194, 0.5816, 0.8162, 0.3853, 0.7772, 0.7672, 0.9650,
          0.7339, 0.3081, 0.5660, 0.0896, 0.0315, 0.2151, 0.8712, 0.6279, 0.3944,
          0.2379, 0.2325, 0.4734, 0.5488, 0.7478, 0.3705, 0.4324, 0.0179, 0.6656,
          0.7680, 0.7536, 0.1472, 0.4044, 0.9004, 0.7693, 0.1209, 0.9309, 0.4581,
          0.1227, 0.1692, 0.9075, 0.1840, 0.0885, 0.8661, 0.3267, 0.3525, 0.1377,
          0.8145, 0.0948, 0.1702, 0.4904, 0.9424, 0.1265, 0.8899, 0.5025, 0.1213,
          0.5150, 0.8924, 0.0663, 0.9330, 0.6946, 0.5691, 0.3163, 0.6854, 0.7642,
          0.1268, 0.2413, 0.2092, 0.0776, 0.5629, 0.7003, 0.9788, 0.9653, 0.1146,
          0.2849, 0.7719, 0.5733, 0.6383, 0.4382, 0.1576, 0.6963, 0.4031, 0.2137,
          0.1679, 0.2471, 0.8445, 0.2093, 0.3184, 0.0368, 0.4316, 0.0578, 0.2638,
          0.0537, 0.4201, 0.3404, 0.2197, 0.2578, 0.3701, 0.6253, 0.1885, 0.6228,
          0.0451, 0.7500, 0.7943, 0.9623, 0.1766, 0.8034, 0.3898, 0.4364, 0.3338,
          0.4876, 0.7123, 0.0904, 0.6353, 0.6971, 0.5448, 0.0879, 0.1763, 0.5327,
          0.8414, 0.0726, 0.9174, 0.0661, 0.8022, 0.2021, 0.4678, 0.1743, 0.6966,
          0.4367, 0.9017, 0.8206, 0.1048, 0.6924, 0.3328, 0.9457, 0.0208, 0.7397,
          0.2351, 0.2538, 0.0788, 0.7390, 0.4120, 0.5591, 0.9879, 0.3683, 0.9222,
          0.4069, 0.6487, 0.7284, 0.9301, 0.3184, 0.8739, 0.5629, 0.0029, 0.7355,
          0.3861, 0.7867, 0.2815, 0.7718, 0.4998, 0.4605, 0.8324, 0.6346, 0.2724,
          0.6385, 0.7876, 0.9332, 0.5319, 0.9088, 0.6663, 0.6732, 0.5441, 0.0050,
          0.8056, 0.7128, 0.5121, 0.9000, 0.3297, 0.7064, 0.6252, 0.0995, 0.0409,
          0.1861, 0.7331, 0.6377, 0.2895, 0.3633, 0.4640, 0.1234, 0.1535, 0.1658,
          0.6956, 0.5257, 0.2348, 0.9876, 0.1082, 0.1669, 0.2644, 0.3871, 0.2160,
          0.8936, 0.3694, 0.9051, 0.2250, 0.5377, 0.5699, 0.0535, 0.9260]]),
 'love': torch.tensor([[3.2036e-01, 3.6033e-01, 9.2827e-01, 3.4356e-01, 6.3648e-01, 2.5919e-01,
          3.4643e-01, 9.0817e-01, 2.7483e-02, 1.5359e-02, 6.5812e-01, 2.4595e-01,
          6.5682e-01, 1.3490e-01, 4.1557e-01, 9.8591e-01, 5.1079e-01, 3.8979e-01,
          3.5355e-01, 4.0449e-01, 5.0372e-01, 1.1598e-01, 4.4389e-01, 1.7553e-02,
          6.0470e-01, 9.6230e-01, 6.2912e-01, 9.9945e-01, 9.9020e-01, 2.4198e-01,
          2.2602e-01, 7.4740e-02, 2.7860e-01, 8.2168e-01, 4.2395e-01, 6.5530e-01,
          6.0443e-01, 4.3612e-01, 5.9571e-01, 1.0397e-01, 3.9118e-01, 1.6040e-01,
          3.3506e-01, 4.6747e-01, 7.1252e-01, 3.3054e-01, 9.8692e-01, 7.4542e-01,
          5.3865e-01, 9.3861e-01, 5.5719e-01, 5.7676e-01, 1.4788e-02, 9.8216e-02,
          4.2174e-01, 4.5211e-01, 9.6997e-01, 5.4045e-01, 3.8381e-01, 3.4024e-01,
          5.4042e-01, 6.5878e-01, 7.3857e-01, 9.0276e-01, 6.5755e-01, 9.6673e-01,
          6.6789e-01, 9.6161e-01, 5.7773e-01, 8.8480e-01, 4.5164e-01, 5.0816e-01,
          3.8049e-01, 1.2847e-01, 6.7156e-01, 6.8460e-01, 7.7264e-01, 2.5562e-01,
          7.0243e-01, 4.1561e-01, 3.5926e-01, 2.9409e-01, 2.3800e-01, 6.3042e-01,
          2.7775e-01, 6.8889e-01, 2.1369e-01, 8.9696e-01, 3.3322e-01, 6.9139e-01,
          4.5683e-01, 4.1063e-01, 1.2425e-01, 8.8525e-01, 4.5305e-01, 3.9848e-01,
          3.8640e-01, 3.2918e-01, 4.1729e-01, 2.4877e-01, 8.1608e-01, 2.7557e-01,
          2.5460e-01, 3.0192e-01, 3.0015e-01, 2.8333e-01, 4.8435e-01, 6.7392e-01,
          2.8385e-01, 7.9593e-01, 7.7256e-01, 1.6829e-02, 1.0197e-01, 3.9928e-02,
          9.5749e-01, 2.6436e-01, 6.4600e-01, 9.0869e-01, 6.2970e-01, 5.1019e-01,
          1.4711e-01, 2.3966e-01, 8.9120e-01, 9.8037e-01, 3.1893e-01, 5.1067e-01,
          5.1696e-01, 7.0687e-02, 7.1846e-01, 4.6714e-01, 8.0869e-01, 7.9392e-01,
          2.7444e-01, 5.3310e-01, 3.3955e-02, 5.0137e-01, 1.5094e-01, 6.5118e-01,
          4.8986e-01, 6.7954e-01, 6.5513e-01, 7.9500e-01, 1.0295e-01, 4.1340e-01,
          3.9518e-01, 6.7223e-02, 8.5010e-01, 4.9664e-01, 3.8700e-01, 5.7411e-01,
          2.5954e-01, 6.1618e-01, 3.3472e-01, 2.4972e-01, 8.3231e-01, 8.4710e-01,
          1.4280e-01, 6.7241e-01, 6.6574e-01, 7.2743e-01, 1.8835e-01, 4.8666e-01,
          8.4496e-01, 4.7414e-01, 3.0836e-01, 2.4621e-01, 4.0414e-01, 6.6992e-01,
          3.8960e-01, 8.9032e-01, 8.7877e-01, 9.5696e-01, 3.4313e-01, 5.3502e-01,
          7.4658e-01, 4.6924e-01, 2.3842e-01, 4.1500e-01, 9.6542e-01, 2.2353e-02,
          8.8304e-01, 2.8826e-01, 9.1021e-01, 8.2349e-01, 6.7239e-01, 4.8306e-01,
          8.0997e-01, 9.8335e-01, 5.4800e-01, 9.6262e-01, 1.4451e-01, 1.9257e-01,
          6.8108e-01, 8.7262e-01, 8.3641e-01, 6.0762e-01, 9.0738e-01, 2.3181e-01,
          4.5551e-01, 2.6621e-01, 6.8178e-01, 9.7016e-01, 7.4338e-01, 4.2097e-01,
          3.7421e-02, 7.7656e-01, 5.6367e-01, 6.7437e-01, 8.1986e-01, 9.4200e-01,
          7.3774e-01, 6.8157e-01, 6.0357e-01, 8.1498e-01, 2.8587e-01, 6.7951e-01,
          9.5034e-01, 8.1549e-01, 4.5450e-01, 9.3813e-01, 7.5757e-01, 3.3047e-01,
          5.9874e-01, 2.0421e-01, 6.6549e-01, 4.3022e-01, 3.2654e-01, 7.3410e-01,
          4.4518e-01, 1.7529e-01, 7.3638e-02, 9.4260e-01, 2.1356e-01, 9.4875e-01,
          8.3785e-01, 8.8319e-01, 9.2358e-01, 5.8857e-01, 7.2927e-01, 6.1131e-01,
          7.9429e-01, 3.6426e-01, 8.5038e-01, 6.4151e-01, 9.0555e-01, 7.9682e-01,
          3.4857e-01, 1.2264e-01, 2.0829e-01, 8.2768e-01, 3.9958e-01, 2.7688e-01,
          1.2426e-01, 9.4627e-01, 8.6441e-01, 7.2914e-01, 8.5059e-01, 3.5931e-01,
          7.5534e-01, 8.5380e-01, 8.9516e-01, 2.5662e-02, 1.3208e-01, 6.3183e-01,
          3.2510e-01, 5.9104e-01, 6.3298e-01, 4.7886e-02, 6.7838e-01, 1.1205e-01,
          2.7382e-01, 5.7415e-01, 5.6295e-01, 7.1518e-01, 4.0177e-01, 7.9128e-01,
          7.4653e-01, 2.4288e-01, 8.6657e-01, 6.1789e-01, 7.2088e-01, 8.5209e-01,
          2.5409e-01, 6.2985e-01, 2.9316e-01, 9.2840e-01, 8.3907e-01, 1.1965e-02,
          3.2071e-02, 8.2066e-01, 3.3579e-01, 3.4440e-01, 8.6374e-01, 5.7256e-01,
          2.3639e-01, 5.5647e-01, 8.9342e-01, 1.5993e-02, 1.5245e-02, 2.4741e-01,
          2.7773e-01, 3.2219e-01, 1.4215e-01, 1.6074e-01, 7.3274e-01, 8.4506e-01,
          7.9081e-02, 2.9455e-02, 4.3403e-01, 8.3863e-01, 2.2666e-01, 8.7737e-01,
          9.7102e-01, 8.7803e-01, 4.9161e-01, 5.6402e-01, 6.7922e-01, 3.6035e-02,
          4.9079e-02, 2.6811e-01, 5.1638e-01, 3.7400e-01, 1.1757e-01, 5.1941e-01,
          2.9464e-01, 7.4052e-01, 6.6815e-01, 6.8369e-02, 8.0137e-01, 6.2780e-01,
          1.4823e-01, 7.2719e-01, 5.1970e-01, 1.8128e-01, 2.1048e-01, 9.5972e-01,
          6.6425e-01, 3.2566e-01, 3.1512e-01, 8.3926e-01, 2.1656e-01, 4.5021e-01,
          4.9758e-01, 8.1101e-01, 7.0378e-01, 4.7874e-01, 7.7223e-02, 7.7517e-01,
          1.7812e-01, 1.9574e-01, 1.4003e-01, 1.6532e-01, 7.9411e-01, 1.8380e-01,
          2.8949e-01, 6.5157e-02, 5.3994e-01, 5.9100e-01, 5.6086e-01, 7.3259e-01,
          5.6936e-01, 3.8767e-01, 1.4458e-01, 6.9578e-01, 4.6396e-03, 4.1772e-01,
          1.7515e-01, 6.3771e-02, 7.2628e-01, 3.2371e-01, 5.9699e-01, 3.9950e-01,
          8.9041e-01, 4.5183e-01, 1.6112e-01, 6.7934e-01, 2.0272e-01, 7.7799e-01,
          3.4494e-01, 3.1030e-01, 1.7292e-01, 3.2287e-01, 3.0720e-01, 9.5505e-01,
          2.4141e-01, 5.1393e-01, 6.3598e-01, 8.3891e-01, 3.8464e-01, 4.5765e-01,
          8.7500e-02, 6.8632e-01, 4.6075e-02, 7.0483e-01, 4.8131e-01, 2.8899e-01,
          1.0429e-01, 8.3112e-01, 6.9843e-01, 7.4178e-01, 7.4304e-01, 4.4626e-01,
          3.0765e-01, 1.5243e-01, 1.2381e-01, 9.9301e-01, 6.2549e-01, 1.3343e-02,
          5.9856e-01, 2.5435e-01, 7.9039e-02, 3.1165e-01, 2.2349e-01, 2.4348e-04,
          1.5493e-01, 1.1544e-01, 8.0152e-01, 6.2459e-01, 5.9629e-01, 5.1095e-01,
          3.2502e-01, 7.4785e-01, 5.3887e-01, 2.0316e-01, 8.8245e-01, 6.1536e-01,
          8.3389e-01, 3.9671e-01, 1.3894e-01, 7.7571e-01, 5.1049e-01, 9.7087e-01,
          3.9999e-01, 8.4705e-01, 7.3011e-03, 4.4728e-01, 8.9888e-01, 8.8678e-01,
          6.1526e-01, 3.1312e-01, 7.8370e-01, 4.6366e-01, 1.5103e-01, 7.3700e-01,
          6.6630e-01, 7.7368e-01, 1.9991e-01, 4.7436e-01, 9.7392e-01, 6.2332e-02,
          4.6405e-01, 6.5271e-01, 2.0993e-01, 4.3995e-01, 2.8320e-01, 1.4868e-01,
          1.3027e-01, 1.8744e-01, 5.2102e-01, 6.3710e-02, 7.0005e-01, 5.1757e-01,
          1.5453e-01, 2.1893e-01, 2.2883e-02, 4.6362e-01, 4.9436e-01, 3.3523e-02,
          1.8064e-01, 4.2461e-02, 9.7187e-01, 5.6076e-01, 1.5483e-01, 5.8740e-01,
          1.0005e-01, 6.9574e-01, 4.8209e-01, 5.4963e-01, 6.5602e-01, 3.2378e-01,
          1.0627e-01, 6.2712e-01, 8.9167e-01, 3.1560e-01, 4.9958e-01, 1.5861e-01,
          9.3944e-01, 7.6696e-01, 6.9313e-01, 6.8548e-01, 6.9551e-01, 9.3898e-02,
          7.3607e-01, 6.5923e-01, 6.7167e-01, 6.9484e-01, 7.5993e-01, 8.2951e-02,
          8.0416e-01, 2.8756e-01, 2.2189e-03, 2.4566e-01, 7.2745e-01, 8.4095e-02,
          3.7950e-01, 6.5057e-01, 2.4089e-01, 6.5357e-01, 9.1263e-01, 6.4453e-01,
          1.0763e-01, 7.5557e-02]]),
 'this': torch.tensor([[2.3834e-01, 2.5048e-01, 9.1889e-02, 7.7154e-02, 4.8834e-01, 4.4117e-01,
          3.2949e-02, 1.8108e-01, 4.1188e-01, 9.4939e-02, 3.0998e-01, 2.1373e-01,
          1.2518e-01, 9.3681e-01, 5.3646e-02, 2.2375e-01, 9.6855e-01, 5.5084e-03,
          9.6530e-02, 7.9439e-01, 9.6744e-02, 7.2176e-01, 7.7603e-01, 2.7444e-01,
          8.7958e-01, 7.9074e-01, 2.1711e-01, 3.7925e-01, 7.9808e-01, 1.5514e-02,
          1.1102e-01, 1.8307e-01, 3.4134e-01, 2.1153e-01, 6.9718e-01, 5.0599e-01,
          1.7129e-01, 1.9581e-01, 1.2978e-01, 8.1371e-01, 2.8152e-01, 4.5804e-01,
          1.7619e-02, 1.3737e-01, 2.3440e-01, 7.5935e-01, 7.3699e-01, 2.6848e-01,
          9.7470e-01, 6.2552e-01, 8.8066e-01, 7.8964e-01, 5.2471e-01, 6.3017e-01,
          2.6484e-01, 2.4195e-01, 3.6620e-01, 9.5759e-01, 6.2869e-02, 4.4051e-01,
          2.0519e-01, 8.0272e-01, 1.0123e-01, 7.1511e-01, 7.1802e-01, 1.4923e-02,
          6.3503e-01, 2.8518e-01, 9.6902e-01, 6.4722e-01, 5.5112e-01, 6.6546e-01,
          4.0018e-01, 9.0346e-01, 9.5718e-01, 2.8202e-01, 7.6771e-01, 7.4522e-01,
          9.6753e-01, 4.5011e-01, 5.3971e-01, 1.5135e-01, 5.8569e-01, 7.4927e-02,
          6.7146e-01, 9.3391e-01, 2.0750e-01, 4.7298e-01, 8.6783e-02, 9.0324e-01,
          6.7819e-01, 4.1562e-01, 6.8531e-01, 3.9879e-01, 4.6090e-01, 2.4863e-01,
          8.4769e-01, 7.7058e-01, 6.3566e-01, 3.5460e-01, 4.2630e-01, 7.7159e-02,
          7.1251e-03, 3.2415e-01, 1.8531e-01, 5.3628e-01, 7.6089e-01, 5.6363e-01,
          2.9058e-01, 5.1506e-01, 5.7332e-01, 1.6020e-01, 1.3075e-01, 4.5874e-01,
          7.0538e-01, 4.7480e-01, 7.1244e-01, 1.5305e-01, 3.1989e-01, 7.0412e-01,
          9.8758e-01, 1.6023e-01, 1.1655e-01, 8.8878e-01, 8.9673e-01, 9.9173e-01,
          7.1017e-01, 1.4529e-01, 4.0468e-01, 9.8473e-01, 1.7083e-01, 4.1306e-01,
          8.5934e-01, 3.1694e-01, 8.3548e-01, 2.4473e-02, 2.6511e-01, 9.0026e-01,
          4.4355e-01, 1.8225e-01, 8.4104e-01, 2.3228e-01, 8.0243e-01, 4.2564e-02,
          6.0276e-01, 5.1179e-01, 4.0776e-01, 4.0208e-02, 3.8869e-01, 6.7575e-01,
          4.5273e-01, 3.3064e-01, 1.3681e-01, 4.9038e-01, 6.4764e-01, 5.4506e-01,
          3.2953e-01, 6.9231e-01, 6.5872e-01, 9.0196e-01, 9.8145e-01, 6.2322e-01,
          5.4279e-01, 2.0333e-01, 9.3420e-01, 6.8019e-01, 6.0696e-01, 3.8711e-01,
          1.8040e-01, 3.6956e-01, 6.1600e-01, 5.9196e-01, 4.3761e-01, 5.0209e-01,
          4.7001e-01, 9.1031e-02, 2.8849e-01, 1.2570e-01, 8.7897e-01, 7.7505e-01,
          2.1794e-01, 3.2482e-01, 7.0190e-01, 1.4723e-01, 1.8510e-03, 3.8921e-02,
          8.4882e-02, 9.6002e-01, 4.3158e-01, 7.2317e-01, 3.3874e-01, 5.0219e-01,
          4.8820e-01, 7.0739e-01, 6.4521e-01, 8.2328e-01, 6.8012e-02, 5.9618e-01,
          9.0562e-01, 8.0369e-01, 7.8883e-01, 8.7944e-02, 3.1810e-01, 4.3308e-01,
          4.1438e-01, 4.5590e-01, 1.6325e-02, 7.5163e-01, 8.1975e-01, 5.2752e-01,
          6.8534e-01, 5.8764e-01, 3.8127e-01, 6.8877e-01, 7.3490e-01, 2.9773e-01,
          8.6387e-02, 1.5450e-04, 7.7004e-01, 8.8278e-01, 3.7338e-01, 7.6670e-01,
          5.2876e-01, 1.3236e-01, 2.1873e-01, 4.4841e-01, 7.9660e-01, 4.1513e-01,
          4.0437e-01, 2.7567e-02, 9.8496e-02, 7.4830e-02, 8.3625e-01, 3.7104e-01,
          3.7507e-01, 5.1112e-01, 5.6286e-01, 2.4002e-01, 1.7788e-01, 7.9530e-01,
          5.9706e-01, 8.7680e-01, 2.4964e-01, 1.8672e-01, 6.2742e-01, 1.8672e-01,
          4.8149e-01, 3.8974e-01, 6.1149e-01, 2.3530e-01, 4.3513e-01, 9.2844e-01,
          8.7355e-01, 5.8490e-01, 3.8869e-01, 2.9123e-01, 2.2444e-01, 2.4557e-01,
          1.8971e-01, 6.3291e-01, 5.7825e-01, 1.9538e-01, 3.2718e-01, 8.0991e-01,
          2.4614e-01, 9.6690e-01, 2.5721e-01, 6.6186e-01, 6.4650e-01, 6.5364e-01,
          7.3134e-01, 5.7764e-02, 8.1875e-01, 8.7189e-01, 9.3024e-01, 9.1922e-03,
          6.4661e-01, 5.7329e-01, 5.1283e-01, 9.2082e-01, 2.8085e-01, 7.1727e-01,
          5.5221e-01, 2.1619e-01, 8.8908e-01, 1.4249e-01, 9.8798e-01, 6.7706e-02,
          6.1414e-02, 9.0230e-01, 4.7719e-01, 5.9354e-01, 2.2006e-01, 1.7854e-01,
          5.4807e-01, 1.6212e-01, 4.8793e-01, 9.1171e-02, 2.7598e-01, 9.2543e-01,
          9.0900e-01, 6.2575e-01, 2.7230e-01, 4.3957e-01, 5.3926e-01, 1.6238e-01,
          2.3286e-01, 4.0653e-01, 8.2895e-01, 3.8395e-01, 3.6242e-02, 2.6546e-01,
          6.9696e-01, 5.6493e-01, 3.6387e-01, 2.0331e-01, 2.9618e-01, 3.1240e-01,
          8.0438e-01, 1.4411e-02, 2.0690e-01, 4.9828e-02, 7.4176e-01, 8.9414e-01,
          1.6232e-01, 6.9453e-01, 2.4496e-01, 7.2695e-01, 3.1371e-01, 1.6423e-01,
          8.2990e-01, 7.6511e-01, 1.1560e-01, 6.6207e-01, 5.7184e-01, 8.7853e-01,
          3.6006e-01, 7.6209e-01, 8.7367e-01, 1.3602e-01, 1.6703e-01, 9.3932e-01,
          5.9761e-01, 1.1540e-01, 4.2187e-01, 6.2786e-01, 8.2692e-03, 8.3559e-01,
          5.3440e-01, 9.2945e-01, 9.4611e-01, 2.0628e-01, 2.9009e-01, 1.2663e-01,
          5.1369e-01, 6.4294e-01, 2.8032e-01, 4.7652e-01, 8.6540e-02, 9.8866e-01,
          7.8649e-01, 1.0309e-01, 8.6168e-01, 1.3252e-01, 7.0204e-01, 3.6252e-01,
          2.4521e-03, 3.4949e-01, 9.1321e-01, 3.7368e-01, 9.5342e-01, 6.6693e-01,
          4.0180e-01, 9.0954e-01, 4.5349e-01, 1.8211e-01, 3.8974e-01, 8.7656e-01,
          8.2184e-01, 2.5087e-01, 2.0753e-01, 5.1800e-01, 9.0333e-01, 7.7790e-01,
          6.9512e-01, 9.7000e-01, 5.2832e-01, 7.4470e-01, 4.9449e-01, 3.9221e-01,
          4.1792e-01, 6.2521e-01, 3.7504e-01, 3.4006e-01, 9.4855e-01, 1.4327e-01,
          8.7429e-01, 4.5282e-01, 6.4084e-01, 3.1007e-01, 8.8488e-01, 8.3841e-01,
          4.0371e-01, 8.6133e-02, 4.1151e-02, 2.8639e-01, 8.4822e-01, 8.3192e-01,
          3.4426e-01, 8.6391e-01, 4.3782e-01, 8.5459e-01, 2.2976e-01, 5.6649e-01,
          5.7197e-01, 6.5351e-01, 6.4927e-01, 8.0816e-02, 3.5992e-01, 3.2558e-01,
          3.0448e-01, 5.6042e-01, 9.2897e-01, 7.4001e-01, 9.9393e-02, 9.1969e-01,
          8.1155e-01, 5.0638e-01, 9.9376e-01, 1.1867e-01, 4.3993e-01, 3.9479e-01,
          7.1838e-01, 4.5425e-01, 5.9763e-01, 2.3276e-01, 9.8227e-01, 5.1888e-01,
          3.8313e-01, 2.9809e-01, 7.3204e-01, 1.7975e-01, 7.7474e-01, 8.1722e-02,
          9.1943e-01, 6.2019e-02, 6.8040e-01, 4.1891e-02, 1.8909e-01, 3.6659e-01,
          1.5216e-01, 8.3889e-01, 6.6816e-01, 3.8773e-01, 1.7230e-01, 8.7945e-01,
          8.9937e-01, 9.9909e-01, 1.9301e-01, 8.4865e-01, 3.1516e-01, 6.4874e-01,
          2.0096e-01, 7.8183e-01, 8.3092e-01, 1.6649e-01, 5.2124e-01, 9.4041e-01,
          3.7401e-01, 4.5719e-01, 1.9411e-01, 8.7794e-01, 9.8365e-01, 8.0980e-01,
          4.3908e-01, 1.3309e-01, 1.7369e-01, 7.9279e-01, 1.8691e-01, 9.0799e-01,
          8.7665e-01, 4.9022e-01, 9.6228e-01, 4.8299e-01, 6.2248e-01, 6.7255e-01,
          9.3441e-01, 1.0349e-01, 2.4897e-02, 4.1212e-01, 6.3234e-01, 2.5786e-01,
          3.3592e-01, 5.9787e-01, 2.9734e-01, 1.5933e-01, 6.9656e-01, 3.3539e-01,
          6.0382e-01, 4.0489e-02, 9.4427e-01, 7.0399e-01, 9.7509e-01, 2.7535e-01,
          6.0003e-01, 2.1621e-01, 3.7047e-01, 2.9239e-01, 8.3524e-01, 2.8134e-01,
          6.4615e-01, 4.4557e-01]]),
 'product': torch.tensor([[0.5108, 0.7993, 0.8076, 0.5352, 0.6881, 0.1335, 0.8137, 0.4771, 0.3603,
          0.0385, 0.8816, 0.9965, 0.6874, 0.2078, 0.7019, 0.4928, 0.2177, 0.1466,
          0.5589, 0.7569, 0.5678, 0.7334, 0.6665, 0.5072, 0.0926, 0.5915, 0.3222,
          0.0314, 0.1826, 0.3329, 0.6625, 0.9319, 0.8918, 0.9692, 0.1958, 0.2499,
          0.2114, 0.1061, 0.3678, 0.6517, 0.3705, 0.2774, 0.1997, 0.1722, 0.4375,
          0.0560, 0.1735, 0.6355, 0.8052, 0.7221, 0.3530, 0.4905, 0.3837, 0.3587,
          0.1871, 0.8423, 0.0070, 0.5395, 0.7842, 0.7795, 0.9685, 0.4980, 0.0770,
          0.0802, 0.6175, 0.0924, 0.6124, 0.1134, 0.1233, 0.3998, 0.8811, 0.1805,
          0.1826, 0.8242, 0.0356, 0.6864, 0.7735, 0.5652, 0.7355, 0.8956, 0.1143,
          0.5132, 0.3144, 0.5012, 0.7095, 0.7166, 0.4243, 0.0551, 0.8649, 0.4434,
          0.9055, 0.9376, 0.0852, 0.0085, 0.3880, 0.5659, 0.0826, 0.2086, 0.6840,
          0.2655, 0.9274, 0.9774, 0.0433, 0.5743, 0.0316, 0.3999, 0.0362, 0.2083,
          0.3480, 0.4323, 0.3134, 0.7161, 0.5370, 0.3315, 0.8539, 0.0808, 0.5527,
          0.1707, 0.0943, 0.6615, 0.8986, 0.3749, 0.4048, 0.1727, 0.7543, 0.6848,
          0.5429, 0.1158, 0.1485, 0.4566, 0.3118, 0.8719, 0.8251, 0.1197, 0.4378,
          0.9718, 0.9142, 0.5751, 0.2795, 0.2541, 0.3485, 0.4654, 0.2810, 0.4617,
          0.9448, 0.6992, 0.2232, 0.1509, 0.9022, 0.8340, 0.5495, 0.8270, 0.9021,
          0.7553, 0.4747, 0.2773, 0.9246, 0.5678, 0.5317, 0.8261, 0.7073, 0.8788,
          0.6781, 0.6795, 0.6281, 0.5670, 0.2893, 0.6563, 0.5372, 0.7557, 0.0528,
          0.0132, 0.4193, 0.2834, 0.0314, 0.9816, 0.0946, 0.9042, 0.9212, 0.7647,
          0.7443, 0.2062, 0.1724, 0.1622, 0.9054, 0.9625, 0.6267, 0.0988, 0.4917,
          0.9187, 0.4653, 0.0321, 0.8561, 0.5118, 0.0064, 0.7592, 0.8740, 0.0782,
          0.1019, 0.4344, 0.8520, 0.4809, 0.0024, 0.7326, 0.2159, 0.9625, 0.8401,
          0.6395, 0.1883, 0.7577, 0.1200, 0.1415, 0.0111, 0.7296, 0.5919, 0.0056,
          0.3379, 0.2009, 0.1173, 0.1305, 0.9759, 0.1268, 0.6186, 0.4308, 0.6682,
          0.2766, 0.4535, 0.0437, 0.4206, 0.8386, 0.3333, 0.5122, 0.8233, 0.6864,
          0.2949, 0.3040, 0.6765, 0.2892, 0.2207, 0.0527, 0.4469, 0.4885, 0.3977,
          0.9217, 0.7529, 0.6436, 0.7286, 0.7513, 0.9788, 0.8396, 0.6648, 0.7731,
          0.5094, 0.8019, 0.8974, 0.5564, 0.9813, 0.7105, 0.0341, 0.4003, 0.3161,
          0.0876, 0.1631, 0.6662, 0.2070, 0.5212, 0.5279, 0.8915, 0.4153, 0.4827,
          0.5655, 0.9799, 0.0035, 0.1689, 0.1773, 0.9576, 0.7220, 0.7918, 0.1595,
          0.4797, 0.4315, 0.3080, 0.4786, 0.2525, 0.4901, 0.0867, 0.6684, 0.4075,
          0.3058, 0.2165, 0.9092, 0.2604, 0.9786, 0.7664, 0.2160, 0.8083, 0.2175,
          0.4576, 0.2821, 0.2615, 0.6921, 0.2060, 0.6388, 0.2936, 0.5950, 0.0716,
          0.3853, 0.0776, 0.0675, 0.6603, 0.4530, 0.9902, 0.9847, 0.3668, 0.1783,
          0.0186, 0.4156, 0.4535, 0.0346, 0.7022, 0.2839, 0.0680, 0.2701, 0.2465,
          0.5444, 0.9480, 0.3222, 0.6444, 0.2507, 0.2428, 0.3895, 0.8633, 0.4557,
          0.6273, 0.3195, 0.0766, 0.3592, 0.6771, 0.3942, 0.8866, 0.1908, 0.9203,
          0.1920, 0.5709, 0.1080, 0.4631, 0.0551, 0.9140, 0.9144, 0.3181, 0.1457,
          0.0767, 0.0625, 0.1011, 0.2256, 0.2607, 0.8497, 0.3363, 0.1066, 0.4194,
          0.2146, 0.4918, 0.1768, 0.5998, 0.9880, 0.7018, 0.3286, 0.8675, 0.4083,
          0.5989, 0.4296, 0.0787, 0.8425, 0.9654, 0.4749, 0.1796, 0.2707, 0.4878,
          0.9066, 0.6144, 0.3793, 0.9710, 0.2659, 0.1688, 0.4766, 0.9390, 0.9746,
          0.6612, 0.5210, 0.8174, 0.0817, 0.4480, 0.3723, 0.8497, 0.7554, 0.7120,
          0.0507, 0.1371, 0.1812, 0.5930, 0.1883, 0.5102, 0.2088, 0.9384, 0.9120,
          0.6411, 0.1707, 0.0505, 0.9953, 0.8755, 0.9990, 0.6896, 0.8225, 0.1086,
          0.8288, 0.1703, 0.8176, 0.8255, 0.7759, 0.8922, 0.7665, 0.8233, 0.9064,
          0.7186, 0.8493, 0.4159, 0.2685, 0.8269, 0.1548, 0.8270, 0.2821, 0.2721,
          0.4116, 0.7817, 0.1506, 0.6916, 0.8400, 0.9705, 0.7171, 0.3510, 0.7372,
          0.6776, 0.5449, 0.2135, 0.2012, 0.2945, 0.3784, 0.0802, 0.8924, 0.4532,
          0.1245, 0.4633, 0.7195, 0.0272, 0.5921, 0.9748, 0.0303, 0.5029, 0.7887,
          0.0270, 0.2720, 0.7939, 0.3954, 0.2153, 0.2720, 0.2541, 0.6063, 0.4117,
          0.9491, 0.8256, 0.6818, 0.8923, 0.6214, 0.6204, 0.6143, 0.4884, 0.5305,
          0.3288, 0.1548, 0.0552, 0.6427, 0.9119, 0.0478, 0.2553, 0.4950, 0.1361,
          0.1141, 0.4404, 0.5236, 0.5245, 0.0885, 0.9483, 0.4375, 0.5139, 0.3130,
          0.6574, 0.7260, 0.7809, 0.7781, 0.8147, 0.8539, 0.1322, 0.0999, 0.5371,
          0.2843, 0.3520, 0.5735, 0.5253, 0.0748, 0.8882, 0.9587, 0.4700]]),
 'This': torch.tensor([[0.2248, 0.4858, 0.0610, 0.0072, 0.2561, 0.7310, 0.6469, 0.6411, 0.4331,
          0.6399, 0.1656, 0.8470, 0.7747, 0.4729, 0.0839, 0.6056, 0.0808, 0.9941,
          0.7675, 0.1654, 0.0089, 0.3187, 0.4696, 0.7128, 0.2816, 0.2717, 0.5359,
          0.0413, 0.3954, 0.9961, 0.7721, 0.4449, 0.3185, 0.9173, 0.9555, 0.2784,
          0.6115, 0.8476, 0.6630, 0.7509, 0.2549, 0.4932, 0.2792, 0.0388, 0.2209,
          0.6950, 0.4984, 0.0718, 0.7660, 0.6148, 0.7144, 0.3288, 0.4364, 0.8557,
          0.1786, 0.0506, 0.9551, 0.5597, 0.3935, 0.2922, 0.5388, 0.5003, 0.3378,
          0.4471, 0.8627, 0.3171, 0.2250, 0.6224, 0.3919, 0.4558, 0.8042, 0.4875,
          0.9772, 0.8102, 0.6468, 0.9453, 0.6968, 0.5684, 0.5672, 0.8795, 0.0875,
          0.2888, 0.5695, 0.6052, 0.3755, 0.4399, 0.5742, 0.6131, 0.3240, 0.9283,
          0.0588, 0.4731, 0.2398, 0.6862, 0.8133, 0.1043, 0.5347, 0.1752, 0.7478,
          0.5152, 0.4210, 0.4048, 0.8671, 0.6170, 0.1705, 0.2661, 0.0923, 0.8033,
          0.1079, 0.6451, 0.1133, 0.9510, 0.6895, 0.2633, 0.9034, 0.8869, 0.9177,
          0.8553, 0.1942, 0.0898, 0.6227, 0.5198, 0.6848, 0.2466, 0.7018, 0.8542,
          0.8716, 0.0312, 0.6517, 0.0608, 0.2671, 0.8671, 0.9678, 0.7249, 0.7643,
          0.8332, 0.6799, 0.0362, 0.3332, 0.3639, 0.2020, 0.5066, 0.9132, 0.4054,
          0.0706, 0.3994, 0.9586, 0.5017, 0.3218, 0.3318, 0.3664, 0.9906, 0.5814,
          0.7940, 0.3585, 0.6299, 0.4868, 0.0491, 0.1595, 0.9084, 0.6982, 0.2737,
          0.5417, 0.7566, 0.3013, 0.3690, 0.3418, 0.3367, 0.4075, 0.2819, 0.7287,
          0.1821, 0.5974, 0.1119, 0.0850, 0.6994, 0.9105, 0.2528, 0.2019, 0.8415,
          0.2500, 0.0597, 0.0018, 0.6785, 0.1260, 0.6017, 0.0370, 0.7476, 0.6227,
          0.6628, 0.2159, 0.4813, 0.4184, 0.6922, 0.2708, 0.4281, 0.2135, 0.6178,
          0.7389, 0.9249, 0.7923, 0.0028, 0.0245, 0.6237, 0.3240, 0.2183, 0.5664,
          0.8640, 0.5382, 0.3356, 0.9940, 0.0656, 0.4319, 0.2010, 0.8589, 0.3636,
          0.2364, 0.2254, 0.6884, 0.1014, 0.2585, 0.2993, 0.0238, 0.7983, 0.8561,
          0.0887, 0.1851, 0.2525, 0.9084, 0.2051, 0.3051, 0.2598, 0.9531, 0.1940,
          0.4657, 0.8140, 0.4549, 0.2485, 0.6583, 0.6890, 0.5616, 0.3892, 0.3330,
          0.9846, 0.1303, 0.6202, 0.9521, 0.9663, 0.4305, 0.8687, 0.3405, 0.0993,
          0.7113, 0.9898, 0.4816, 0.9414, 0.2424, 0.9827, 0.3743, 0.3154, 0.7324,
          0.8021, 0.3423, 0.7664, 0.7707, 0.5037, 0.1781, 0.7815, 0.6518, 0.2870,
          0.0881, 0.4409, 0.9458, 0.7662, 0.3114, 0.1732, 0.9724, 0.1104, 0.4479,
          0.1833, 0.8241, 0.9095, 0.7091, 0.6992, 0.5029, 0.1564, 0.3943, 0.3606,
          0.5206, 0.1025, 0.8886, 0.2387, 0.3508, 0.7941, 0.0946, 0.9677, 0.2328,
          0.5115, 0.9790, 0.2842, 0.4362, 0.1328, 0.5611, 0.9872, 0.9939, 0.9684,
          0.7537, 0.0578, 0.4865, 0.3991, 0.4752, 0.4808, 0.0492, 0.6284, 0.1333,
          0.0776, 0.5015, 0.3144, 0.8844, 0.8478, 0.2169, 0.2225, 0.6585, 0.6769,
          0.0182, 0.5547, 0.1667, 0.9435, 0.6420, 0.6871, 0.3862, 0.3084, 0.3961,
          0.6315, 0.2913, 0.3155, 0.0387, 0.5785, 0.6469, 0.5161, 0.9279, 0.6305,
          0.8677, 0.2347, 0.2828, 0.5034, 0.7280, 0.2909, 0.4102, 0.8347, 0.9537,
          0.8334, 0.6135, 0.0560, 0.6534, 0.6173, 0.1142, 0.8485, 0.5866, 0.7655,
          0.8633, 0.4718, 0.3233, 0.0745, 0.8405, 0.6677, 0.3787, 0.0750, 0.7488,
          0.6179, 0.6625, 0.1893, 0.8404, 0.8891, 0.6379, 0.9141, 0.4587, 0.7793,
          0.1907, 0.6993, 0.1517, 0.7396, 0.7155, 0.6754, 0.5097, 0.5558, 0.2970,
          0.8431, 0.1765, 0.8117, 0.4339, 0.6287, 0.9208, 0.4443, 0.3363, 0.9762,
          0.8655, 0.0505, 0.0356, 0.3316, 0.5330, 0.8179, 0.9641, 0.4843, 0.9154,
          0.8642, 0.2807, 0.2142, 0.3661, 0.1049, 0.7550, 0.6954, 0.6302, 0.2846,
          0.7754, 0.2950, 0.3257, 0.6207, 0.9584, 0.7062, 0.8902, 0.2497, 0.8398,
          0.9634, 0.8687, 0.4274, 0.0893, 0.0214, 0.6714, 0.7177, 0.3361, 0.1432,
          0.3087, 0.0540, 0.7353, 0.1374, 0.1679, 0.1501, 0.1812, 0.2332, 0.0484,
          0.1901, 0.5976, 0.6068, 0.8271, 0.5582, 0.0960, 0.7159, 0.8361, 0.3635,
          0.3741, 0.3156, 0.1266, 0.9314, 0.6437, 0.3529, 0.7109, 0.6901, 0.9831,
          0.4754, 0.8258, 0.4818, 0.9382, 0.4374, 0.1871, 0.5238, 0.9499, 0.2324,
          0.9481, 0.3602, 0.0429, 0.1903, 0.2871, 0.8435, 0.7431, 0.5332, 0.9504,
          0.9968, 0.9339, 0.1620, 0.7451, 0.5944, 0.9011, 0.2341, 0.1349, 0.3809,
          0.6127, 0.5068, 0.5659, 0.4173, 0.3846, 0.1351, 0.6232, 0.4660, 0.3224,
          0.2920, 0.7953, 0.2064, 0.7803, 0.4722, 0.1676, 0.6890, 0.5277, 0.4861,
          0.0038, 0.4383, 0.4843, 0.4675, 0.3551, 0.7525, 0.4248, 0.6625]]),
 'is': torch.tensor([[0.0026, 0.8060, 0.7927, 0.0784, 0.9617, 0.3939, 0.1783, 0.6040, 0.0124,
          0.3951, 0.3400, 0.6844, 0.8172, 0.3113, 0.1920, 0.9447, 0.9491, 0.8553,
          0.1697, 0.2167, 0.7509, 0.8408, 0.4000, 0.6712, 0.5125, 0.5933, 0.9376,
          0.1690, 0.9770, 0.2244, 0.5283, 0.6171, 0.0650, 0.9045, 0.5266, 0.5392,
          0.5300, 0.8904, 0.4243, 0.2166, 0.6814, 0.3655, 0.9260, 0.7236, 0.8685,
          0.1579, 0.5437, 0.1139, 0.4532, 0.8535, 0.8615, 0.0641, 0.9755, 0.6644,
          0.7966, 0.4300, 0.1444, 0.9832, 0.7803, 0.7099, 0.1849, 0.6074, 0.2968,
          0.2390, 0.9256, 0.5090, 0.5335, 0.5682, 0.8080, 0.7092, 0.5111, 0.2262,
          0.6341, 0.8827, 0.6711, 0.5110, 0.9956, 0.3688, 0.0581, 0.9938, 0.4476,
          0.6758, 0.6678, 0.2662, 0.4976, 0.5128, 0.6120, 0.5594, 0.6962, 0.6923,
          0.9291, 0.8710, 0.2456, 0.2003, 0.4130, 0.0650, 0.2300, 0.9542, 0.1111,
          0.4843, 0.5536, 0.2206, 0.9171, 0.9699, 0.2396, 0.1387, 0.4855, 0.9057,
          0.5814, 0.2292, 0.6864, 0.8850, 0.4129, 0.7262, 0.7160, 0.9010, 0.7174,
          0.9225, 0.7992, 0.4639, 0.7217, 0.3371, 0.5990, 0.2081, 0.9401, 0.7332,
          0.8140, 0.9540, 0.2875, 0.1549, 0.9937, 0.2203, 0.7481, 0.3182, 0.5123,
          0.1044, 0.0720, 0.1171, 0.4433, 0.0567, 0.4334, 0.8900, 0.1516, 0.3533,
          0.7828, 0.3059, 0.8707, 0.4107, 0.3093, 0.2446, 0.3520, 0.6471, 0.1537,
          0.1814, 0.0406, 0.9700, 0.4989, 0.0611, 0.2780, 0.2569, 0.6084, 0.6924,
          0.1736, 0.9758, 0.4097, 0.2306, 0.9215, 0.6930, 0.7001, 0.0696, 0.2467,
          0.8027, 0.1521, 0.1544, 0.9697, 0.4959, 0.2810, 0.6840, 0.7083, 0.4961,
          0.0463, 0.2415, 0.3513, 0.0083, 0.1810, 0.8128, 0.1221, 0.8114, 0.8781,
          0.9532, 0.2621, 0.2273, 0.5627, 0.7622, 0.4741, 0.1379, 0.4840, 0.3044,
          0.3298, 0.6728, 0.8695, 0.4655, 0.9724, 0.3189, 0.8980, 0.9873, 0.8151,
          0.9559, 0.3048, 0.9843, 0.2951, 0.0903, 0.8958, 0.3684, 0.3172, 0.5664,
          0.6841, 0.2787, 0.7804, 0.9562, 0.6620, 0.7012, 0.7573, 0.4819, 0.8010,
          0.1517, 0.9660, 0.0963, 0.9154, 0.6214, 0.7684, 0.4629, 0.0792, 0.7194,
          0.3256, 0.6697, 0.7902, 0.9090, 0.4473, 0.6023, 0.2114, 0.5108, 0.0973,
          0.9818, 0.7036, 0.7295, 0.4361, 0.7554, 0.1709, 0.9639, 0.1931, 0.5587,
          0.2502, 0.3586, 0.3924, 0.9617, 0.0213, 0.0905, 0.2491, 0.6512, 0.1091,
          0.3362, 0.9086, 0.1172, 0.3735, 0.4711, 0.9003, 0.8515, 0.8931, 0.7978,
          0.6633, 0.8591, 0.0581, 0.2238, 0.0095, 0.9259, 0.0119, 0.6266, 0.9093,
          0.8172, 0.6359, 0.0285, 0.4835, 0.3080, 0.9961, 0.4698, 0.9505, 0.1868,
          0.8149, 0.8074, 0.3978, 0.4487, 0.0574, 0.2969, 0.8122, 0.8513, 0.9715,
          0.8798, 0.7094, 0.3759, 0.6082, 0.9613, 0.7087, 0.4967, 0.7506, 0.4810,
          0.8156, 0.2568, 0.3060, 0.4520, 0.6947, 0.3110, 0.5286, 0.0745, 0.9135,
          0.5526, 0.8481, 0.7803, 0.5323, 0.7076, 0.0488, 0.4188, 0.9791, 0.1409,
          0.0323, 0.2377, 0.4272, 0.0680, 0.4589, 0.0490, 0.8103, 0.1941, 0.4011,
          0.1293, 0.1687, 0.7422, 0.5197, 0.7510, 0.8835, 0.9634, 0.4551, 0.5655,
          0.7575, 0.7493, 0.8119, 0.2747, 0.5500, 0.6550, 0.8824, 0.1122, 0.4636,
          0.2849, 0.1003, 0.4689, 0.7917, 0.3727, 0.3250, 0.1077, 0.4403, 0.9268,
          0.2544, 0.3251, 0.9788, 0.0487, 0.9143, 0.2875, 0.0252, 0.8565, 0.8033,
          0.5525, 0.4248, 0.0762, 0.0527, 0.6032, 0.0013, 0.1244, 0.5527, 0.7574,
          0.9782, 0.2310, 0.0321, 0.3405, 0.1030, 0.9392, 0.7279, 0.2065, 0.9110,
          0.9081, 0.3298, 0.2100, 0.7566, 0.6685, 0.9885, 0.1330, 0.3973, 0.8027,
          0.8990, 0.4052, 0.5396, 0.5343, 0.4880, 0.7448, 0.5016, 0.2536, 0.5969,
          0.2527, 0.2544, 0.2929, 0.6525, 0.9329, 0.0826, 0.1138, 0.4918, 0.9369,
          0.0952, 0.0402, 0.7507, 0.1637, 0.9818, 0.6416, 0.3913, 0.9609, 0.9263,
          0.0754, 0.6493, 0.1704, 0.5153, 0.1942, 0.2214, 0.2341, 0.6186, 0.3534,
          0.3202, 0.2680, 0.6710, 0.2278, 0.6716, 0.2177, 0.3136, 0.3609, 0.4507,
          0.0464, 0.9182, 0.8906, 0.4243, 0.2050, 0.7149, 0.8892, 0.1937, 0.7806,
          0.9846, 0.9497, 0.2632, 0.1738, 0.3715, 0.1478, 0.0643, 0.4642, 0.1080,
          0.8413, 0.0684, 0.3270, 0.4696, 0.6147, 0.0384, 0.6408, 0.4687, 0.4358,
          0.4708, 0.7737, 0.5283, 0.3253, 0.8622, 0.3254, 0.7229, 0.5068, 0.4733,
          0.9962, 0.2806, 0.3127, 0.4211, 0.1972, 0.0972, 0.2623, 0.8703, 0.8119,
          0.3626, 0.4270, 0.0995, 0.2329, 0.8120, 0.2760, 0.8255, 0.9049, 0.0337,
          0.4108, 0.1214, 0.8381, 0.2178, 0.3745, 0.7456, 0.6328, 0.6773, 0.1513,
          0.3442, 0.4791, 0.2585, 0.7205, 0.0941, 0.9816, 0.9795, 0.8004]]),
 'terrible': torch.tensor([[0.5523, 0.4149, 0.0186, 0.2560, 0.2986, 0.0754, 0.0296, 0.5580, 0.8416,
          0.1429, 0.5841, 0.6495, 0.1802, 0.7146, 0.3612, 0.2959, 0.7153, 0.4348,
          0.0517, 0.2159, 0.7591, 0.3811, 0.0253, 0.8596, 0.3041, 0.7413, 0.8322,
          0.3177, 0.5817, 0.2567, 0.8583, 0.6845, 0.1650, 0.7577, 0.2362, 0.2992,
          0.1458, 0.2599, 0.7303, 0.6802, 0.5333, 0.5370, 0.9252, 0.8683, 0.2681,
          0.2785, 0.7669, 0.8508, 0.7444, 0.9648, 0.8511, 0.3789, 0.9473, 0.3583,
          0.1244, 0.5066, 0.3170, 0.3415, 0.2545, 0.4864, 0.7063, 0.8559, 0.3086,
          0.1759, 0.3919, 0.7776, 0.3229, 0.6943, 0.4784, 0.1552, 0.1232, 0.2487,
          0.5579, 0.0738, 0.8034, 0.5057, 0.2114, 0.0857, 0.6179, 0.3921, 0.7421,
          0.3360, 0.3470, 0.1018, 0.9532, 0.9754, 0.7582, 0.1067, 0.1988, 0.3016,
          0.9997, 0.3150, 0.0052, 0.0622, 0.7997, 0.1646, 0.8734, 0.8571, 0.3925,
          0.7243, 0.3955, 0.2342, 0.6386, 0.8146, 0.4044, 0.1372, 0.4087, 0.0556,
          0.2570, 0.6860, 0.8734, 0.5767, 0.3946, 0.5609, 0.6867, 0.8558, 0.7364,
          0.5127, 0.6551, 0.0336, 0.8470, 0.1011, 0.3193, 0.9532, 0.3510, 0.0291,
          0.6345, 0.8233, 0.9885, 0.8543, 0.6251, 0.4637, 0.3407, 0.1970, 0.4628,
          0.4314, 0.5429, 0.0200, 0.2057, 0.2699, 0.2603, 0.7731, 0.4188, 0.2349,
          0.3554, 0.7730, 0.6092, 0.8464, 0.8754, 0.0792, 0.0996, 0.8236, 0.1241,
          0.3484, 0.3874, 0.7867, 0.8613, 0.0679, 0.8443, 0.4321, 0.3383, 0.4456,
          0.6620, 0.5558, 0.8730, 0.6300, 0.4126, 0.8802, 0.9023, 0.8029, 0.2289,
          0.9616, 0.7171, 0.3805, 0.1273, 0.4762, 0.6380, 0.1074, 0.7116, 0.7568,
          0.9664, 0.4555, 0.4461, 0.2980, 0.5538, 0.9447, 0.4658, 0.5033, 0.7186,
          0.3986, 0.8468, 0.4561, 0.7683, 0.5046, 0.3105, 0.6091, 0.4414, 0.9287,
          0.3114, 0.5034, 0.4947, 0.0928, 0.5987, 0.9878, 0.7708, 0.2167, 0.3314,
          0.4791, 0.1962, 0.2097, 0.8418, 0.8454, 0.7949, 0.7993, 0.5623, 0.7008,
          0.5998, 0.6374, 0.7868, 0.7646, 0.2125, 0.2404, 0.6368, 0.1608, 0.4637,
          0.5356, 0.5560, 0.0533, 0.8781, 0.8861, 0.8839, 0.8157, 0.2287, 0.8870,
          0.3179, 0.6259, 0.8521, 0.2439, 0.4309, 0.3295, 0.0721, 0.6607, 0.1049,
          0.4648, 0.7879, 0.4765, 0.1751, 0.3233, 0.4474, 0.0408, 0.1717, 0.8864,
          0.7225, 0.8781, 0.7528, 0.6790, 0.4865, 0.6397, 0.4553, 0.6954, 0.3503,
          0.0279, 0.0370, 0.2373, 0.1352, 0.1834, 0.9798, 0.7214, 0.6153, 0.9547,
          0.8362, 0.1957, 0.8715, 0.4910, 0.3479, 0.7633, 0.7959, 0.0768, 0.6532,
          0.6756, 0.1833, 0.7756, 0.3782, 0.8885, 0.1601, 0.6248, 0.7914, 0.2447,
          0.9643, 0.1641, 0.1960, 0.3142, 0.8497, 0.7692, 0.4148, 0.5645, 0.2861,
          0.8301, 0.5275, 0.4069, 0.2255, 0.4588, 0.9161, 0.0212, 0.9594, 0.5608,
          0.2856, 0.4118, 0.5717, 0.4053, 0.3725, 0.6181, 0.8097, 0.2320, 0.6827,
          0.4311, 0.5850, 0.8171, 0.9098, 0.4531, 0.6049, 0.4428, 0.0786, 0.9915,
          0.4767, 0.9512, 0.9415, 0.0566, 0.2261, 0.1614, 0.5240, 0.2939, 0.1055,
          0.8272, 0.4679, 0.1289, 0.0561, 0.5763, 0.2815, 0.3537, 0.3647, 0.2792,
          0.6071, 0.6311, 0.2997, 0.3510, 0.5364, 0.6956, 0.5540, 0.5275, 0.6332,
          0.9663, 0.1408, 0.0908, 0.2963, 0.5781, 0.8933, 0.2211, 0.1359, 0.7311,
          0.6622, 0.6811, 0.7278, 0.4269, 0.4792, 0.3314, 0.9913, 0.4176, 0.8088,
          0.7123, 0.2219, 0.1081, 0.6701, 0.9128, 0.8742, 0.7799, 0.1087, 0.7602,
          0.4765, 0.5599, 0.3778, 0.7835, 0.7416, 0.4967, 0.8465, 0.6610, 0.0118,
          0.4012, 0.0807, 0.6654, 0.3300, 0.9447, 0.8444, 0.9952, 0.6089, 0.8506,
          0.0272, 0.0817, 0.9477, 0.6054, 0.5720, 0.1311, 0.2826, 0.2175, 0.9731,
          0.9307, 0.9070, 0.5332, 0.4569, 0.8323, 0.8524, 0.2591, 0.3923, 0.0861,
          0.3480, 0.5497, 0.9886, 0.7892, 0.0666, 0.7629, 0.3936, 0.9050, 0.5102,
          0.8146, 0.8673, 0.6551, 0.9025, 0.5316, 0.9759, 0.0890, 0.8204, 0.0063,
          0.2444, 0.6867, 0.9873, 0.8542, 0.7711, 0.2343, 0.0126, 0.7447, 0.0585,
          0.8468, 0.9256, 0.1251, 0.3194, 0.4010, 0.8175, 0.5524, 0.7526, 0.2585,
          0.1968, 0.2391, 0.4596, 0.7381, 0.4674, 0.9282, 0.7560, 0.9033, 0.1105,
          0.1124, 0.9345, 0.9601, 0.4350, 0.2009, 0.6215, 0.3788, 0.3866, 0.1059,
          0.6639, 0.2691, 0.9363, 0.0640, 0.3698, 0.5512, 0.2813, 0.4612, 0.3614,
          0.9617, 0.1098, 0.3573, 0.7476, 0.2778, 0.4446, 0.4800, 0.5695, 0.9733,
          0.7856, 0.5769, 0.2235, 0.8669, 0.8972, 0.4727, 0.6187, 0.7987, 0.5810,
          0.4046, 0.0150, 0.8423, 0.1119, 0.4724, 0.0362, 0.7425, 0.2532, 0.9425,
          0.8911, 0.6786, 0.6890, 0.4385, 0.3932, 0.2724, 0.5307, 0.5135]]),
 'Could': torch.tensor([[1.8315e-01, 1.4335e-01, 1.3423e-01, 6.6931e-01, 7.5425e-01, 2.9169e-01,
          4.2456e-01, 9.4742e-01, 7.6759e-01, 2.8962e-01, 1.0699e-01, 5.4077e-01,
          8.2682e-01, 5.3464e-01, 8.0377e-01, 9.0830e-01, 1.7134e-01, 4.0039e-01,
          5.4113e-01, 1.8482e-01, 5.0454e-01, 8.0079e-01, 3.5137e-01, 8.3197e-01,
          8.1924e-02, 1.4373e-01, 8.5260e-01, 3.4334e-01, 2.5339e-01, 8.6552e-01,
          1.6757e-01, 5.9924e-01, 7.8876e-02, 7.1025e-01, 9.0658e-02, 5.4242e-01,
          6.1267e-02, 7.3266e-01, 2.3330e-01, 8.9600e-01, 6.8027e-01, 4.1855e-01,
          1.1461e-01, 8.8274e-01, 3.1031e-02, 6.0669e-01, 6.1448e-01, 3.9792e-01,
          4.0767e-02, 5.6953e-01, 4.5928e-01, 6.0103e-01, 9.0842e-02, 2.1534e-01,
          5.8592e-01, 1.7863e-02, 6.0003e-01, 6.5383e-01, 6.6990e-01, 8.0898e-01,
          6.2625e-01, 1.7709e-01, 9.9022e-02, 9.3180e-03, 1.5291e-02, 1.5776e-01,
          9.0039e-04, 1.7730e-01, 1.2565e-01, 7.0286e-01, 1.4087e-01, 3.7857e-01,
          5.1320e-01, 5.7794e-01, 2.9110e-01, 6.8667e-01, 9.8989e-01, 6.1238e-01,
          3.1135e-02, 6.4720e-02, 2.9514e-01, 8.8598e-01, 4.0626e-01, 2.6915e-01,
          5.2302e-01, 2.6997e-02, 2.8319e-01, 2.0869e-01, 8.6057e-01, 6.6321e-02,
          9.4918e-01, 1.4964e-01, 3.3833e-01, 7.7162e-01, 2.9781e-01, 7.0799e-01,
          9.1179e-01, 3.6024e-01, 7.7561e-01, 2.7175e-01, 2.0056e-01, 5.2111e-01,
          8.9640e-02, 4.9417e-01, 9.3660e-02, 7.0906e-01, 1.5799e-01, 7.0432e-01,
          6.0949e-01, 9.3324e-01, 1.7197e-01, 9.5500e-01, 1.5667e-01, 1.2106e-01,
          2.2878e-01, 1.6822e-01, 8.8953e-02, 6.2613e-01, 3.0348e-01, 7.7485e-01,
          3.2701e-01, 6.9710e-01, 9.2204e-03, 3.5730e-01, 8.3131e-01, 4.4178e-01,
          7.4831e-01, 7.0168e-01, 1.4445e-01, 8.5451e-01, 2.0448e-01, 1.0934e-01,
          9.2108e-01, 9.7882e-01, 7.1854e-01, 9.9592e-01, 7.7601e-01, 7.4920e-01,
          4.9649e-02, 1.6668e-01, 2.7702e-01, 4.4802e-01, 6.4391e-02, 5.6904e-01,
          4.9359e-01, 3.2791e-01, 3.8297e-01, 4.0255e-01, 7.9060e-01, 1.7349e-01,
          5.5629e-01, 2.6756e-01, 2.3608e-01, 7.2974e-03, 8.4451e-02, 8.5660e-01,
          9.7647e-01, 5.1044e-01, 2.1751e-01, 2.2804e-01, 4.2032e-01, 9.0452e-01,
          1.2030e-01, 1.4719e-01, 1.3199e-01, 7.2010e-01, 4.1696e-01, 9.1149e-01,
          2.9796e-01, 6.0440e-01, 5.6533e-01, 4.0515e-01, 3.0983e-01, 4.5693e-01,
          3.4566e-01, 7.3905e-01, 2.0152e-01, 1.2233e-01, 3.9534e-01, 7.4433e-01,
          5.7207e-01, 7.3077e-01, 8.2748e-01, 9.2568e-01, 1.8322e-01, 4.0016e-01,
          8.8206e-01, 8.2758e-01, 8.8241e-01, 8.3885e-01, 8.6631e-01, 6.5085e-01,
          3.0178e-01, 2.4533e-01, 6.2018e-02, 7.3480e-01, 6.0241e-01, 6.2882e-01,
          3.7189e-02, 1.4155e-01, 2.5188e-01, 4.5747e-01, 7.6640e-01, 8.6379e-01,
          5.0352e-01, 6.9164e-01, 4.4640e-01, 8.3796e-01, 2.3279e-02, 2.4266e-01,
          1.5911e-01, 2.9448e-01, 4.3712e-01, 5.8846e-01, 8.3390e-01, 6.0577e-01,
          5.6659e-02, 7.3884e-01, 1.1563e-01, 9.9946e-01, 5.8747e-01, 7.4030e-01,
          9.7608e-01, 9.6837e-01, 8.1035e-01, 5.6657e-02, 8.2303e-01, 3.4613e-01,
          5.6750e-01, 5.6155e-01, 5.9419e-01, 8.6684e-01, 2.7618e-01, 9.4254e-01,
          4.1192e-01, 1.4546e-01, 8.6641e-01, 9.4435e-01, 6.3321e-01, 5.3931e-01,
          4.1183e-01, 9.0797e-01, 5.7037e-01, 3.2724e-01, 9.5367e-02, 3.2045e-01,
          5.6791e-01, 1.0936e-01, 9.1669e-01, 2.4321e-01, 3.0326e-01, 5.0873e-01,
          4.5896e-01, 9.6857e-01, 7.8476e-01, 4.4804e-01, 1.2176e-01, 2.9214e-01,
          5.3163e-01, 5.2443e-01, 2.9687e-02, 2.1608e-01, 6.1559e-01, 7.1864e-01,
          6.3347e-01, 2.8179e-01, 4.5725e-01, 4.2837e-02, 3.4421e-01, 5.8743e-01,
          2.1679e-01, 1.8272e-01, 1.9484e-01, 8.8063e-01, 3.1435e-01, 2.3461e-01,
          7.9409e-01, 1.2472e-01, 7.8094e-01, 6.8672e-02, 7.2276e-01, 5.5303e-01,
          5.1648e-01, 2.8736e-01, 1.8416e-01, 2.4727e-01, 1.8933e-01, 9.1345e-01,
          6.8520e-01, 4.4169e-01, 9.3804e-01, 5.5995e-01, 7.0857e-01, 8.6889e-01,
          1.3858e-01, 9.7022e-01, 4.0966e-02, 5.2702e-01, 5.3067e-01, 9.8568e-01,
          1.3031e-01, 7.6696e-01, 4.3392e-01, 3.9288e-01, 4.7292e-01, 9.5555e-01,
          2.2223e-01, 6.2474e-01, 2.4762e-01, 9.7356e-01, 5.2980e-01, 1.8349e-01,
          1.3356e-01, 2.9514e-01, 7.2618e-02, 3.1376e-01, 2.9834e-01, 3.3448e-01,
          4.5164e-01, 1.4113e-01, 3.0853e-02, 2.3807e-01, 6.9196e-01, 8.7289e-01,
          2.0844e-01, 2.6816e-01, 5.7085e-01, 5.5636e-01, 1.9644e-02, 1.8733e-01,
          2.7854e-01, 1.1890e-01, 3.9792e-01, 6.2062e-02, 6.1977e-01, 6.9529e-01,
          6.3978e-02, 9.5725e-02, 8.1335e-01, 2.0878e-01, 5.5215e-01, 8.1974e-01,
          4.6905e-01, 4.7244e-01, 9.1417e-01, 1.2908e-01, 8.2139e-01, 4.3964e-01,
          5.0774e-02, 4.2454e-01, 2.0171e-01, 5.5387e-02, 6.4700e-01, 5.5733e-01,
          9.3077e-01, 3.8415e-01, 4.6993e-01, 6.2185e-01, 8.2636e-01, 1.4217e-01,
          5.6464e-01, 4.8635e-01, 5.7223e-02, 2.7939e-01, 2.0088e-01, 2.2114e-01,
          5.4843e-01, 9.5278e-02, 6.0618e-01, 1.0707e-01, 4.3344e-02, 1.4622e-01,
          3.8450e-01, 5.5555e-01, 1.9274e-01, 8.6633e-01, 4.9804e-02, 1.0685e-02,
          4.3917e-02, 7.4434e-01, 1.2755e-01, 9.4103e-01, 6.5851e-01, 5.8306e-01,
          4.4620e-01, 5.3446e-02, 6.3152e-01, 5.0199e-02, 3.0937e-01, 4.4472e-01,
          2.9813e-01, 4.1765e-01, 3.5968e-01, 1.3808e-01, 7.1118e-01, 3.8625e-01,
          5.0031e-01, 6.2872e-01, 4.1631e-01, 6.0984e-01, 1.2110e-02, 6.8720e-01,
          9.6986e-01, 8.6093e-01, 1.4428e-01, 9.9231e-01, 9.8456e-01, 2.6723e-01,
          9.6078e-01, 5.8336e-01, 2.8123e-02, 7.8642e-01, 2.0578e-01, 5.5014e-01,
          4.7842e-01, 7.6994e-02, 2.9309e-01, 8.2450e-01, 1.5575e-01, 6.8049e-01,
          7.7765e-02, 1.4121e-02, 4.2263e-01, 6.0621e-01, 1.5737e-01, 4.2721e-01,
          1.4188e-01, 8.2326e-01, 2.5639e-01, 7.2983e-01, 4.7657e-01, 4.9735e-01,
          3.3352e-01, 9.8745e-02, 6.8606e-01, 5.8307e-01, 3.5501e-01, 9.6240e-02,
          5.6762e-01, 7.8193e-03, 1.1542e-01, 3.7574e-01, 4.5057e-01, 8.8703e-01,
          8.5080e-01, 4.0875e-01, 1.7885e-01, 6.3066e-01, 1.2342e-01, 9.8923e-01,
          8.7586e-01, 1.9746e-01, 6.3018e-01, 9.8011e-02, 7.1228e-01, 3.9281e-01,
          8.5504e-01, 2.7517e-01, 3.3320e-01, 9.0323e-01, 9.5839e-01, 4.0996e-01,
          5.0806e-01, 7.1636e-01, 9.1463e-01, 9.4709e-01, 5.9514e-01, 4.5230e-01,
          2.7855e-01, 1.9598e-01, 5.6367e-01, 4.6875e-01, 7.2862e-01, 2.1958e-01,
          3.5183e-01, 7.3227e-01, 7.2989e-02, 1.0181e-01, 2.4360e-01, 9.8775e-01,
          9.4699e-01, 7.4495e-01, 6.7069e-01, 2.9204e-01, 8.7885e-01, 2.1214e-01,
          7.5154e-01, 5.2983e-01, 4.3451e-01, 6.1163e-01, 4.0658e-01, 4.2787e-01,
          2.6285e-01, 9.1572e-01, 7.0580e-02, 5.2668e-02, 8.3777e-01, 8.9616e-02,
          5.3810e-01, 4.5461e-01, 1.7230e-01, 5.4549e-01, 7.3666e-02, 8.8637e-01,
          3.8374e-01, 1.0186e-01, 1.4126e-01, 1.8823e-01, 7.3010e-01, 5.0990e-01,
          5.5106e-01, 9.5073e-01]]),
 'be': torch.tensor([[2.3670e-01, 5.4039e-01, 1.8671e-01, 1.6731e-01, 7.9639e-01, 2.6262e-01,
          7.3181e-01, 4.8283e-01, 1.8756e-01, 1.5180e-01, 1.2770e-01, 2.3363e-01,
          7.3854e-02, 5.4801e-01, 8.4638e-01, 5.1149e-01, 5.0831e-01, 1.0324e-01,
          1.6141e-01, 3.4296e-01, 1.6707e-01, 5.6221e-01, 4.1129e-01, 2.5523e-01,
          6.9992e-02, 2.0117e-01, 1.8602e-01, 2.7343e-01, 6.0633e-01, 4.4436e-01,
          1.0580e-01, 2.4610e-01, 4.3793e-03, 2.7839e-01, 6.8102e-01, 5.4801e-01,
          6.2858e-01, 6.9363e-01, 1.1529e-01, 8.8759e-01, 7.2293e-01, 8.6286e-01,
          9.0510e-01, 7.8900e-01, 4.5291e-01, 4.4695e-01, 5.9421e-01, 3.1457e-01,
          3.2365e-01, 3.4519e-01, 3.5492e-01, 4.3075e-02, 1.6437e-01, 7.8469e-01,
          3.9442e-01, 8.5061e-01, 9.9239e-02, 1.2494e-01, 5.6962e-01, 5.9411e-01,
          1.5430e-01, 1.9773e-01, 5.6802e-01, 1.1105e-01, 8.0199e-01, 5.8486e-01,
          5.3171e-01, 3.6655e-01, 1.1478e-01, 1.5841e-01, 6.3663e-01, 5.7232e-01,
          1.0294e-01, 5.5214e-01, 7.6530e-02, 5.3540e-01, 4.6347e-01, 3.8808e-01,
          9.1357e-01, 6.9005e-01, 5.9859e-01, 6.5145e-01, 5.2661e-01, 5.1065e-01,
          5.5542e-01, 9.5399e-01, 3.9415e-01, 4.1897e-01, 2.9082e-01, 8.6932e-01,
          8.7861e-01, 8.0508e-01, 5.2330e-01, 7.1564e-01, 9.6320e-01, 1.5375e-01,
          5.9223e-01, 7.8640e-01, 1.8979e-01, 3.6913e-01, 2.9473e-01, 3.2348e-01,
          9.1928e-03, 7.8671e-01, 7.9390e-01, 9.4709e-01, 6.5192e-02, 6.4159e-01,
          5.2667e-01, 4.9923e-01, 3.8061e-01, 5.5553e-01, 2.9026e-01, 9.0439e-01,
          7.9522e-01, 4.8535e-01, 9.3005e-02, 2.7955e-01, 7.5278e-01, 5.4938e-01,
          3.7337e-01, 6.6962e-01, 3.1800e-01, 9.6579e-01, 7.2812e-01, 8.0604e-01,
          2.8311e-01, 9.7377e-01, 8.6975e-01, 8.9171e-01, 2.8009e-02, 6.8996e-02,
          8.5847e-01, 1.1506e-01, 3.2203e-01, 8.1657e-01, 5.7842e-01, 7.4345e-04,
          7.9115e-01, 4.0778e-01, 1.1942e-01, 7.0644e-01, 5.4568e-02, 9.9701e-01,
          3.6030e-01, 6.6035e-01, 1.9727e-01, 2.9321e-01, 3.0453e-01, 9.7555e-02,
          1.9450e-01, 5.5662e-01, 9.5359e-01, 7.5396e-01, 3.0008e-01, 3.1026e-01,
          3.8540e-01, 7.9894e-01, 3.7984e-01, 4.0913e-01, 1.5873e-01, 2.4505e-01,
          2.1317e-02, 8.1921e-01, 8.7914e-01, 2.1364e-01, 6.5121e-01, 7.8324e-01,
          8.2043e-02, 2.5612e-01, 1.0654e-01, 5.4961e-01, 2.4949e-01, 9.8915e-01,
          1.3948e-01, 5.3412e-01, 5.1260e-01, 3.7399e-01, 8.3139e-01, 4.0651e-01,
          1.0435e-01, 8.9451e-01, 5.1603e-01, 1.3237e-02, 1.7359e-01, 8.3092e-01,
          9.5511e-02, 6.0618e-01, 9.3897e-01, 6.2733e-01, 2.1152e-01, 6.9296e-01,
          2.4786e-02, 4.5604e-01, 9.1465e-01, 7.0758e-01, 9.6912e-01, 8.0761e-01,
          6.5817e-01, 8.6185e-01, 1.2089e-01, 5.9195e-01, 2.1969e-01, 6.7452e-01,
          4.7280e-01, 7.6544e-01, 8.6723e-01, 3.7344e-01, 8.5775e-01, 8.6815e-01,
          4.3084e-02, 4.7895e-01, 3.7985e-01, 3.3923e-01, 6.8319e-01, 1.6385e-01,
          2.6285e-01, 6.4713e-01, 1.3464e-01, 7.2628e-01, 2.7010e-02, 6.1721e-02,
          9.3144e-01, 6.4608e-01, 3.6566e-01, 4.7268e-01, 5.1735e-01, 4.9640e-01,
          9.2912e-01, 2.0492e-01, 3.0646e-01, 6.6053e-02, 1.8480e-01, 1.6831e-01,
          4.2105e-01, 6.8707e-01, 6.6561e-01, 6.0137e-01, 8.9260e-01, 8.4663e-01,
          5.0013e-01, 9.4788e-01, 3.8894e-01, 1.4687e-01, 9.7480e-01, 5.5124e-01,
          9.0240e-01, 8.2076e-01, 8.5825e-02, 1.5270e-01, 6.4390e-01, 7.3355e-02,
          9.3262e-01, 7.6904e-01, 2.6122e-01, 4.1780e-01, 4.7368e-01, 9.2188e-01,
          7.4379e-01, 1.3988e-01, 6.5153e-01, 1.9837e-01, 1.1052e-01, 4.8182e-01,
          3.5069e-01, 6.9691e-01, 7.7636e-01, 1.6274e-01, 3.6227e-02, 4.3465e-01,
          7.8809e-01, 1.8318e-01, 8.1879e-01, 3.9711e-01, 2.1820e-01, 7.3773e-01,
          7.9896e-01, 3.0650e-01, 8.2824e-01, 1.7078e-01, 2.0887e-01, 7.7808e-01,
          7.6628e-01, 5.8273e-01, 8.2975e-01, 2.8141e-01, 3.7932e-01, 2.5913e-01,
          2.2752e-01, 2.3629e-01, 3.4556e-01, 9.2507e-01, 8.1931e-01, 2.0856e-01,
          1.5541e-01, 9.8083e-01, 4.9594e-01, 3.7695e-01, 5.8577e-01, 6.6709e-01,
          6.9031e-01, 8.4231e-01, 6.4237e-01, 5.9297e-01, 7.9153e-01, 7.9495e-01,
          4.3953e-01, 3.3963e-01, 4.6112e-01, 3.0420e-01, 1.0080e-01, 8.1773e-01,
          9.3412e-01, 1.2003e-01, 3.0854e-01, 3.1477e-01, 9.3437e-01, 9.9380e-01,
          2.9469e-01, 8.3573e-01, 7.0677e-02, 6.6623e-01, 5.7149e-01, 5.9718e-01,
          9.1316e-01, 3.6173e-01, 5.9218e-01, 6.5720e-01, 4.2252e-01, 8.8756e-01,
          1.3740e-01, 4.2539e-01, 2.9307e-01, 1.8683e-01, 5.9306e-01, 2.4726e-01,
          2.2762e-01, 3.6923e-01, 2.2558e-01, 2.5332e-01, 9.3432e-02, 7.8535e-01,
          8.7720e-01, 4.4901e-01, 7.6483e-01, 5.2797e-01, 7.4158e-01, 1.7970e-01,
          9.4767e-03, 7.4188e-01, 6.6606e-01, 1.8728e-03, 6.3040e-01, 9.7156e-01,
          5.4220e-01, 8.6222e-01, 1.3260e-01, 5.8290e-01, 2.3067e-01, 1.6388e-01,
          6.1174e-01, 9.3900e-01, 3.6382e-01, 4.2095e-01, 7.6293e-01, 1.7053e-04,
          9.9200e-01, 5.3723e-01, 5.8092e-01, 7.3010e-01, 5.4661e-01, 2.8859e-01,
          1.6088e-01, 7.7254e-01, 6.4357e-01, 8.9507e-01, 5.2335e-01, 2.1550e-01,
          3.3733e-01, 7.0912e-01, 7.4025e-01, 9.6961e-01, 6.5577e-01, 6.7932e-01,
          7.0136e-01, 5.2214e-02, 4.5345e-01, 4.8053e-01, 7.6410e-01, 4.6067e-01,
          2.3452e-01, 2.6659e-01, 7.0379e-01, 2.9437e-01, 2.5682e-01, 6.2237e-02,
          8.2386e-01, 6.1276e-01, 3.4465e-01, 8.3691e-02, 5.3226e-03, 3.7659e-01,
          9.5102e-01, 8.2916e-01, 2.2256e-01, 2.0728e-01, 4.7022e-01, 2.5459e-01,
          5.7350e-01, 8.9763e-01, 2.3528e-01, 3.9866e-01, 1.8435e-01, 9.9828e-01,
          3.3876e-01, 5.7065e-01, 8.6727e-01, 1.5538e-01, 8.2578e-01, 7.0414e-01,
          2.4006e-01, 8.3892e-01, 1.1168e-01, 3.5076e-01, 7.9713e-01, 9.4369e-01,
          9.7720e-01, 8.4135e-01, 7.6388e-01, 2.9216e-01, 4.8565e-01, 7.4745e-01,
          5.8242e-01, 3.7250e-01, 8.8763e-01, 2.3169e-02, 4.4079e-01, 3.4495e-01,
          6.5017e-01, 3.9595e-01, 7.8794e-02, 9.2598e-01, 4.3806e-01, 2.3802e-01,
          5.0534e-01, 9.2207e-01, 2.2440e-01, 3.2228e-01, 5.2705e-01, 8.7181e-01,
          2.8317e-01, 9.8508e-02, 7.9837e-01, 1.3869e-01, 3.4256e-01, 2.1443e-01,
          2.4072e-02, 5.8238e-01, 4.1783e-01, 5.4162e-02, 6.1189e-01, 5.9102e-01,
          8.0321e-01, 8.9016e-01, 2.8153e-02, 9.6427e-01, 6.7950e-01, 3.4274e-01,
          2.5349e-01, 5.8568e-01, 7.4132e-01, 3.1836e-01, 5.2502e-01, 9.5677e-01,
          1.2817e-01, 8.8003e-02, 9.4131e-01, 4.3640e-02, 9.3198e-01, 2.6271e-01,
          3.2884e-01, 1.5940e-01, 1.1473e-02, 2.9703e-01, 6.6450e-01, 1.8533e-01,
          4.2674e-01, 2.6263e-01, 3.5945e-01, 9.2688e-01, 3.9515e-01, 9.9599e-01,
          2.9101e-01, 6.2498e-01, 8.9213e-01, 5.0406e-01, 2.8322e-02, 6.5236e-01,
          6.5454e-01, 4.1009e-01, 3.3242e-01, 4.6601e-01, 9.6356e-01, 9.8911e-01,
          3.3081e-01, 9.7387e-01, 8.0590e-01, 1.9728e-01, 2.1803e-02, 3.7933e-01,
          8.8971e-01, 8.7636e-01]]),
 'better': torch.tensor([[1.8799e-04, 8.8830e-01, 7.1639e-01, 2.0368e-01, 8.3557e-01, 3.6569e-01,
          4.4396e-01, 4.0851e-01, 8.3497e-01, 1.2631e-01, 6.1921e-01, 1.4682e-01,
          6.7209e-01, 1.9778e-01, 8.3026e-02, 8.9255e-01, 9.2253e-02, 8.2438e-01,
          7.3378e-01, 2.2685e-01, 3.8100e-01, 5.5563e-01, 2.9709e-01, 9.1398e-02,
          7.4414e-01, 3.9677e-01, 9.1266e-01, 9.6381e-01, 6.1393e-01, 4.2584e-01,
          3.4397e-01, 8.6522e-01, 1.8279e-01, 9.6557e-01, 2.6892e-01, 9.4602e-01,
          6.7064e-01, 7.3725e-01, 9.3584e-01, 6.2849e-01, 4.1063e-01, 6.0514e-01,
          6.3408e-01, 5.0963e-01, 7.8491e-01, 7.2035e-01, 7.8605e-01, 5.1879e-01,
          4.9117e-01, 9.8313e-01, 1.0941e-01, 6.6574e-01, 1.5103e-01, 7.5916e-01,
          4.2121e-02, 9.8984e-01, 3.9655e-01, 6.0979e-01, 1.7478e-01, 3.1363e-01,
          6.9451e-01, 1.2161e-01, 9.5027e-02, 6.3173e-01, 8.8403e-01, 4.3366e-01,
          4.9093e-01, 4.3212e-01, 8.8397e-01, 6.3854e-01, 1.6027e-02, 2.6789e-02,
          2.7130e-01, 3.4433e-01, 5.6905e-01, 3.2440e-01, 6.4769e-01, 2.9197e-01,
          3.1945e-01, 9.9377e-01, 8.6825e-01, 4.7233e-01, 5.4387e-01, 4.0567e-02,
          7.5986e-01, 8.3091e-01, 5.9239e-01, 2.6495e-01, 7.3758e-01, 6.2833e-01,
          6.5512e-01, 8.9643e-01, 9.9598e-01, 8.4977e-01, 7.8660e-02, 6.2438e-01,
          4.4268e-01, 7.0010e-01, 7.5314e-01, 1.0772e-01, 8.3156e-02, 9.1629e-01,
          6.1441e-01, 3.5631e-01, 3.6253e-01, 6.2755e-01, 5.2898e-01, 4.5735e-01,
          7.4918e-01, 2.8643e-02, 1.1880e-01, 7.3441e-01, 3.1034e-01, 7.5804e-01,
          1.6345e-01, 6.3725e-01, 5.7222e-01, 7.3018e-01, 5.4246e-01, 8.9740e-01,
          7.3208e-01, 5.6601e-01, 1.1988e-01, 9.3616e-01, 4.4160e-01, 4.5548e-01,
          7.3660e-01, 7.6505e-02, 5.4666e-01, 1.4484e-01, 6.8556e-01, 8.9157e-01,
          9.2987e-01, 1.3964e-01, 9.3235e-01, 4.1983e-01, 7.7941e-01, 8.8467e-01,
          6.4536e-01, 7.2976e-01, 1.3020e-01, 9.2981e-01, 5.3447e-01, 9.2528e-01,
          3.4156e-01, 5.1741e-01, 9.2589e-01, 1.0837e-01, 1.5575e-01, 1.4328e-01,
          8.2225e-01, 5.3636e-01, 9.8177e-01, 8.1018e-01, 5.3635e-02, 1.7573e-01,
          6.6304e-01, 5.0216e-01, 7.1925e-01, 1.4152e-01, 3.4714e-01, 7.9097e-01,
          1.5162e-01, 3.4560e-01, 5.5822e-01, 8.6598e-01, 4.7696e-01, 4.4563e-01,
          2.1352e-01, 2.5788e-01, 9.0245e-01, 3.9910e-01, 3.0791e-01, 8.4680e-01,
          4.9957e-01, 9.0592e-01, 3.0672e-01, 8.6502e-01, 6.5758e-01, 3.1964e-01,
          4.4117e-01, 3.0612e-01, 2.8969e-01, 4.5305e-02, 1.0965e-01, 3.1256e-01,
          7.7762e-01, 9.0737e-01, 7.5546e-01, 1.1966e-01, 3.9014e-01, 8.2390e-01,
          6.8710e-01, 9.9723e-01, 7.9436e-02, 7.2092e-01, 1.1885e-01, 7.7970e-01,
          2.5552e-01, 4.4172e-01, 2.3300e-01, 5.4899e-01, 8.4856e-01, 2.4643e-01,
          7.8712e-01, 4.6303e-01, 9.6418e-01, 8.8789e-01, 3.0197e-01, 9.1570e-01,
          7.6953e-01, 4.8789e-01, 8.2756e-01, 6.3126e-01, 6.1179e-01, 1.7929e-01,
          3.6435e-01, 6.3253e-02, 1.2810e-01, 2.9215e-01, 5.1530e-01, 7.4945e-01,
          7.9752e-01, 9.8014e-01, 5.2573e-01, 1.1157e-01, 2.7269e-01, 9.6377e-01,
          7.4296e-01, 8.7736e-01, 7.8000e-01, 1.6383e-01, 5.0924e-01, 3.2282e-01,
          6.6724e-02, 3.9554e-01, 4.5895e-01, 9.1337e-01, 2.3756e-02, 1.1680e-01,
          7.9834e-02, 3.3328e-01, 1.9752e-01, 9.5524e-01, 1.4900e-01, 8.8468e-01,
          2.8686e-01, 1.8723e-01, 5.5425e-01, 3.2350e-01, 2.5904e-01, 6.0857e-01,
          3.8574e-01, 5.4817e-01, 3.9796e-01, 7.9089e-01, 4.2267e-01, 6.6049e-01,
          8.2999e-01, 6.1717e-01, 9.5832e-01, 1.1569e-02, 4.9729e-01, 1.0839e-01,
          7.6477e-01, 5.3476e-01, 5.4000e-01, 6.1244e-01, 6.3417e-01, 7.1225e-01,
          5.7760e-01, 5.8739e-01, 5.8829e-01, 3.1390e-01, 6.3466e-01, 1.1472e-01,
          9.0907e-01, 2.9409e-01, 4.7632e-01, 3.1554e-01, 8.7357e-01, 9.5039e-01,
          3.6984e-02, 2.1353e-01, 4.6423e-01, 1.6565e-01, 9.8433e-01, 4.8427e-01,
          8.9360e-01, 8.6218e-01, 8.7069e-01, 6.5018e-01, 9.4866e-01, 8.3375e-01,
          1.7750e-01, 4.0198e-01, 5.1278e-01, 3.7021e-01, 9.5383e-01, 8.8722e-01,
          7.3088e-01, 2.2775e-01, 4.2940e-01, 1.1942e-01, 6.2834e-01, 2.0634e-01,
          7.9457e-01, 2.3176e-01, 1.4220e-01, 3.8714e-01, 5.0365e-01, 3.4686e-01,
          9.8213e-01, 3.1160e-01, 5.1593e-01, 4.7918e-01, 7.7597e-01, 8.7662e-01,
          7.2025e-01, 7.8843e-02, 3.6143e-01, 2.7620e-01, 7.3832e-01, 1.2199e-01,
          7.6816e-01, 3.3088e-01, 5.3308e-01, 8.3727e-01, 3.4454e-01, 8.3313e-01,
          4.0220e-02, 8.2559e-01, 6.0820e-02, 5.2861e-01, 2.4866e-01, 7.1011e-01,
          9.2070e-01, 4.0587e-01, 6.6516e-01, 2.4116e-01, 5.1024e-01, 1.3055e-01,
          5.1477e-02, 7.1164e-01, 8.1807e-01, 7.5981e-01, 7.5819e-01, 3.2527e-01,
          4.4353e-01, 7.7080e-01, 3.2903e-01, 5.2482e-01, 1.4463e-01, 8.3229e-01,
          8.0723e-01, 8.1806e-01, 5.0002e-01, 4.4697e-01, 1.2796e-01, 3.3480e-01,
          3.1038e-01, 5.5274e-02, 7.5902e-02, 4.2351e-01, 1.9490e-01, 9.2346e-01,
          9.2659e-01, 1.8682e-01, 1.3897e-01, 7.8278e-01, 8.6317e-01, 3.8320e-02,
          5.8644e-01, 9.0572e-01, 8.5458e-01, 4.5462e-01, 8.0315e-01, 8.9879e-02,
          1.0618e-01, 8.0783e-01, 2.6974e-02, 2.3689e-01, 2.9643e-01, 6.7493e-01,
          2.7907e-01, 5.5367e-01, 9.3280e-02, 9.2625e-01, 5.4860e-01, 7.0376e-02,
          5.4315e-01, 2.5089e-01, 2.0467e-01, 7.3139e-01, 6.0116e-01, 8.3706e-01,
          6.6732e-01, 3.7537e-01, 8.3278e-01, 4.8573e-01, 2.2279e-02, 7.8911e-01,
          9.7635e-01, 5.3045e-01, 7.5861e-01, 4.6490e-01, 1.0845e-01, 6.0059e-01,
          8.1958e-01, 5.8407e-01, 2.1209e-01, 4.0547e-01, 5.1498e-01, 4.4649e-01,
          2.6352e-01, 3.8251e-01, 5.5353e-01, 3.9497e-01, 9.0617e-01, 6.7568e-01,
          1.5051e-01, 3.9751e-01, 7.0274e-01, 5.6846e-01, 1.0390e-01, 6.0444e-01,
          7.9628e-01, 4.3004e-01, 8.8855e-01, 7.4114e-01, 5.9640e-01, 6.8028e-01,
          3.7275e-01, 3.0151e-01, 1.7370e-01, 5.6640e-01, 7.2067e-01, 9.6942e-01,
          3.2944e-01, 4.4925e-02, 3.1496e-01, 2.2803e-01, 2.9965e-01, 7.0194e-01,
          3.9314e-01, 4.7080e-01, 8.1630e-02, 1.1783e-01, 2.0786e-01, 2.9285e-01,
          5.3395e-01, 6.4874e-01, 4.6018e-02, 5.3807e-01, 1.9702e-01, 7.9351e-01,
          3.4352e-01, 7.4383e-01, 2.2405e-01, 7.1542e-01, 8.2697e-02, 5.6497e-01,
          8.9829e-01, 3.3166e-01, 2.8397e-01, 2.4034e-01, 7.3679e-01, 1.9978e-01,
          6.1062e-01, 2.4549e-01, 6.0340e-02, 4.7634e-01, 1.3082e-01, 5.0697e-01,
          4.1993e-01, 5.8177e-01, 7.5621e-01, 2.0104e-01, 1.5791e-02, 7.3531e-01,
          8.6762e-01, 9.5891e-01, 6.5430e-01, 8.1056e-02, 1.7849e-03, 6.0112e-01,
          9.5946e-01, 9.7519e-01, 3.6132e-01, 3.2361e-01, 6.7261e-01, 7.9804e-02,
          9.3448e-01, 3.0291e-01, 3.1697e-02, 5.5981e-01, 1.1414e-01, 7.8567e-01,
          3.4255e-01, 7.3328e-01, 6.2983e-01, 2.8096e-02, 1.5039e-01, 5.6391e-01,
          3.0765e-01, 8.9107e-01, 4.5396e-01, 8.6478e-01, 7.8518e-01, 6.2083e-01,
          1.7664e-01, 5.0840e-01]]),
 'the': torch.tensor([[5.9716e-01, 8.8556e-01, 8.5613e-01, 9.4307e-01, 7.1746e-01, 2.5895e-01,
          1.9383e-01, 9.8941e-02, 9.1393e-01, 4.5181e-01, 6.4429e-01, 5.8876e-01,
          7.9625e-01, 5.4962e-01, 4.6825e-02, 3.9037e-01, 8.8964e-01, 3.0795e-01,
          5.3498e-01, 5.4634e-01, 3.6914e-01, 4.4809e-01, 2.3708e-01, 5.3106e-01,
          9.0105e-01, 9.8789e-01, 9.1218e-01, 5.2036e-01, 8.2404e-01, 8.2826e-01,
          1.7016e-01, 6.9585e-01, 1.8131e-01, 6.3051e-01, 5.0824e-01, 3.6289e-01,
          1.0272e-01, 1.5159e-01, 4.4929e-01, 9.6666e-01, 4.5874e-01, 1.1520e-01,
          9.1800e-01, 9.7879e-01, 6.1359e-01, 8.8229e-01, 6.6125e-01, 1.3010e-01,
          8.4326e-01, 9.0826e-01, 1.9370e-01, 6.8556e-01, 3.6316e-01, 4.4351e-01,
          3.9717e-02, 6.8477e-01, 3.5896e-01, 3.7225e-01, 2.6983e-02, 7.9714e-02,
          4.8514e-01, 1.4685e-01, 2.7173e-01, 2.0835e-01, 5.0232e-01, 5.3126e-01,
          8.1204e-01, 4.3673e-01, 1.1383e-01, 6.8204e-02, 8.7348e-01, 2.1681e-01,
          4.0372e-01, 8.0213e-01, 2.8356e-02, 2.1587e-01, 6.0343e-01, 9.8251e-01,
          1.8921e-01, 5.7847e-01, 8.0916e-01, 4.9086e-01, 1.3679e-01, 2.8025e-01,
          8.4532e-01, 1.9179e-01, 8.6524e-01, 4.1704e-01, 6.2315e-01, 7.6275e-01,
          8.5727e-01, 4.7626e-01, 2.3798e-01, 7.3355e-01, 3.0818e-01, 9.6668e-01,
          2.2966e-01, 9.4888e-02, 8.7657e-01, 4.5626e-01, 6.7596e-01, 8.8898e-01,
          8.5081e-02, 8.8662e-01, 6.2149e-01, 3.0379e-01, 5.2365e-01, 9.0404e-01,
          9.9760e-01, 8.3603e-01, 7.0537e-01, 6.5862e-01, 8.6399e-01, 4.2443e-01,
          8.7665e-01, 3.7484e-01, 6.2866e-01, 8.4476e-01, 2.6564e-03, 1.8827e-01,
          1.2825e-01, 9.3827e-01, 1.8516e-01, 2.5345e-01, 8.9917e-01, 4.4333e-01,
          2.3897e-01, 6.4961e-01, 7.9292e-01, 7.4344e-01, 3.0364e-03, 5.4069e-01,
          7.8301e-01, 6.1868e-01, 2.5626e-01, 2.2073e-02, 7.2271e-01, 9.0102e-01,
          7.7735e-01, 8.3903e-01, 1.4693e-01, 9.4049e-01, 9.5331e-01, 8.8624e-01,
          2.5608e-01, 9.8704e-01, 5.1126e-01, 4.0627e-01, 1.4281e-01, 5.9203e-01,
          8.0833e-01, 1.1424e-01, 6.6141e-01, 4.3098e-01, 5.5087e-01, 4.4888e-01,
          1.6036e-01, 7.3204e-01, 9.8181e-01, 4.8673e-01, 4.9816e-01, 9.5695e-01,
          1.6447e-01, 2.6880e-01, 6.0759e-01, 8.1632e-01, 1.0371e-01, 1.1677e-01,
          2.1969e-02, 9.5638e-01, 4.9026e-01, 8.7338e-01, 6.5518e-02, 3.3154e-01,
          6.6842e-01, 5.2022e-01, 3.8877e-01, 6.4411e-01, 8.7390e-01, 8.2009e-02,
          2.3922e-01, 1.2955e-01, 2.4522e-01, 5.8091e-01, 6.7490e-01, 4.8195e-01,
          3.2732e-01, 7.8641e-01, 6.8853e-01, 3.3378e-01, 1.3111e-01, 8.5384e-01,
          4.3768e-01, 2.9535e-01, 5.5984e-01, 2.2774e-02, 8.8002e-01, 1.5850e-01,
          5.7387e-01, 9.0617e-01, 2.4149e-01, 2.8398e-02, 2.9076e-01, 7.0567e-01,
          8.7564e-01, 5.0916e-01, 5.6099e-01, 9.4709e-01, 5.8040e-01, 4.8262e-01,
          9.2032e-01, 8.5568e-01, 8.2705e-01, 2.2565e-01, 8.5345e-01, 5.0968e-01,
          7.5939e-01, 1.2943e-01, 4.6860e-01, 7.9926e-01, 7.5122e-01, 5.7070e-01,
          2.4807e-01, 4.4509e-01, 8.5761e-01, 8.7043e-01, 3.6721e-01, 9.2974e-02,
          4.3005e-01, 1.9694e-01, 6.5944e-01, 3.5507e-01, 7.2222e-01, 7.2747e-01,
          9.3492e-03, 8.0341e-01, 7.6326e-01, 9.2470e-01, 7.0943e-01, 9.9356e-01,
          8.7999e-01, 2.6757e-01, 1.0754e-03, 8.2257e-01, 3.8843e-01, 1.3998e-01,
          4.2243e-01, 2.3335e-01, 1.4262e-02, 3.5172e-01, 5.5245e-02, 9.9674e-01,
          7.6166e-01, 5.3164e-01, 7.4305e-02, 9.9851e-01, 8.6691e-01, 8.4843e-01,
          8.2974e-01, 1.2990e-01, 4.1648e-01, 3.2595e-01, 3.0286e-01, 9.8203e-01,
          8.0960e-01, 8.0323e-01, 9.4784e-01, 5.1274e-01, 7.9072e-01, 7.3096e-01,
          1.5225e-01, 5.6120e-02, 9.9264e-01, 6.7411e-01, 3.4756e-01, 8.5652e-05,
          5.0898e-01, 5.3245e-01, 5.4389e-01, 2.1278e-01, 7.7696e-01, 8.8965e-01,
          2.5698e-01, 8.3106e-01, 8.2892e-01, 1.3490e-01, 5.2489e-01, 9.7364e-01,
          7.4769e-02, 6.7848e-01, 4.8939e-02, 8.5751e-01, 7.4371e-01, 1.8864e-01,
          9.9723e-01, 3.3427e-01, 5.3161e-01, 2.0851e-01, 8.5971e-02, 1.8024e-03,
          1.6875e-01, 9.9456e-02, 6.6265e-01, 7.7105e-01, 5.8859e-01, 3.3375e-01,
          5.1556e-01, 6.0761e-01, 9.7717e-01, 8.6212e-01, 1.4157e-01, 4.2871e-01,
          6.2717e-01, 3.9140e-02, 2.8228e-01, 2.4733e-01, 1.6921e-01, 1.3334e-01,
          3.0075e-01, 1.0022e-01, 8.1735e-01, 9.0056e-01, 6.0633e-01, 1.0343e-01,
          5.4521e-01, 1.4244e-01, 9.5417e-01, 8.4845e-01, 5.6204e-01, 3.3640e-01,
          8.6809e-02, 7.7329e-01, 7.9940e-01, 7.9623e-01, 9.1862e-01, 6.0978e-01,
          2.8501e-01, 2.3244e-01, 6.6288e-01, 2.1600e-01, 9.0017e-01, 4.2196e-01,
          8.7898e-01, 2.3492e-01, 3.0776e-01, 5.3610e-01, 8.1624e-01, 2.4583e-01,
          9.1892e-02, 9.9278e-01, 1.5261e-01, 9.3363e-01, 9.0087e-01, 3.9119e-01,
          3.0562e-01, 4.6124e-01, 9.3669e-01, 1.6582e-01, 3.7153e-01, 8.6663e-01,
          6.7876e-01, 8.1477e-01, 4.7438e-01, 9.9147e-01, 9.4156e-01, 8.6005e-01,
          9.2642e-01, 4.3974e-01, 5.4200e-01, 7.8496e-01, 2.4301e-01, 4.7981e-01,
          9.7571e-01, 9.5016e-01, 6.8761e-01, 6.4159e-01, 2.1872e-01, 4.1056e-01,
          6.1267e-01, 1.1459e-01, 1.8138e-01, 5.7354e-01, 1.9937e-01, 5.8107e-01,
          7.7895e-01, 3.8689e-01, 8.4914e-02, 3.4589e-01, 9.4474e-01, 1.9079e-01,
          8.9324e-01, 1.5946e-01, 2.3353e-01, 6.4780e-01, 3.0497e-01, 8.0430e-01,
          3.3793e-01, 8.5644e-01, 5.3836e-01, 7.7878e-02, 7.1061e-01, 6.8277e-01,
          7.8382e-01, 3.7341e-01, 4.0918e-01, 7.3309e-01, 8.2653e-01, 5.2282e-01,
          6.3758e-01, 5.2945e-01, 1.6552e-02, 2.0541e-02, 5.1810e-01, 9.0847e-01,
          5.2944e-01, 9.3667e-01, 9.0491e-01, 9.3837e-02, 3.3871e-01, 2.4793e-01,
          1.8385e-01, 6.4224e-01, 2.7627e-01, 6.2083e-01, 6.2202e-01, 5.5904e-01,
          7.8870e-01, 3.4068e-01, 2.5864e-01, 4.0688e-01, 3.8846e-02, 7.2750e-01,
          4.8177e-01, 7.4269e-01, 3.4739e-01, 8.8328e-02, 8.8877e-01, 3.5912e-01,
          5.0049e-01, 2.7131e-01, 8.9791e-01, 8.1208e-01, 5.0786e-01, 3.1147e-01,
          6.0630e-01, 6.4723e-01, 6.3352e-01, 8.7271e-01, 2.7027e-01, 3.6926e-01,
          3.9001e-02, 2.1423e-01, 1.2677e-01, 3.4298e-01, 2.2920e-02, 7.3216e-02,
          8.1706e-01, 7.3704e-02, 3.1072e-01, 5.6233e-01, 2.8943e-01, 4.7282e-02,
          1.4712e-01, 6.4526e-01, 7.4838e-01, 4.4316e-01, 3.6862e-01, 4.6678e-02,
          8.5393e-01, 5.7872e-02, 3.6773e-01, 5.6294e-01, 9.9771e-01, 7.0511e-01,
          2.3592e-01, 7.8172e-01, 4.0488e-01, 9.9540e-01, 9.5749e-01, 9.8313e-01,
          2.0249e-01, 8.7962e-01, 1.4029e-01, 1.2557e-01, 4.0738e-01, 9.4714e-01,
          6.0821e-01, 2.5611e-01, 6.3711e-01, 8.1161e-01, 4.8803e-01, 4.6730e-01,
          9.2156e-01, 6.1729e-01, 9.7630e-01, 3.1694e-01, 2.6823e-01, 7.6794e-01,
          3.6108e-02, 4.8462e-01, 7.5794e-01, 4.4048e-01, 4.5925e-01, 6.6520e-01,
          1.6901e-01, 2.7798e-01, 9.5578e-01, 6.2469e-01, 4.8220e-01, 2.6849e-01,
          3.0333e-01, 9.0285e-01]]),
 'best': torch.tensor([[0.8685, 0.4230, 0.9698, 0.0489, 0.7114, 0.5790, 0.2682, 0.8984, 0.6091,
          0.6066, 0.9561, 0.1979, 0.2855, 0.7902, 0.0230, 0.4718, 0.9265, 0.9100,
          0.2152, 0.2065, 0.4270, 0.8716, 0.2135, 0.6221, 0.2553, 0.9074, 0.4845,
          0.1413, 0.6820, 0.2977, 0.2641, 0.9933, 0.0892, 0.8254, 0.1252, 0.3009,
          0.6442, 0.9320, 0.6234, 0.9437, 0.4052, 0.2305, 0.1018, 0.2148, 0.1482,
          0.9086, 0.6899, 0.2354, 0.9773, 0.4356, 0.0720, 0.0151, 0.2598, 0.6253,
          0.0161, 0.7656, 0.4335, 0.8049, 0.8737, 0.8874, 0.8559, 0.3474, 0.2414,
          0.6109, 0.8350, 0.8681, 0.0533, 0.2512, 0.3858, 0.1697, 0.3882, 0.6555,
          0.5758, 0.5914, 0.2158, 0.1183, 0.0652, 0.0065, 0.0726, 0.1239, 0.1758,
          0.3789, 0.6184, 0.7131, 0.2448, 0.9515, 0.4726, 0.7412, 0.2166, 0.6527,
          0.0863, 0.1759, 0.4398, 0.1367, 0.5569, 0.0793, 0.4983, 0.0173, 0.1063,
          0.8217, 0.5270, 0.8404, 0.7655, 0.3585, 0.0162, 0.1701, 0.6311, 0.0893,
          0.9195, 0.4121, 0.0709, 0.2479, 0.5698, 0.9999, 0.7933, 0.8693, 0.2002,
          0.1046, 0.5840, 0.6381, 0.9009, 0.6303, 0.9602, 0.9222, 0.5283, 0.6958,
          0.9112, 0.6953, 0.3188, 0.2032, 0.6394, 0.7330, 0.0345, 0.8332, 0.8431,
          0.6761, 0.6408, 0.6977, 0.4418, 0.4319, 0.4629, 0.9379, 0.8229, 0.5624,
          0.7172, 0.8736, 0.2522, 0.6497, 0.1915, 0.2397, 0.0353, 0.8381, 0.9651,
          0.1687, 0.5526, 0.1649, 0.1516, 0.2064, 0.5192, 0.4784, 0.7016, 0.7354,
          0.0223, 0.4391, 0.9908, 0.2202, 0.6836, 0.8539, 0.4523, 0.2878, 0.1313,
          0.2218, 0.1664, 0.0767, 0.6833, 0.5995, 0.6032, 0.9476, 0.6488, 0.2242,
          0.2173, 0.9145, 0.2873, 0.0816, 0.5510, 0.2363, 0.3002, 0.5767, 0.1635,
          0.8453, 0.3174, 0.4896, 0.4734, 0.5994, 0.8761, 0.5897, 0.6856, 0.3323,
          0.2078, 0.0138, 0.5288, 0.8533, 0.8132, 0.4933, 0.6933, 0.1490, 0.7795,
          0.8469, 0.3599, 0.1281, 0.6431, 0.4225, 0.9772, 0.3264, 0.2638, 0.5438,
          0.6871, 0.8335, 0.3571, 0.1732, 0.8822, 0.4448, 0.7064, 0.0590, 0.6028,
          0.5456, 0.3940, 0.2250, 0.7949, 0.0277, 0.0172, 0.8750, 0.6884, 0.8703,
          0.5337, 0.3156, 0.1942, 0.7022, 0.5961, 0.6566, 0.4708, 0.3963, 0.6845,
          0.2951, 0.1476, 0.2319, 0.5214, 0.5700, 0.6698, 0.8632, 0.9141, 0.2448,
          0.9981, 0.8407, 0.9528, 0.5692, 0.4634, 0.2680, 0.1562, 0.6089, 0.8068,
          0.2088, 0.0471, 0.6172, 0.3302, 0.9984, 0.4486, 0.8361, 0.3444, 0.2501,
          0.0428, 0.5881, 0.2384, 0.6039, 0.3627, 0.0386, 0.2249, 0.5907, 0.2291,
          0.6833, 0.9104, 0.9721, 0.1465, 0.2934, 0.5328, 0.9190, 0.8061, 0.3778,
          0.3771, 0.5723, 0.0276, 0.5642, 0.0617, 0.2984, 0.3237, 0.9144, 0.5736,
          0.9874, 0.5342, 0.3951, 0.8674, 0.7609, 0.5768, 0.0871, 0.1040, 0.3112,
          0.7825, 0.7628, 0.3192, 0.5798, 0.2001, 0.8278, 0.3726, 0.4363, 0.4661,
          0.5691, 0.9202, 0.0041, 0.0537, 0.0243, 0.3539, 0.5285, 0.0940, 0.7569,
          0.9150, 0.4169, 0.3965, 0.9940, 0.9079, 0.0894, 0.9083, 0.4398, 0.5813,
          0.6188, 0.3488, 0.2400, 0.7538, 0.2246, 0.4477, 0.9849, 0.2171, 0.4304,
          0.5834, 0.9875, 0.8313, 0.9005, 0.8072, 0.1989, 0.2741, 0.4732, 0.4005,
          0.1655, 0.9900, 0.5290, 0.9003, 0.0714, 0.0294, 0.9476, 0.1109, 0.5694,
          0.4495, 0.2465, 0.3022, 0.5693, 0.2030, 0.5041, 0.3725, 0.9894, 0.0798,
          0.0654, 0.6592, 0.9158, 0.3915, 0.6521, 0.7879, 0.1784, 0.5302, 0.5140,
          0.6482, 0.8163, 0.7481, 0.9123, 0.9731, 0.6805, 0.1868, 0.7336, 0.7600,
          0.9402, 0.9167, 0.7508, 0.0468, 0.6629, 0.2850, 0.4055, 0.7866, 0.4304,
          0.5055, 0.1628, 0.4506, 0.7686, 0.3156, 0.6145, 0.5558, 0.4981, 0.6653,
          0.3920, 0.1579, 0.4160, 0.3593, 0.7374, 0.5180, 0.7503, 0.9096, 0.8974,
          0.5925, 0.9591, 0.0373, 0.3643, 0.2406, 0.9306, 0.6286, 0.8017, 0.1823,
          0.0828, 0.4735, 0.1840, 0.1269, 0.1872, 0.1588, 0.5363, 0.5896, 0.5634,
          0.5369, 0.9899, 0.4987, 0.2368, 0.7556, 0.8587, 0.3342, 0.3088, 0.5890,
          0.3332, 0.9669, 0.9919, 0.1565, 0.7024, 0.0276, 0.5024, 0.4499, 0.9406,
          0.8990, 0.9619, 0.2295, 0.7416, 0.4053, 0.3703, 0.5169, 0.1264, 0.8330,
          0.3740, 0.6907, 0.1357, 0.6217, 0.1624, 0.1126, 0.3528, 0.8001, 0.5593,
          0.2756, 0.0462, 0.4759, 0.6341, 0.7747, 0.8072, 0.0605, 0.3261, 0.9816,
          0.8293, 0.9134, 0.7604, 0.0455, 0.7839, 0.3231, 0.4388, 0.4917, 0.4682,
          0.5339, 0.6218, 0.6764, 0.7503, 0.4121, 0.8496, 0.2517, 0.3841, 0.3800,
          0.2905, 0.6880, 0.2412, 0.8529, 0.6242, 0.8602, 0.5089, 0.5435, 0.7993,
          0.1419, 0.2879, 0.6521, 0.2606, 0.0457, 0.4149, 0.1444, 0.5578]])}

In [193]:
token_embeddings.keys()

dict_keys(['I', 'love', 'this', 'product', 'This', 'is', 'terrible', 'Could', 'be', 'better', 'the', 'best'])

In [134]:
for key in token_embeddings.keys():
    print(key, token_embeddings[key].shape)

I torch.Size([1, 512])
love torch.Size([1, 512])
this torch.Size([1, 512])
product torch.Size([1, 512])
This torch.Size([1, 512])
is torch.Size([1, 512])
terrible torch.Size([1, 512])
Could torch.Size([1, 512])
be torch.Size([1, 512])
better torch.Size([1, 512])
the torch.Size([1, 512])
best torch.Size([1, 512])


In [118]:
# from transformers import BertModel, BertTokenizer
# import torch

# # Initialize tokenizer and model
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model = BertModel.from_pretrained('bert-base-uncased')

# # List of unique tokens in your data
# tokens = ["I", "love", "this", "product", "This", "is", "terrible", "Could", "be", "better", "the", "best"]

# # Create a dictionary mapping from tokens to their embeddings
# token_embeddings = {}
# for token in tokens:
#     # Tokenize the word to get its input IDs & pass it through the BERT model
#     input_ids = tokenizer.encode(token, add_special_tokens=False)
#     with torch.no_grad():
#         outputs = model(torch.tensor([input_ids]))  # The output has 2 elements: hidden states and the output tensor
#     # Get the embeddings from the final hidden state
#     embeddings = outputs[0]  # (1, seq_length, 768)
#     # Remove the batch dimension and get the embedding of the [CLS] token
#     embeddings = embeddings.squeeze(0)  # (seq_length, 768)
#     token_embeddings[token] = embeddings

# # Now you can print the shape of each token's embedding
# for key in token_embeddings.keys():
#     print(key, token_embeddings[key].shape)

In [141]:
# The token_embeddings are used to convert each token (word) in the sentence into a high-dimensional vector representation that captures its semantic meaning. These embeddings are typically learned from large amounts of text data and are useful for many natural language processing tasks. By starting with these precomputed embeddings, the model can more easily learn to make accurate predictions.

for epoch in range(5):
    # Loop over each sentence and its corresponding label in the training data
    for sentence, label in zip(train_sentences, train_labels):
        # Print the sentence and label for debugging
        print('sentence , label', sentence, label)

        # Split the sentence into individual words (tokens)
        tokens = sentence.split()
        print('tokens', tokens)

        # For each token, retrieve its embedding from the precomputed token_embeddings dictionary.
        # Stack these embeddings together to form a tensor representing the sentence.
        # The dimension 1 is chosen for stacking because we want each row of the tensor to represent a token.
        data = torch.stack([token_embeddings[token] for token in tokens], dim=1)
        print('data.shape', data.shape)

        # Pass the sentence tensor through the model to get predictions
        output = model(data)
        print('output', output)

        # Compute the loss between the model's predictions and the true label
        # The label is wrapped in a tensor for compatibility with PyTorch functions
        loss = criterion(output, torch.tensor([label]))
        print('loss', loss)

        # Before the backward pass, use the optimizer object to zero all of the
        # gradients for the variables it will update (which are the learnable weights of the model)
        optimizer.zero_grad()

        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()

        # Calling the step function on an Optimizer makes an update to its parameters
        optimizer.step()

        # Print the loss for this epoch
        print(f"Epoch {epoch}, Loss: {loss.item()}")



sentence , label I love this product 1
tokens ['I', 'love', 'this', 'product']
data.shape torch.Size([1, 4, 512])
output tensor([[ 0.3219, -0.6821]], grad_fn=<AddmmBackward0>)
loss tensor(1.3162, grad_fn=<NllLossBackward0>)
Epoch 0, Loss: 1.3161700963974
sentence , label This is terrible 0
tokens ['This', 'is', 'terrible']
data.shape torch.Size([1, 3, 512])
output tensor([[-0.0569, -0.0821]], grad_fn=<AddmmBackward0>)
loss tensor(0.6806, grad_fn=<NllLossBackward0>)
Epoch 0, Loss: 0.680583655834198
sentence , label Could be better 0
tokens ['Could', 'be', 'better']
data.shape torch.Size([1, 3, 512])
output tensor([[-0.0678, -0.1236]], grad_fn=<AddmmBackward0>)
loss tensor(0.6656, grad_fn=<NllLossBackward0>)
Epoch 0, Loss: 0.6656351685523987
sentence , label I love this product 1
tokens ['I', 'love', 'this', 'product']
data.shape torch.Size([1, 4, 512])
output tensor([[ 0.1953, -0.5564]], grad_fn=<AddmmBackward0>)
loss tensor(1.1380, grad_fn=<NllLossBackward0>)
Epoch 1, Loss: 1.138014078

In [196]:
def predict(sentence):
    model.eval()
    # Deactivate the gradient computations and get the sentiment prediction.
    with torch.no_grad():
        tokens = sentence.split()
        print('tokens', tokens)
        data = torch.stack([token_embeddings.get(token, torch.rand((1, 512))) for token in tokens], dim=1) # torch.rand, it’s used as a fallback method to handle tokens that are not found in the token_embeddings dictionary
        print('data.shape', data.shape)
        output = model(data)
        print('output', output)
        predicted = torch.argmax(output, dim=1)
        print('predicted', predicted)
        return "Positive" if predicted.item() == 1 else "Negative"

sample_sentence = "This product can be better"
print(f"'{sample_sentence}' is {predict(sample_sentence)}")

tokens ['This', 'product', 'can', 'be', 'better']
data.shape torch.Size([1, 5, 512])
output tensor([[ 0.3276, -0.4645]])
predicted tensor([0])
'This product can be better' is Negative


In [199]:
class RNNWithAttentionModel(nn.Module):
    
    def __init__(self):
        super(RNNWithAttentionModel, self).__init__()
        # Create an embedding layer for the vocabulary
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        # Apply a linear transformation to get the attention scores
        self.attention = nn.Linear(hidden_dim, 1) # linear layer is a simplified implementation of the attention mechanism. It transforms the hidden state of the RNN (hidden_dim) into a single value (hence the output dimension is 1). This value is used as the attention score
        self.fc = nn.Linear(hidden_dim, vocab_size) # This line is defining a linear layer that maps the hidden state of the RNN to a vector of size vocab_size. The objective is to predict the probability of each word being the next word in the sentence, hence the output size of this layer is equal to the number of words in the vocabulary, i.e., vocab_size. But these outputs are logits, not probabilities.
        
    def forward(self, x):
        x = self.embeddings(x)
        out, _ = self.rnn(x) # The RNN processes the sequence of word embeddings one by one and outputs a hidden state for each word. This hidden state is a representation of the word in the context of the sentence
        #  Get the attention weights
        attn_weights = torch.nn.functional.softmax(self.attention(out).squeeze(2), dim=1) # The attention scores are computed by applying a linear transformation (defined by self.attention) to the hidden states output by the RNN. The softmax function is then applied to these scores so that they sum to 1 and can be interpreted as probabilities. This means that each word in the sentence is assigned a weight between 0 and 1, with more important words receiving higher weights
        # Compute the context vector 
        context = torch.sum(attn_weights.unsqueeze(2) * out, dim=1) # The context vector is computed as the weighted sum of the RNN hidden states, where the weights are the attention scores. This results in a single vector that is a summary of the entire sentence, with more emphasis placed on the important words
        out = self.fc(context) # the linear layer fc is used to compute a set of scores for each word in the vocabulary. It outputs logits, and then a softmax function (typically applied in the loss function: loss = criterion(outputs, targets)) turns these scores into probabilities
        return out


In [200]:
optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    weight_decay: 0
)

In [166]:
# the below inputs are sequences of word indices. Each number represents a word in your vocabulary
inputs = [torch.tensor([ 3, 12, 11, 16,  3]), 
 torch.tensor([13,  5, 15,  2]),
 torch.tensor([ 6,  5, 10,  9]),
 torch.tensor([4, 5, 3, 7])]

# targets are also word indices. Each target is the word that the model should predict given the corresponding input sequence
targets = torch.tensor([ 1,  8,  0, 14])

In [233]:
vocab_size = sum(inputs[i].shape[0] for i in range(len(inputs)))

embedding_dim = 10
hidden_dim = 16

vocab_size

17

In [209]:
attention_model = RNNWithAttentionModel()
attention_model

RNNWithAttentionModel(
  (embeddings): Embedding(17, 10)
  (rnn): RNN(10, 16, batch_first=True)
  (attention): Linear(in_features=16, out_features=1, bias=True)
  (fc): Linear(in_features=16, out_features=17, bias=True)
)

In [210]:
# The pad_sequences function is used to ensure that all sequences in a batch have the same length so they can be processed together. 

def pad_sequences(batch):     
    max_len = max([len(seq) for seq in batch])
    return torch.stack([torch.cat([seq, torch.zeros(max_len-len(seq)).long()]) for seq in batch])

In [220]:
import torch.nn.functional as F

epochs = 1 # change to 300

for epoch in range(epochs):
    attention_model.train()
    optimizer.zero_grad()
    print('inputs', inputs)
    padded_inputs = pad_sequences(inputs)
    print('padded_inputs', padded_inputs)
    print('padded_inputs.shape', padded_inputs.shape)
    outputs = attention_model(padded_inputs)
    print('targets.shape', targets.shape)
    print('outputs.shape', outputs.shape)
    print('outputs', outputs)
    probabilities = F.softmax(outputs, dim=1)  # Apply softmax to the outputs to get probabilities (just for understanding, because the loss criterion uses the outputs logits direct, instead of the probabilities)
    print('probabilities:', probabilities)  
    loss = criterion(outputs, targets)
    print('loss', loss)
    loss.backward()
    optimizer.step()
    print('----------------------------------------------------------------------------------------------------------------------')


inputs [tensor([ 3, 12, 11, 16,  3]), tensor([13,  5, 15,  2]), tensor([ 6,  5, 10,  9]), tensor([4, 5, 3, 7])]
padded_inputs tensor([[ 3, 12, 11, 16,  3],
        [13,  5, 15,  2,  0],
        [ 6,  5, 10,  9,  0],
        [ 4,  5,  3,  7,  0]])
padded_inputs.shape torch.Size([4, 5])
targets.shape torch.Size([4])
outputs.shape torch.Size([4, 17])
outputs tensor([[-0.3106,  0.5167,  0.1364, -0.0742,  0.0536,  0.0878,  0.3273, -0.0699,
         -0.1763, -0.0348, -0.2325,  0.1983,  0.0194, -0.3633, -0.5338, -0.5047,
         -0.0844],
        [-0.2740,  0.1209, -0.0385, -0.0336, -0.1356,  0.3925, -0.0324, -0.2128,
         -0.3367, -0.0519,  0.1298,  0.2006, -0.1395, -0.0757, -0.0360, -0.2241,
         -0.1544],
        [-0.4180,  0.3230,  0.1359,  0.0261,  0.1709,  0.4095,  0.2102,  0.0374,
         -0.4524, -0.1244,  0.1721, -0.1776, -0.3164,  0.1884,  0.0016, -0.2180,
         -0.2133],
        [-0.3108,  0.3213,  0.0487, -0.0804,  0.0833,  0.2437,  0.1857, -0.0545,
         -0.3139, 

In [307]:
class RNNModel_2(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNNModel_2, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim
        
    def forward(self, x):
        x = self.embeddings(x)  # Convert input indices to embeddings
        h0 = torch.zeros(1, x.size(0), self.hidden_dim).to(x.device)  # Initial hidden state
        out, _ = self.rnn(x, h0)  # Pass the input and initial hidden state through the RNN
        out = out[:, -1, :]  # Select the last output of each sequence
        out = self.fc(out)  # Pass the last outputs through the fully connected layer
        return out


In [308]:
rnn_model = RNNModel_2(vocab_size, embedding_dim, hidden_dim)
rnn_model

RNNModel_2(
  (embeddings): Embedding(17, 10)
  (rnn): RNN(10, 16, batch_first=True)
  (fc): Linear(in_features=16, out_features=17, bias=True)
)

In [309]:
input_data = [[2, 7, 8, 6, 2], [16, 14, 4, 12], [9, 14, 13, 15], [3, 14, 2, 10]]

target_data = [11, 5, 0, 1]

ix_to_word = {0: 'noisy',
             1: 'mammals',
             2: 'the',
             3: 'whales',
             4: 'very',
             5: 'animals',
             6: 'on',
             7: 'cat',
             8: 'sat',
             9: 'parrots',
             10: 'largest',
             11: 'mat',
             12: 'loyal',
             13: 'colorful',
             14: 'are',
             15: 'and',
             16: 'dogs'}

In [332]:
for input_seq, target in zip(input_data, target_data):
    input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
    print('input_test', input_test)
    print([ix_to_word[index.item()] for index in input_test[0]])
   
    #  Set the RNN model to evaluation mode
    rnn_model.eval()
    # Get the RNN output by passing the appropriate input 
    rnn_output = rnn_model(input_test)
    print('rnn_output', rnn_output)
    rnn_prediction = ix_to_word[torch.argmax(rnn_output).item()]
    print('rnn_prediction', rnn_prediction)

    attention_model.eval()
    attention_output = attention_model(input_test)
    print('attention_output', attention_output)
    # Extract the word with the highest prediction score
    attention_prediction = ix_to_word[torch.argmax(attention_output).item()]
    print('attention_prediction', attention_prediction)

    # print(f"\nInput: {' '.join([ix_to_word[ix] for ix in input_seq])}")
    # print(f"Target: {ix_to_word[target]}")
    # print(f"RNN prediction: {rnn_prediction}")
    # print(f"RNN with Attention prediction: {attention_prediction}")
    
    print('----------------------------------------------------------------------------------------------------')

input_test tensor([[2, 7, 8, 6, 2]])
['the', 'cat', 'sat', 'on', 'the']
rnn_output tensor([[-0.1428,  0.4585, -0.3621, -0.0588,  0.1513,  0.2110,  0.0612,  0.3391,
          0.0640,  0.4440, -0.3169, -0.5467,  0.2409, -0.5939, -0.0526, -0.0015,
          0.6679]], grad_fn=<AddmmBackward0>)
rnn_prediction dogs
attention_output tensor([[-0.2578,  0.1940, -0.2021, -0.2177, -0.3048,  0.1416,  0.0347, -0.0964,
         -0.1777,  0.0350,  0.0942,  0.2312,  0.0515, -0.3250, -0.2956, -0.3767,
         -0.0722]], grad_fn=<AddmmBackward0>)
attention_prediction mat
----------------------------------------------------------------------------------------------------
input_test tensor([[16, 14,  4, 12]])
['dogs', 'are', 'very', 'loyal']
rnn_output tensor([[-0.5535,  0.2841, -0.1895,  0.6107, -0.0878, -0.1122,  0.0431,  0.3670,
         -0.0467, -0.2583,  0.3900, -0.4158,  0.1517, -0.1374, -0.4201,  0.6435,
         -0.0965]], grad_fn=<AddmmBackward0>)
rnn_prediction and
attention_output tensor([[-0.

In [330]:
tens = torch.tensor([[2, 7, 8, 6, 2]])

[i.item() for i in tens[0]]

[2, 7, 8, 6, 2]