In [1]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Torch version:", torch.__version__)


CUDA available: True
CUDA version: 11.8
Torch version: 2.2.0+cu118


In [2]:
import torchtext
print("Torchtext version:", torchtext.__version__)

  from .autonotebook import tqdm as notebook_tqdm


Torchtext version: 0.17.0+cpu


In [3]:
# Import the necessary functions
# !pip install torchtext

import nltk
import torch
from torchtext.data.utils import get_tokenizer
from nltk.probability import FreqDist
from nltk.stem import PorterStemmer
from nltk.corpus import stopwords


text = "In the city of Dataville, a data analyst named Alex explores hidden insights within vast data. With determination, Alex uncovers patterns, cleanses the data, and unlocks innovation. Join this adventure to unleash the power of data-driven decisions."

# Initialize the tokenizer and tokenize the text
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(text)

tokens

['in',
 'the',
 'city',
 'of',
 'dataville',
 ',',
 'a',
 'data',
 'analyst',
 'named',
 'alex',
 'explores',
 'hidden',
 'insights',
 'within',
 'vast',
 'data',
 '.',
 'with',
 'determination',
 ',',
 'alex',
 'uncovers',
 'patterns',
 ',',
 'cleanses',
 'the',
 'data',
 ',',
 'and',
 'unlocks',
 'innovation',
 '.',
 'join',
 'this',
 'adventure',
 'to',
 'unleash',
 'the',
 'power',
 'of',
 'data-driven',
 'decisions',
 '.']

In [4]:
threshold = 1
# Remove rare words and print common tokens
freq_dist = FreqDist(tokens)
common_tokens = [token for token in tokens if freq_dist[token] > threshold]
print(common_tokens)

['the', 'of', ',', 'data', 'alex', 'data', '.', ',', 'alex', ',', 'the', 'data', ',', '.', 'the', 'of', '.']


In [5]:
text = 'The moor is very sparsely inhabited, and those who live near each other are thrown very much together. For this reason I saw a good deal of Sir Charles Baskerville. With the exception of Mr. Frankland, of Lafter Hall, and Mr. Stapleton, the naturalist, there are no other men of education within many miles. Sir Charles was a retiring man, but the chance of his illness brought us together, and a community of interests in science kept us so. He had brought back much scientific information from South Africa, and many a charming evening we have spent together discussing the comparative anatomy of the Bushman and the Hottentot.'

In [6]:
# Initialize and tokenize the text
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(text)

tokens[:10]

['the',
 'moor',
 'is',
 'very',
 'sparsely',
 'inhabited',
 ',',
 'and',
 'those',
 'who']

In [7]:
# Remove any stopwords
stop_words = set(stopwords.words("english"))
filtered_tokens = [token for token in tokens if token.lower() not in stop_words]

filtered_tokens

['moor',
 'sparsely',
 'inhabited',
 ',',
 'live',
 'near',
 'thrown',
 'much',
 'together',
 '.',
 'reason',
 'saw',
 'good',
 'deal',
 'sir',
 'charles',
 'baskerville',
 '.',
 'exception',
 'mr',
 '.',
 'frankland',
 ',',
 'lafter',
 'hall',
 ',',
 'mr',
 '.',
 'stapleton',
 ',',
 'naturalist',
 ',',
 'men',
 'education',
 'within',
 'many',
 'miles',
 '.',
 'sir',
 'charles',
 'retiring',
 'man',
 ',',
 'chance',
 'illness',
 'brought',
 'us',
 'together',
 ',',
 'community',
 'interests',
 'science',
 'kept',
 'us',
 '.',
 'brought',
 'back',
 'much',
 'scientific',
 'information',
 'south',
 'africa',
 ',',
 'many',
 'charming',
 'evening',
 'spent',
 'together',
 'discussing',
 'comparative',
 'anatomy',
 'bushman',
 'hottentot',
 '.']

In [8]:
# Perform stemming on the filtered tokens
stemmer = PorterStemmer()
stemmed_tokens = [stemmer.stem(token) for token in filtered_tokens]
print(stemmed_tokens)

['moor', 'spars', 'inhabit', ',', 'live', 'near', 'thrown', 'much', 'togeth', '.', 'reason', 'saw', 'good', 'deal', 'sir', 'charl', 'baskervil', '.', 'except', 'mr', '.', 'frankland', ',', 'lafter', 'hall', ',', 'mr', '.', 'stapleton', ',', 'naturalist', ',', 'men', 'educ', 'within', 'mani', 'mile', '.', 'sir', 'charl', 'retir', 'man', ',', 'chanc', 'ill', 'brought', 'us', 'togeth', ',', 'commun', 'interest', 'scienc', 'kept', 'us', '.', 'brought', 'back', 'much', 'scientif', 'inform', 'south', 'africa', ',', 'mani', 'charm', 'even', 'spent', 'togeth', 'discuss', 'compar', 'anatomi', 'bushman', 'hottentot', '.']


In [9]:
genres = ['Fiction','Non-fiction','Biography', 'Children','Mystery']

# Define the size of the vocabulary
vocab_size = len(genres)

# Create one-hot vectors
one_hot_vectors = torch.eye(vocab_size)

# Create a dictionary mapping genres to their one-hot vectors
one_hot_dict = {genre: one_hot_vectors[i] for i, genre in enumerate(genres)}

one_hot_dict

{'Fiction': tensor([1., 0., 0., 0., 0.]),
 'Non-fiction': tensor([0., 1., 0., 0., 0.]),
 'Biography': tensor([0., 0., 1., 0., 0.]),
 'Children': tensor([0., 0., 0., 1., 0.]),
 'Mystery': tensor([0., 0., 0., 0., 1.])}

In [10]:
for genre, vector in one_hot_dict.items():
    print(f'{genre}: {vector.numpy()}')

Fiction: [1. 0. 0. 0. 0.]
Non-fiction: [0. 1. 0. 0. 0.]
Biography: [0. 0. 1. 0. 0.]
Children: [0. 0. 0. 1. 0.]
Mystery: [0. 0. 0. 0. 1.]


In [11]:
# Import from sklearn
from sklearn.feature_extraction.text import CountVectorizer

titles = ['The Great Gatsby','To Kill a Mockingbird','1984','The Catcher in the Rye','The Hobbit', 'Great Expectations']

# Initialize Bag-of-words with the list of book titles
vectorizer = CountVectorizer()
bow_encoded_titles = vectorizer.fit_transform(titles)

bow_encoded_titles

<6x12 sparse matrix of type '<class 'numpy.int64'>'
	with 15 stored elements in Compressed Sparse Row format>

In [12]:
# # Extract and print the first five features
# print(vectorizer.get_feature_names_out()[:5])
# print(bow_encoded_titles.toarray()[0, :5])

In [13]:
# features and occurrencies
print(vectorizer.get_feature_names_out())
print(bow_encoded_titles.toarray())

['1984' 'catcher' 'expectations' 'gatsby' 'great' 'hobbit' 'in' 'kill'
 'mockingbird' 'rye' 'the' 'to']
[[0 0 0 1 1 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 1 1 0 0 1]
 [1 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 1 0 0 1 2 0]
 [0 0 0 0 0 1 0 0 0 0 1 0]
 [0 0 1 0 1 0 0 0 0 0 0 0]]


In [14]:
descriptions = ['A portrait of the Jazz Age in all of its decadence and excess.',
 'A gripping, heart-wrenching, and wholly remarkable tale of coming-of-age in a South poisoned by virulent prejudice.',
 'A startling and haunting vision of the world.',
 'A story of lost innocence.',
 'A timeless adventure story.']

In [15]:
# Importing TF-IDF from sklearn
from sklearn.feature_extraction.text import TfidfVectorizer

# Initialize TF-IDF encoding vectorizer
vectorizer = TfidfVectorizer()
tfidf_encoded_descriptions = vectorizer.fit_transform(descriptions)

tfidf_encoded_descriptions

<5x32 sparse matrix of type '<class 'numpy.float64'>'
	with 41 stored elements in Compressed Sparse Row format>

In [16]:
# # Extract and print the first five features
# print(vectorizer.get_feature_names_out()[:5])
# print(tfidf_encoded_descriptions.toarray()[0, :5])

In [17]:
print(vectorizer.get_feature_names_out())
print(tfidf_encoded_descriptions.toarray())

['adventure' 'age' 'all' 'and' 'by' 'coming' 'decadence' 'excess'
 'gripping' 'haunting' 'heart' 'in' 'innocence' 'its' 'jazz' 'lost' 'of'
 'poisoned' 'portrait' 'prejudice' 'remarkable' 'south' 'startling'
 'story' 'tale' 'the' 'timeless' 'virulent' 'vision' 'wholly' 'world'
 'wrenching']
[[0.         0.25943581 0.321564   0.21535516 0.         0.
  0.321564   0.321564   0.         0.         0.         0.25943581
  0.         0.321564   0.321564   0.         0.36232709 0.
  0.321564   0.         0.         0.         0.         0.
  0.         0.25943581 0.         0.         0.         0.
  0.         0.        ]
 [0.         0.20817488 0.         0.17280396 0.2580274  0.2580274
  0.         0.         0.2580274  0.         0.2580274  0.20817488
  0.         0.         0.         0.         0.29073627 0.2580274
  0.         0.2580274  0.2580274  0.2580274  0.         0.
  0.2580274  0.         0.         0.2580274  0.         0.2580274
  0.         0.2580274 ]
 [0.         0.       

In [18]:
# Create a list of stopwords
stop_words = set(stopwords.words("english"))

# Initialize the tokenizer and stemmer
tokenizer = get_tokenizer("basic_english")
stemmer = PorterStemmer() 

# Complete the function to preprocess sentences
def preprocess_sentences(sentences):
    processed_sentences = []
    for sentence in sentences:
        sentence = sentence.lower()
        tokens = tokenizer(sentence)
        tokens = [token for token in tokens if token not in stop_words]
        tokens = [stemmer.stem(token) for token in tokens]
        processed_sentences.append(' '.join(tokens))
    return processed_sentences


In [19]:
with open('shakespeare.txt', 'r', encoding='utf-8') as file:
    shakespeare = file.read().split('.')


In [20]:
shakespeare[:5]

['The Project Gutenberg eBook of The Complete Works of William Shakespeare, by William Shakespeare\n\nThis eBook is for the use of anyone anywhere in the United States and\nmost other parts of the world at no cost and with almost no restrictions\nwhatsoever',
 ' You may copy it, give it away or re-use it under the terms\nof the Project Gutenberg License included with this eBook or online at\nwww',
 'gutenberg',
 'org',
 ' If you are not located in the United States, you\nwill have to check the laws of the country where you are located before\nusing this eBook']

In [21]:
processed_shakespeare = preprocess_sentences(shakespeare)
print(processed_shakespeare[:5]) 

['project gutenberg ebook complet work william shakespear , william shakespear ebook use anyon anywher unit state part world cost almost restrict whatsoev', 'may copi , give away re-us term project gutenberg licens includ ebook onlin www', 'gutenberg', 'org', 'locat unit state , check law countri locat use ebook']


In [22]:
# Import libraries
from torch.utils.data import Dataset, DataLoader

# Define your Dataset class
class ShakespeareDataset(Dataset):
    # it is mandatory to define these three methods when you extend the Dataset class in PyTorch.
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]


In [23]:
# threshold to remove all tokens that appear less than or equal to threshold times in the sentence

filter_threshod = 1

In [24]:
def preprocess_sentences(sentences):
    processed_sentences = []
    for sentence in sentences:
        sentence = sentence.lower()
        tokens = tokenizer(sentence)
        tokens = [token for token in tokens if token not in stop_words]
        tokens = [stemmer.stem(token) for token in tokens]
        freq_dist = FreqDist(tokens)
        threshold = filter_threshod
        tokens = [token for token in tokens if freq_dist[token] > threshold]
        processed_sentences.append(' '.join(tokens))
    return processed_sentences


In [25]:
def encode_sentences(sentences):
    vectorizer = CountVectorizer()
    X = vectorizer.fit_transform(sentences)
    encoded_sentences = X.toarray()
    return encoded_sentences, vectorizer


In [26]:
def extract_sentences(data):
    sentences = re.findall(r'[A-Z][^.!?]*[.!?]', data)
    return sentences


In [27]:
# Complete the text processing pipeline
def text_processing_pipeline(sentences):
    processed_sentences = preprocess_sentences(sentences)
    encoded_sentences, vectorizer = encode_sentences(processed_sentences)
    dataset = ShakespeareDataset(encoded_sentences)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    return dataloader, vectorizer


In [28]:
dataloader, vectorizer = text_processing_pipeline(processed_shakespeare)

# Print the vectorizer's feature names and the first 100 components of the first item
print(vectorizer.get_feature_names_out()[:100]) 
next(iter(dataloader))[0, :100]


['_benedictu' '_ergo_' '_hic' '_le' '_molli' '_mulier_' '_solus_' '_the'
 '_to' 'aaron' 'abat' 'abbot' 'abhor' 'abhorson' 'abl' 'aboard' 'abu'
 'access' 'accommod' 'accost' 'account' 'accu' 'accur' 'achil' 'acquaint'
 'act' 'action' 'actor' 'ad' 'adam' 'adder' 'adieu' 'admir' 'adoni'
 'adramadio' 'adulteri' 'advanc' 'advantag' 'advic' 'aemiliu' 'aenea'
 'aer_' 'afterward' 'agamemnon' 'age' 'agreement' 'agrippa' 'ah' 'aid'
 'aim' 'air' 'ajax' 'ala' 'alack' 'alban' 'alcibiad' 'alençon' 'alexa'
 'alexand' 'alia' 'alik' 'aliv' 'all' 'allegi' 'allon' 'allow' 'almost'
 'alon' 'along' 'alow' 'altar' 'alter' 'although' 'altogeth' 'amaz'
 'amber' 'ambiti' 'amen' 'amend' 'amiss' 'amiti' 'among' 'amurath'
 'anchor' 'ancient' 'andrew' 'andronicu' 'angel' 'angelo' 'angl' 'angri'
 'ann' 'anon' 'anoth' 'another' 'answer' 'antigonu' 'antipholu' 'antoni'
 'antonio']


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0])

In [29]:
# Map a unique index to each word
words = ["This", "book", "was", "fantastic", "I", "really", "love", "science", "fiction", "but", "the", "protagonist", "was", "rude", "sometimes"]

word_to_idx = {word: i for i, word in enumerate(words)}

word_to_idx

{'This': 0,
 'book': 1,
 'was': 12,
 'fantastic': 3,
 'I': 4,
 'really': 5,
 'love': 6,
 'science': 7,
 'fiction': 8,
 'but': 9,
 'the': 10,
 'protagonist': 11,
 'rude': 13,
 'sometimes': 14}

In [30]:
# Convert word_to_idx to a tensor
inputs = torch.LongTensor([word_to_idx[w] for w in words])

inputs

tensor([ 0,  1, 12,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])

In [31]:
import torch
import torch.nn as nn

# Define the embedding layer with the number of embeddings and the size of each embedding
embedding = nn.Embedding(15, 10) # creates an embedding of 10 values (embedding dimension) for each of the 15 words (number of embeddings)

# Pass the tensor to the embedding layer 
output = embedding(inputs)
print(output)

tensor([[ 2.9155e-01, -4.2846e-02,  7.4626e-01, -7.2732e-01,  4.1755e-01,
          5.8797e-01,  2.9430e-01,  1.0333e+00, -3.7366e-02, -3.5645e-01],
        [ 4.4465e-04,  2.6341e+00, -6.0215e-01, -6.1211e-01, -1.1978e+00,
         -1.1391e+00, -1.6949e+00,  8.8291e-01,  1.6486e+00, -8.7134e-01],
        [-3.5535e-01, -6.2705e-01, -1.0797e+00,  1.1430e+00,  1.0941e+00,
         -8.2658e-01, -1.0178e-01, -5.2812e-01,  2.9619e-01, -1.5035e+00],
        [ 1.2312e+00, -9.7742e-01, -1.3229e-01,  7.1796e-01,  1.7844e+00,
         -1.0645e+00,  7.5272e-01,  4.2596e-01, -7.8369e-01,  3.3566e-01],
        [ 1.5683e+00,  1.6907e-02,  1.3010e+00, -1.8566e+00, -4.3566e-01,
          2.0464e-01, -3.7057e-01,  1.3641e+00, -3.5580e-01, -1.4706e+00],
        [ 2.8564e-01, -4.7004e-01, -1.5656e+00, -5.8641e-01, -7.5532e-01,
          1.2489e+00, -1.1991e+00,  5.0140e-01,  2.3359e-01, -1.3675e-01],
        [-1.1918e+00, -1.7489e+00, -1.5809e+00, -2.8796e-01,  5.5131e-01,
          1.3875e-01,  1.0814e+0

In [32]:
# data labeled positively 1 or negatively 0

data = [(['I', 'love', 'this', 'first','book'], 1),
     (['This', 'is', 'an', 'amazing', 'novel'], 1),
     (['I', 'really', 'like', 'this', 'story'], 1),
     (['I', 'do', 'not', 'like', 'this', 'book'], 0),
     (['I', 'hate', 'this', 'novel'], 0),
     (['This', 'is', 'a', 'terrible', 'story'], 0)]

In [33]:
# unique_words = []

# for sub_list, _ in data:
#     for word in sub_list:
#         unique_words.append(word)

# list(set(unique_words))

unique_words = list(set(word for sub_list, _ in data for word in sub_list))

vocab_size = len(unique_words)

vocab_size, unique_words

(18,
 ['love',
  'story',
  'I',
  'really',
  'this',
  'amazing',
  'like',
  'is',
  'an',
  'hate',
  'terrible',
  'not',
  'do',
  'a',
  'novel',
  'first',
  'book',
  'This'])

In [35]:
import torch.nn.functional as F

class TextClassificationCNN(nn.Module):
    
    def __init__(self, vocab_size, embed_dim):
        super(TextClassificationCNN, self).__init__() # initializes the base class nn.Module
        # Initialize the embedding layer 
        self.embedding = nn.Embedding(vocab_size, embed_dim) # Transforms input words into dense vectors of a fixed size (embed_dim).
        self.conv = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=1, padding=1) # Applies a 1D convolution over the embedding vectors to detect patterns.
        self.fc = nn.Linear(embed_dim, 2) # binary classification
        
    def forward(self, text):
        embedded = self.embedding(text).permute(0, 2, 1) # reshaping the tensor to match the expected input format of the convolutional layer, inverting position 2 with position 1, changing from [batch_size, sequence_length, embed_dim] to [batch_size, embed_dim, sequence_length].
        # Pass the embedded text through the convolutional layer and apply a ReLU
        conved = F.relu(self.conv(embedded)) # Passes the embeddings through the convolution layer and applies the ReLU activation function to extract features.
        conved = conved.mean(dim=2) # Mean Pooling: Reduces the dimensions by averaging across one dimension.
        return self.fc(conved) # Outputs the logits for the two classes.

In [34]:
import torch.nn as nn

embed_dim = 10

# Initialize embedding layer with input and output dimensions
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)

embedding

Embedding(18, 10)

In [36]:
word_to_ix = {word: i for i, word in enumerate(unique_words)}

word_to_ix

{'love': 0,
 'story': 1,
 'I': 2,
 'really': 3,
 'this': 4,
 'amazing': 5,
 'like': 6,
 'is': 7,
 'an': 8,
 'hate': 9,
 'terrible': 10,
 'not': 11,
 'do': 12,
 'a': 13,
 'novel': 14,
 'first': 15,
 'book': 16,
 'This': 17}

In [37]:
# each word w in the sentence is replaced by its index in word_to_ix. If a word is not found in word_to_ix, the get method returns 0 as a default value. The result is a list of indices that represent the sentence (applied later in the model training)

torch.LongTensor([word_to_ix.get(w, 0) for sentence, _ in data for w in sentence]).unsqueeze(0) # gets the indexed in the word_to_ix vocabulary of the words in the data

tensor([[ 2,  0,  4, 15, 16, 17,  7,  8,  5, 14,  2,  3,  6,  4,  1,  2, 12, 11,
          6,  4, 16,  2,  9,  4, 14, 17,  7, 13, 10,  1]])

In [38]:
model = TextClassificationCNN(vocab_size, embed_dim)

model

TextClassificationCNN(
  (embedding): Embedding(18, 10)
  (conv): Conv1d(10, 10, kernel_size=(3,), stride=(1,), padding=(1,))
  (fc): Linear(in_features=10, out_features=2, bias=True)
)

In [39]:
import torch.optim as optim

# Define the loss function
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(10):
    for sentence, label in data:     
        # Clear the gradients
        model.zero_grad()
        sentence = torch.LongTensor([word_to_ix.get(w, 0) for w in sentence]).unsqueeze(0) # provides the indexed in the word_to_ix vocabulary of the sentences in data (sentences encoding). Necessary before applying the self.embedding to the text in the model
        label = torch.LongTensor([int(label)])
        outputs = model(sentence)
        loss = criterion(outputs, label)
        loss.backward()
        # print(loss)
        # Update the parameters
        optimizer.step()
        print(sentence, outputs, label) # just last one (outputs are logits)
print('Training complete!')

tensor([[ 2,  0,  4, 15, 16]]) tensor([[-0.3550,  0.2004]], grad_fn=<AddmmBackward0>) tensor([1])
tensor([[17,  7,  8,  5, 14]]) tensor([[-0.3968,  0.0740]], grad_fn=<AddmmBackward0>) tensor([1])
tensor([[2, 3, 6, 4, 1]]) tensor([[-0.3561,  0.3760]], grad_fn=<AddmmBackward0>) tensor([1])
tensor([[ 2, 12, 11,  6,  4, 16]]) tensor([[-0.4681,  0.3257]], grad_fn=<AddmmBackward0>) tensor([0])
tensor([[ 2,  9,  4, 14]]) tensor([[-0.2896,  0.1875]], grad_fn=<AddmmBackward0>) tensor([0])
tensor([[17,  7, 13, 10,  1]]) tensor([[-0.1506, -0.0615]], grad_fn=<AddmmBackward0>) tensor([0])
tensor([[ 2,  0,  4, 15, 16]]) tensor([[-0.2319,  0.0779]], grad_fn=<AddmmBackward0>) tensor([1])
tensor([[17,  7,  8,  5, 14]]) tensor([[-0.3221, -0.0128]], grad_fn=<AddmmBackward0>) tensor([1])
tensor([[2, 3, 6, 4, 1]]) tensor([[-0.2911,  0.3110]], grad_fn=<AddmmBackward0>) tensor([1])
tensor([[ 2, 12, 11,  6,  4, 16]]) tensor([[-0.3139,  0.1654]], grad_fn=<AddmmBackward0>) tensor([0])
tensor([[ 2,  9,  4, 14]])

In [40]:
book_reviews = [
    "I love this book".split(),
    "I do not like this book".split()
]

book_reviews

[['I', 'love', 'this', 'book'], ['I', 'do', 'not', 'like', 'this', 'book']]

In [41]:
for review in book_reviews:
    # Convert the review words into tensor form
    input_tensor = torch.tensor([word_to_ix[w] for w in review], dtype=torch.long).unsqueeze(0) 
    print('input_tensor',input_tensor)
    
    # Get the model's output
    outputs = model(input_tensor)
    print('outputs',outputs)
    # Find the index of the most likely sentiment category
    _, predicted_label = torch.max(outputs.data, 1)
    print('predicted_label',predicted_label)
    
    # Convert the predicted label into a sentiment string
    sentiment = "Positive" if predicted_label.item() == 1 else "Negative"
    print(f"Book Review: {' '.join(review)}")
    print(f"Sentiment: {sentiment}\n")

input_tensor tensor([[ 2,  0,  4, 16]])
outputs tensor([[-0.9109,  0.9031]], grad_fn=<AddmmBackward0>)
predicted_label tensor([1])
Book Review: I love this book
Sentiment: Positive

input_tensor tensor([[ 2, 12, 11,  6,  4, 16]])
outputs tensor([[ 0.7609, -0.8040]], grad_fn=<AddmmBackward0>)
predicted_label tensor([0])
Book Review: I do not like this book
Sentiment: Negative



In [42]:
# Complete the RNN class
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size # Sets the size of the hidden state in the RNN.
        self.num_layers = num_layers # Sets the number of stacked RNN layers.
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True) # When batch_first=True is set in an RNN, it means that the input and output tensors will have the shape [batch_size, sequence_length, input_size]. By default, PyTorch expects the input shape to be [sequence_length, batch_size, input_size]
        self.fc = nn.Linear(hidden_size, num_classes)        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # Initializes the hidden state with zeros. Its shape is [num_layers, batch_size, hidden_size]. If x has a shape of [batch_size, sequence_length, input_size], then: x.size(0) returns the batch_size, which is the number of sequences in the batch.
        out, _ = self.rnn(x, h0) # The RNN outputs out (all hidden states for all time steps) and _ (the final hidden state, which isn't used here).
        out = out[:, -1, :] # Selects the hidden state at the last time step for each sequence with out[:, -1, :]
        out = self.fc(out)
        return out


In [43]:
input_size = 6
hidden_size = 32
num_layers = 2
num_classes = 3

# Initialize the model
rnn_model = RNNModel(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_model.parameters(), lr=0.01)

rnn_model

RNNModel(
  (rnn): RNN(6, 32, num_layers=2, batch_first=True)
  (fc): Linear(in_features=32, out_features=3, bias=True)
)

In [44]:
X_train_seq = torch.tensor([[[0., 0., 0., 1., 0., 1.]],

        [[0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]]])

X_train_seq.shape

torch.Size([5, 1, 6])

In [45]:
y_train_seq = torch.tensor([2, 2, 2, 0, 2])

In [46]:
# Train the model for ten epochs and zero the gradients
for epoch in range(10): 
    optimizer.zero_grad()
    outputs = rnn_model(X_train_seq)
    loss = criterion(outputs, y_train_seq)
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 1.1248911619186401
Epoch: 2, Loss: 0.973309338092804
Epoch: 3, Loss: 0.8396531939506531
Epoch: 4, Loss: 0.7184687256813049
Epoch: 5, Loss: 0.6190042495727539
Epoch: 6, Loss: 0.5575698614120483
Epoch: 7, Loss: 0.5377610921859741
Epoch: 8, Loss: 0.5386293530464172
Epoch: 9, Loss: 0.5369535684585571
Epoch: 10, Loss: 0.5235108137130737


In [47]:
# Initialize the LSTM and the output layer with parameters
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)        
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
        out, _ = self.lstm(x, (h0, c0))
        out = out[:, -1, :] 
        out = self.fc(out)
        return out



In [48]:
# Initialize model with required parameters
lstm_model = LSTMModel(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(lstm_model.parameters(), lr=0.01)



In [49]:
# Train the model by passing the correct parameters and zeroing the gradient
for epoch in range(10): 
    optimizer.zero_grad()
    outputs = lstm_model(X_train_seq)
    loss = criterion(outputs, y_train_seq)
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 1.0723886489868164
Epoch: 2, Loss: 1.0383002758026123
Epoch: 3, Loss: 1.0030786991119385
Epoch: 4, Loss: 0.9654237031936646
Epoch: 5, Loss: 0.9242860674858093
Epoch: 6, Loss: 0.8789087533950806
Epoch: 7, Loss: 0.8290559649467468
Epoch: 8, Loss: 0.7754516005516052
Epoch: 9, Loss: 0.7201984524726868
Epoch: 10, Loss: 0.6669596433639526


In [50]:
# Complete the GRU model
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)       
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) 
        out, _ = self.gru(x, h0)
        out = out[:, -1, :] 
        out = self.fc(out)
        return out



In [51]:
# Initialize the model
gru_model = GRUModel(input_size, hidden_size, num_layers, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(gru_model.parameters(), lr=0.01)



In [52]:
# Train the model and backpropagate the loss after initialization
for epoch in range(15): 
    optimizer.zero_grad()
    outputs = gru_model(X_train_seq)
    loss = criterion(outputs, y_train_seq)
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

Epoch: 1, Loss: 1.0797009468078613
Epoch: 2, Loss: 1.0256575345993042
Epoch: 3, Loss: 0.9725996255874634
Epoch: 4, Loss: 0.9161521792411804
Epoch: 5, Loss: 0.8538802862167358
Epoch: 6, Loss: 0.7856529951095581
Epoch: 7, Loss: 0.7140219807624817
Epoch: 8, Loss: 0.644810140132904
Epoch: 9, Loss: 0.586836040019989
Epoch: 10, Loss: 0.549041748046875
Epoch: 11, Loss: 0.5341728329658508
Epoch: 12, Loss: 0.5343982577323914
Epoch: 13, Loss: 0.5372315049171448
Epoch: 14, Loss: 0.534077525138855
Epoch: 15, Loss: 0.5223044157028198


In [53]:
from torchmetrics import Accuracy, Precision, Recall, F1Score

# Create an instance of the metrics
accuracy = Accuracy(task="multiclass", num_classes=3)
precision = Precision(task="multiclass", num_classes=3)
recall = Recall(task="multiclass", num_classes=3)
f1 = F1Score(task="multiclass", num_classes=3)



In [54]:
X_test_seq = torch.tensor([[[0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 1., 1., 0.]],

        [[0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.]]])

X_test_seq.shape

torch.Size([6, 1, 6])

In [55]:
# Generate the predictions
outputs = rnn_model(X_test_seq)
_, predicted = torch.max(outputs, 1)

outputs,predicted

(tensor([[-0.2006, -2.3813,  1.7712],
         [-0.4732, -2.4382,  1.9004],
         [-0.2006, -2.3813,  1.7712],
         [-0.2006, -2.3813,  1.7712],
         [-0.4813, -2.6254,  2.0863],
         [-0.2006, -2.3813,  1.7712]], grad_fn=<AddmmBackward0>),
 tensor([2, 2, 2, 2, 2, 2]))

In [56]:
y_test_seq = torch.tensor([2, 1, 0, 1, 2, 2])

In [57]:
# Calculate the metrics
accuracy_score = accuracy(predicted, y_test_seq)
precision_score = precision(predicted, y_test_seq)
recall_score = recall(predicted, y_test_seq)
f1_score = f1(predicted, y_test_seq)
print("RNN Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_score, precision_score, recall_score, f1_score))

RNN Model - Accuracy: 0.5, Precision: 0.5, Recall: 0.5, F1 Score: 0.5


In [58]:
outputs = lstm_model(X_test_seq)
_, y_pred_lstm = torch.max(outputs, 1)

In [59]:

# Calculate metrics for the LSTM model
accuracy_1 = accuracy(y_pred_lstm, y_test_seq)
precision_1 = precision(y_pred_lstm, y_test_seq)
recall_1 = recall(y_pred_lstm, y_test_seq)
f1_1 = f1(y_pred_lstm, y_test_seq)
print("LSTM Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_1, precision_1, recall_1, f1_1))


LSTM Model - Accuracy: 0.5, Precision: 0.5, Recall: 0.5, F1 Score: 0.5


In [60]:
outputs = gru_model(X_test_seq)
_, y_pred_gru = torch.max(outputs, 1)

In [61]:
# Calculate metrics for the GRU model
accuracy_2 = accuracy(y_pred_gru, y_test_seq)
precision_2 = precision(y_pred_gru, y_test_seq)
recall_2 = recall(y_pred_gru, y_test_seq)
f1_2 = f1(y_pred_gru, y_test_seq)
print("GRU Model - Accuracy: {}, Precision: {}, Recall: {}, F1 Score: {}".format(accuracy_2, precision_2, recall_2, f1_2))

GRU Model - Accuracy: 0.5, Precision: 0.5, Recall: 0.5, F1 Score: 0.5


In [62]:
data = 'The rabbit-hole went straight on like a tunnel for some way, and then dipped suddenly down, so suddenly that Alice had not a moment to think about stopping herself before she found herself falling down a very deep well.'

In [63]:
# Include an RNN layer and linear layer in RNNmodel class

class RNNmodel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNmodel, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size)
        out, _ = self.rnn(x, h0)  
        out = self.fc(out[:, -1, :])  # out[:, -1, :] means “for each sequence in the batch, take the last hidden state” (In the context of RNNs, the last hidden state often contains information about the entire sequence because it has seen all the previous inputs in the sequence. This is why it’s commonly used for sequence classification tasks.)
        return out


In [64]:
chars = ['y', ',','b','w','a','d','e','p','o','k','n','v','s','T','l','m','.','g','c','h','u','f','-','A',' ','r','t','i']
len(chars)

28

In [65]:
# Instantiate the RNN model
model = RNNmodel(len(chars), 16, len(chars))
model

RNNmodel(
  (rnn): RNN(28, 16, batch_first=True)
  (fc): Linear(in_features=16, out_features=28, bias=True)
)

In [66]:
char_to_ix = {chars[i]:i for i in range(len(chars))}
char_to_ix

{'y': 0,
 ',': 1,
 'b': 2,
 'w': 3,
 'a': 4,
 'd': 5,
 'e': 6,
 'p': 7,
 'o': 8,
 'k': 9,
 'n': 10,
 'v': 11,
 's': 12,
 'T': 13,
 'l': 14,
 'm': 15,
 '.': 16,
 'g': 17,
 'c': 18,
 'h': 19,
 'u': 20,
 'f': 21,
 '-': 22,
 'A': 23,
 ' ': 24,
 'r': 25,
 't': 26,
 'i': 27}

In [67]:
# Instantiate the loss function
criterion = nn.CrossEntropyLoss()

# Instantiate the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    weight_decay: 0
)

In [68]:
inputs = torch.tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]],

        [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])


In [69]:
inputs.shape # batch_size is 218, sequence_length is 1, and number_of_features (characters) is 28

torch.Size([218, 1, 28])

In [70]:
targets = torch.tensor([19,  6, 24, 25,  4,  2,  2, 27, 26, 22, 19,  8, 14,  6, 24,  3,  6, 10,
        26, 24, 12, 26, 25,  4, 27, 17, 19, 26, 24,  8, 10, 24, 14, 27,  9,  6,
        24,  4, 24, 26, 20, 10, 10,  6, 14, 24, 21,  8, 25, 24, 12,  8, 15,  6,
        24,  3,  4,  0,  1, 24,  4, 10,  5, 24, 26, 19,  6, 10, 24,  5, 27,  7,
         7,  6,  5, 24, 12, 20,  5,  5,  6, 10, 14,  0, 24,  5,  8,  3, 10,  1,
        24, 12,  8, 24, 12, 20,  5,  5,  6, 10, 14,  0, 24, 26, 19,  4, 26, 24,
        23, 14, 27, 18,  6, 24, 19,  4,  5, 24, 10,  8, 26, 24,  4, 24, 15,  8,
        15,  6, 10, 26, 24, 26,  8, 24, 26, 19, 27, 10,  9, 24,  4,  2,  8, 20,
        26, 24, 12, 26,  8,  7,  7, 27, 10, 17, 24, 19,  6, 25, 12,  6, 14, 21,
        24,  2,  6, 21,  8, 25,  6, 24, 12, 19,  6, 24, 21,  8, 20, 10,  5, 24,
        19,  6, 25, 12,  6, 14, 21, 24, 21,  4, 14, 14, 27, 10, 17, 24,  5,  8,
         3, 10, 24,  4, 24, 11,  6, 25,  0, 24,  5,  6,  6,  7, 24,  3,  6, 14,
        14, 16])

targets.shape

torch.Size([218])

In [71]:
# Train the model
for epoch in range(100):
    model.train()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch+1) % 10 == 0:
        print(f'Epoch {epoch+1}/100, Loss: {loss.item()}')


Epoch 10/100, Loss: 2.976134777069092
Epoch 20/100, Loss: 2.677084445953369
Epoch 30/100, Loss: 2.4725914001464844
Epoch 40/100, Loss: 2.2726950645446777
Epoch 50/100, Loss: 2.1078569889068604
Epoch 60/100, Loss: 1.9847338199615479
Epoch 70/100, Loss: 1.9022361040115356
Epoch 80/100, Loss: 1.8479633331298828
Epoch 90/100, Loss: 1.8115605115890503
Epoch 100/100, Loss: 1.7871347665786743


In [72]:
print(outputs.shape) # outputs the logits of all possible characters

outputs

torch.Size([218, 28])


tensor([[-0.8395, -0.5993,  0.9741,  ...,  1.5718, -0.7799, -0.6828],
        [-1.5890, -3.1118, -0.4486,  ..., -1.5477,  2.6892,  2.7753],
        [-1.8929, -2.5851, -3.3455,  ...,  3.0360,  0.2332, -1.0040],
        ...,
        [-1.8929, -2.5851, -3.3455,  ...,  3.0360,  0.2332, -1.0040],
        [ 3.0976, -1.5643, -0.1192,  ..., -0.1872, -1.4737,  3.5292],
        [ 3.0976, -1.5643, -0.1192,  ..., -0.1872, -1.4737,  3.5292]],
       grad_fn=<AddmmBackward0>)

In [73]:
ix_to_char = {i:chars[i] for i in range(len(chars))}
ix_to_char

{0: 'y',
 1: ',',
 2: 'b',
 3: 'w',
 4: 'a',
 5: 'd',
 6: 'e',
 7: 'p',
 8: 'o',
 9: 'k',
 10: 'n',
 11: 'v',
 12: 's',
 13: 'T',
 14: 'l',
 15: 'm',
 16: '.',
 17: 'g',
 18: 'c',
 19: 'h',
 20: 'u',
 21: 'f',
 22: '-',
 23: 'A',
 24: ' ',
 25: 'r',
 26: 't',
 27: 'i'}

In [74]:
# Test the model
model.eval()
test_input = char_to_ix['r'] # 25
test_input = nn.functional.one_hot(torch.tensor(test_input).view(-1, 1), num_classes=len(chars)).float()
print(test_input.shape)
test_input

torch.Size([1, 1, 28])


tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]]])

In [75]:
predicted_output = model(test_input)
predicted_char_ix = torch.argmax(predicted_output, 1).item()
print(f"Test Input: 'r', Predicted Output: '{ix_to_char[predicted_char_ix]}'")
predicted_output, predicted_char_ix

Test Input: 'r', Predicted Output: 'a'


(tensor([[ 2.4213e+00, -1.5242e+00, -4.7428e-02, -1.0359e+00,  3.3683e+00,
           2.7207e-01,  1.3397e+00, -2.2774e+00, -1.5252e-03, -3.5113e+00,
           3.7883e-01, -1.6907e+00,  3.0331e+00, -2.5844e+00, -5.7443e-01,
          -2.2052e+00, -1.1477e+00, -3.6155e+00, -2.9645e+00, -1.0199e+00,
          -2.8242e+00,  4.3883e-01, -1.2937e+00, -1.3535e+00,  2.6196e+00,
           1.5258e-01,  7.7540e-02,  6.4389e-01]], grad_fn=<AddmmBackward0>),
 4)

In [76]:
import torch
import torch.nn as nn

# Define the generator class
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(seq_length,seq_length), nn.Sigmoid()) # Linear Layer: Takes input of size seq_length and outputs the same size (seq_length serves as both the size of the input noise vector and the size of the output generated data.).Sigmoid Activation: Squashes the output values to be between 0 and 1, that is particularly useful when the generated data needs to be in a specific range, such as pixel values for images (which range from 0 to 1)
    def forward(self, x):
        return self.model(x)

# Define the discriminator networks
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(seq_length, 1), nn.Sigmoid()) # Linear Layer: Takes input of size seq_length and outputs a single value. Sigmoid Activation: Outputs a value between 0 and 1, representing the probability that the input is real.
    def forward(self, x):
        return self.model(x)

In [77]:
seq_length = 5 #: Length of each synthetic data sequence
num_sequences = 100 #: Total number of sequences generated
num_epochs = 50 #: Number of complete passes through the dataset
print_every = 10 #: Output display frequency, showing results every 10 epochs

In [78]:
generator = Generator()
generator

Generator(
  (model): Sequential(
    (0): Linear(in_features=5, out_features=5, bias=True)
    (1): Sigmoid()
  )
)

In [79]:
discriminator = Discriminator()
discriminator

Discriminator(
  (model): Sequential(
    (0): Linear(in_features=5, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

In [80]:
# Define the loss function and optimizer
criterion = nn.BCELoss()

# Since Generator and discriminator inherit from nn.Module, they also inherits the parameters() method.
optimizer_gen = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizer_disc = torch.optim.Adam(discriminator.parameters(), lr=0.001)

In [81]:
data = torch.tensor([[0., 1., 0., 0., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 0., 1., 1., 0.],
                    [0., 1., 0., 0., 1.],
                    [0., 1., 1., 1., 0.],
                    [0., 0., 1., 0., 1.],
                    [0., 0., 0., 1., 1.],
                    [0., 1., 0., 1., 1.],
                    [1., 1., 1., 0., 0.],
                    [0., 1., 1., 1., 0.],
                    [1., 0., 0., 0., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 1., 1., 0., 1.],
                    [1., 0., 1., 1., 0.],
                    [0., 0., 1., 0., 0.],
                    [1., 0., 1., 0., 1.],
                    [0., 0., 0., 1., 0.],
                    [1., 1., 0., 1., 1.],
                    [1., 1., 1., 0., 0.],
                    [1., 1., 0., 0., 0.],
                    [0., 1., 0., 1., 1.],
                    [1., 1., 1., 0., 1.],
                    [1., 1., 0., 1., 0.],
                    [0., 1., 1., 1., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [1., 0., 0., 0., 0.],
                    [0., 1., 1., 0., 0.],
                    [1., 0., 0., 1., 0.],
                    [0., 1., 1., 1., 1.],
                    [1., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0.],
                    [1., 1., 0., 1., 0.],
                    [1., 0., 0., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [0., 0., 1., 0., 0.],
                    [0., 1., 0., 1., 0.],
                    [0., 1., 0., 0., 1.],
                    [1., 0., 1., 0., 1.],
                    [1., 0., 1., 1., 0.],
                    [1., 1., 1., 0., 0.],
                    [0., 0., 0., 0., 1.],
                    [1., 1., 1., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [1., 1., 0., 1., 1.],
                    [0., 1., 1., 0., 1.],
                    [0., 0., 1., 0., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 0., 1., 0., 1.],
                    [0., 0., 1., 0., 1.],
                    [1., 0., 1., 0., 1.],
                    [0., 0., 1., 1., 0.],
                    [1., 0., 0., 0., 1.],
                    [0., 1., 0., 1., 1.],
                    [1., 0., 1., 1., 0.],
                    [1., 1., 1., 0., 1.],
                    [0., 0., 0., 0., 1.],
                    [1., 1., 0., 1., 0.],
                    [1., 1., 1., 0., 0.],
                    [0., 1., 0., 1., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 1., 1., 1., 1.],
                    [1., 0., 0., 1., 1.],
                    [1., 0., 1., 0., 0.],
                    [0., 0., 0., 0., 1.],
                    [0., 1., 0., 1., 1.],
                    [1., 0., 1., 1., 0.],
                    [0., 0., 0., 1., 1.],
                    [0., 1., 1., 1., 1.],
                    [1., 0., 1., 1., 1.],
                    [0., 1., 0., 1., 0.],
                    [1., 0., 1., 0., 1.],
                    [1., 1., 0., 1., 0.],
                    [0., 0., 0., 1., 1.],
                    [0., 1., 1., 1., 0.],
                    [0., 0., 0., 0., 0.],
                    [0., 0., 1., 1., 1.],
                    [1., 1., 1., 0., 1.],
                    [1., 0., 0., 0., 0.],
                    [1., 1., 1., 1., 1.],
                    [0., 0., 1., 0., 1.],
                    [0., 1., 0., 1., 0.],
                    [1., 1., 1., 1., 1.],
                    [0., 0., 0., 0., 0.],
                    [1., 0., 0., 1., 0.],
                    [1., 0., 0., 1., 1.],
                    [0., 0., 0., 0., 0.],
                    [0., 1., 0., 1., 1.],
                    [0., 0., 0., 0., 0.],
                    [1., 1., 1., 1., 0.],
                    [1., 1., 0., 1., 0.],
                    [1., 1., 1., 0., 1.],
                    [1., 0., 0., 0., 1.],
                    [0., 1., 0., 1., 1.],
                    [0., 1., 0., 1., 0.],
                    [1., 0., 1., 0., 1.],
                    [1., 0., 1., 1., 0.],
                    [1., 0., 1., 1., 1.],
                    [0., 1., 1., 0., 1.]])

data.shape

torch.Size([100, 5])

In [82]:
test_data = torch.tensor([1., 0., 1., 1., 0.])
test_real = test_data.unsqueeze(0)
test_real

tensor([[1., 0., 1., 1., 0.]])

In [83]:
test_noise = torch.rand((1, seq_length))
print(test_noise)
test_noise = generator(test_noise)
test_noise.detach()

tensor([[0.1546, 0.5740, 0.6748, 0.7829, 0.6753]])


tensor([[0.3982, 0.5545, 0.3936, 0.2869, 0.4244]])

In [84]:
disc_test_real = discriminator(test_real)
disc_test_noise = discriminator(test_noise.detach())
disc_test_real, disc_test_noise # a output close to 1 represents a real image, while an output closer to 0 represents a fake image

(tensor([[0.3595]], grad_fn=<SigmoidBackward0>),
 tensor([[0.4690]], grad_fn=<SigmoidBackward0>))

In [85]:
torch.ones_like(disc_test_real), torch.zeros_like(disc_test_noise) # expected outputs (1:real, 0:fake)

(tensor([[1.]]), tensor([[0.]]))

In [86]:
# Start of the training process
for epoch in range(num_epochs):
    # For each epoch, the model goes through the entire dataset
    for real_data in data:
        # The real data is unsqueezed to add an extra dimension: unsqueeze(0) is used to create a batch with size 1 from the single sample of real_data. This ensures that the data can be correctly processed by the discriminator and generator, which are designed to handle batches of data.
        real_data = real_data.unsqueeze(0)
        # A random noise vector is generated
        noise = torch.rand((1, seq_length))
        # This noise vector is then passed through the generator to create the fake data
        fake_data = generator(noise)
        # The discriminator is then used to classify the real and fake data (outputs '1' if considers the input real, and outputs '0' if considers the input fake: ideally, disc_real should be always classified to 1, and disc_fake should be always classified to 0)
        disc_real = discriminator(real_data)
        disc_fake = discriminator(fake_data.detach()) # we want to use the output of the generator 'fake_data' to feed into the discriminator but don’t want these operations to influence the gradients of the generator (because here we're training the discriminator, and not the generator)
        # The loss for the discriminator is calculated as the sum of the losses for the real and fake data
        loss_disc = criterion(disc_real, torch.ones_like(disc_real)) + criterion(disc_fake, torch.zeros_like(disc_fake))
        # The gradients are then backpropagated through the discriminator
        optimizer_disc.zero_grad()
        loss_disc.backward()
        # The discriminator's weights are updated
        optimizer_disc.step()

        # The generator is trained here
        disc_fake = discriminator(fake_data) # In this case (when we train the generator), we need to calculate gradients for the generator because we want to update its parameters to minimize this loss. So, we don’t use detach() when passing the fake data through the discriminator.
        # The loss for the generator is calculated
        loss_gen = criterion(disc_fake, torch.ones_like(disc_fake)) # the generator loss is small if it's able to fool the discriminator (so if the outputs of the discriminator for the fake data is 1)
        # The gradients are backpropagated through the generator
        optimizer_gen.zero_grad()
        loss_gen.backward()
        # The generator's weights are updated
        optimizer_gen.step()

    # The losses for the generator and the discriminator are printed every 'print_every' epochs
    if (epoch+1) % print_every == 0:
        print(f"Epoch {epoch+1}/{num_epochs}:\t Generator loss: {loss_gen.item()}\t Discriminator loss: {loss_disc.item()}")


Epoch 10/50:	 Generator loss: 0.6854485273361206	 Discriminator loss: 1.424321174621582
Epoch 20/50:	 Generator loss: 0.6872369050979614	 Discriminator loss: 1.4275445938110352
Epoch 30/50:	 Generator loss: 0.6769758462905884	 Discriminator loss: 1.4055067300796509
Epoch 40/50:	 Generator loss: 0.6922941207885742	 Discriminator loss: 1.4285491704940796
Epoch 50/50:	 Generator loss: 0.6925503611564636	 Discriminator loss: 1.4118452072143555


In [87]:
# print real data
print("\nReal data: ")
print(data[:5])

# print generated fake data, by the trained generator
print("\nGenerated data: ")
for _ in range(5):
    noise = torch.rand((1, seq_length))
    generated_data = generator(noise)
    # The generated data is detached from its computation graph and rounded before printing
    print(noise)
    print(torch.round(generated_data).detach())



Real data: 
tensor([[0., 1., 0., 0., 1.],
        [0., 0., 1., 0., 1.],
        [0., 0., 1., 1., 0.],
        [0., 1., 0., 0., 1.],
        [0., 1., 1., 1., 0.]])

Generated data: 
tensor([[0.9740, 0.5491, 0.9054, 0.0567, 0.3241]])
tensor([[0., 0., 1., 1., 1.]])
tensor([[0.3442, 0.5166, 0.1226, 0.4327, 0.9537]])
tensor([[0., 0., 1., 1., 1.]])
tensor([[0.0865, 0.1589, 0.0019, 0.7451, 0.0364]])
tensor([[0., 0., 1., 1., 0.]])
tensor([[0.7152, 0.3166, 0.8678, 0.7232, 0.9533]])
tensor([[0., 0., 1., 1., 1.]])
tensor([[0.8748, 0.9489, 0.2699, 0.1841, 0.7344]])
tensor([[0., 0., 1., 1., 1.]])


In [88]:
from transformers import GPT2Tokenizer,GPT2LMHeadModel 

# Initialize the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Initialize the pre-trained model
model = GPT2LMHeadModel.from_pretrained('gpt2')

seed_text = "I am usually called"

# Encode the seed text to get input tensors
input_ids = tokenizer.encode(seed_text, return_tensors='pt')

# Generate text from the model
output = model.generate(input_ids, max_length=100, temperature=0.7, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id) 

generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)



I am usually called a "good guy" by my friends and family. I am a good guy, and I'm not a bad guy.

I'm a very good person. And I don't want to be a jerk. But I do want people to know that I have a lot of respect for them. That I care about them, that they care, because I know they're going to love me. They're not going away.


In [89]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Initalize tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

input_prompt = "translate English to French: 'Hello, give me the pen'"

# Encode the input prompt using the tokenizer
input_ids = tokenizer.encode(input_prompt, return_tensors="pt")

# Generate the translated ouput
output = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated text:",generated_text)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Generated text: "Jo, donnez-moi le stylo"


In [90]:
from torchmetrics.text import BLEUScore, ROUGEScore

reference_text = "Once upon a time, there was a little girl who lived in a village near the forest."
generated_text = "Once upon a time, the world was a place of great beauty and great danger. The world of the gods was the place where the great gods were born, and where they were to live."

# Initialize BLEU and ROUGE scorers
bleu = BLEUScore()
rouge = ROUGEScore()

# Calculate the BLEU and ROUGE scores
bleu_score = bleu([generated_text], [[reference_text]])
rouge_score = rouge([generated_text], [[reference_text]])

# Print the BLEU and ROUGE scores
print("BLEU Score:", bleu_score.item())
print("ROUGE Score:", rouge_score)

BLEU Score: 0.08170417696237564
ROUGE Score: {'rouge1_fmeasure': tensor(0.2692), 'rouge1_precision': tensor(0.2000), 'rouge1_recall': tensor(0.4118), 'rouge2_fmeasure': tensor(0.1600), 'rouge2_precision': tensor(0.1176), 'rouge2_recall': tensor(0.2500), 'rougeL_fmeasure': tensor(0.2692), 'rougeL_precision': tensor(0.2000), 'rougeL_recall': tensor(0.4118), 'rougeLsum_fmeasure': tensor(0.2692), 'rougeLsum_precision': tensor(0.2000), 'rougeLsum_recall': tensor(0.4118)}


In [91]:
from transformers import BertTokenizer, BertForSequenceClassification

# Load the BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

model

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [92]:
import torch

# Setup the optimizer using model parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)
optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 1e-05
    maximize: False
    weight_decay: 0.01
)

In [93]:
texts = ['I love this!',
 'This is terrible.',
 'Amazing experience!',
 'Not my cup of tea.']

labels = [1, 0, 1, 0]

In [94]:
# Tokenize your data and return PyTorch tensors
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=32)
inputs["labels"] = torch.tensor(labels)

inputs # The tokenizer also adds these special tokens ([CLS] and [SEP]) to each sentence, which is why the maximum number of tokens is 8 (in the sencence 'Not my cup of tea.' --> [ 101, 2025, 2026, 2452, 1997, 5572, 1012,  102]). 

{'input_ids': tensor([[ 101, 1045, 2293, 2023,  999,  102,    0,    0],
        [ 101, 2023, 2003, 6659, 1012,  102,    0,    0],
        [ 101, 6429, 3325,  999,  102,    0,    0,    0],
        [ 101, 2025, 2026, 2452, 1997, 5572, 1012,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([1, 0, 1, 0])}

In [95]:
model.train()
for epoch in range(3):
    outputs = model(**inputs) # during the training, the model takes the inputs, including the labels
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}, Outputs: {outputs}")

Epoch: 1, Loss: 0.6837122440338135, Outputs: SequenceClassifierOutput(loss=tensor(0.6837, grad_fn=<NllLossBackward0>), logits=tensor([[-0.3516, -0.3675],
        [-0.4386, -0.2354],
        [-0.6778, -0.0728],
        [-0.4461, -0.2462]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
Epoch: 2, Loss: 0.571395993232727, Outputs: SequenceClassifierOutput(loss=tensor(0.5714, grad_fn=<NllLossBackward0>), logits=tensor([[-0.3580, -0.2193],
        [-0.3531, -0.3552],
        [-0.6342,  0.0609],
        [-0.2216, -0.5026]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
Epoch: 3, Loss: 0.6573758125305176, Outputs: SequenceClassifierOutput(loss=tensor(0.6574, grad_fn=<NllLossBackward0>), logits=tensor([[-0.2009, -0.1965],
        [-0.2347, -0.1050],
        [-0.2827,  0.0024],
        [-0.0879, -0.2450]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [96]:
text = "I had a bad day!"

# Tokenize the text and return PyTorch tensors
input_eval = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=32)
print('input_eval',input_eval)
outputs_eval = model(**input_eval)
print('outputs_eval',outputs_eval)


input_eval {'input_ids': tensor([[ 101, 1045, 2018, 1037, 2919, 2154,  999,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
outputs_eval SequenceClassifierOutput(loss=None, logits=tensor([[-0.3731, -0.0702]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


In [97]:
outputs_eval.logits

tensor([[-0.3731, -0.0702]], grad_fn=<AddmmBackward0>)

In [98]:
# Convert the output logits to probabilities (with sigmoid, not summing up to 1)
predictions = torch.nn.functional.sigmoid(outputs_eval.logits)

predictions

tensor([[0.4078, 0.4824]], grad_fn=<SigmoidBackward0>)

In [99]:
# Convert the output logits to probabilities (with softmax, summing up to 1)
predictions = torch.nn.functional.softmax(outputs_eval.logits, dim=-1)

predictions

tensor([[0.4248, 0.5752]], grad_fn=<SoftmaxBackward0>)

In [100]:
# Display the sentiments
predicted_label = 'positive' if torch.argmax(predictions) > 0 else 'negative'
print(f"Text: {text}\nSentiment: {predicted_label}")

Text: I had a bad day!
Sentiment: positive


In [101]:
import torch.nn as nn

class TransformerEncoder(nn.Module):
    
    def __init__(self, embed_size, heads, num_layers, dropout): # heads is the number of parallel attention layers (or “heads”) in the multi-head attention mechanism. Each head learns a different type of attention and then the model combines the results from all heads.
        super(TransformerEncoder, self).__init__()
        # Initialize the encoder 
        self.encoder = nn.TransformerEncoder( # nn.TransformerEncoder Stacks num_layers of the below TransformerEncoderLayer to form the complete Transformer encoder
            nn.TransformerEncoderLayer(d_model=embed_size, nhead=heads), # Creates a single layer of the Transformer encoder, which uses multi-head self-attention with embed_size dimensions and heads parallel attention layers.
            num_layers=num_layers)
        
        # Define the fully connected layer
        self.fc = nn.Linear(embed_size, 2)

    def forward(self, x):
        # Pass the input through the transformer encoder 
        x = self.encoder(x)
        x = x.mean(dim=1) # mean often calculated when you want to convert a sequence output (for example, a sequence of word embeddings) into a single vector that can be used for classification. The mean operation effectively captures the ‘average’ representation of all the words in the sequence.
        return self.fc(x)



In [102]:
model = TransformerEncoder(embed_size=512, heads=8, num_layers=3, dropout=0.3)

model



TransformerEncoder(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=512, out_features=2, bias=True)
)

In [103]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.05
    maximize: False
    weight_decay: 0
)

In [104]:
import random

# Base positive and negative phrases
positive_bases = [
    "I love", "This is amazing", "Fantastic", "Highly recommend", 
    "Best", "Absolutely wonderful", "Exceeded my expectations", "So happy with",
    "Incredible", "Very satisfied", "Couldn't be happier", "Top-notch", "Great value",
    "Awesome", "Just perfect", "Five stars", "This is great", "So delighted", "Worth every penny",
    "Pleasantly surprised"
]

negative_bases = [
    "This is terrible", "Worst", "Not worth", "Very disappointed", 
    "Absolutely awful", "I hate", "So dissatisfied", "Would not recommend", 
    "Terrible", "Regret buying", "Awful", "Such a waste", "One star",
    "Horrible", "Not what I expected", "Highly disappointing", "Unhappy with",
    "Doesn't work as expected", "Extremely unsatisfied", "Money wasted"
]

# Add some variations to these base phrases
suffixes = ["product", "experience", "purchase", "service", "quality"]

# Generate dataset with variations
train_sentences = []
train_labels = []

for _ in range(40):
    positive_sentence = random.choice(positive_bases) + " " + random.choice(suffixes)
    negative_sentence = random.choice(negative_bases) + " " + random.choice(suffixes)
    train_sentences.append(positive_sentence)
    train_labels.append(1)
    train_sentences.append(negative_sentence)
    train_labels.append(0)

# Shuffle the dataset
combined = list(zip(train_sentences, train_labels))
random.shuffle(combined)
train_sentences[:], train_labels[:] = zip(*combined)

# Print first 10 examples to check
for i in range(10):
    print(f"Sentence: {train_sentences[i]}, Label: {train_labels[i]}")


Sentence: Regret buying product, Label: 0
Sentence: So happy with quality, Label: 1
Sentence: Extremely unsatisfied quality, Label: 0
Sentence: This is great service, Label: 1
Sentence: This is amazing experience, Label: 1
Sentence: Money wasted service, Label: 0
Sentence: So dissatisfied service, Label: 0
Sentence: Five stars service, Label: 1
Sentence: This is amazing quality, Label: 1
Sentence: One star purchase, Label: 0


In [105]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = TransformerEncoder(embed_size=512, heads=8, num_layers=3, dropout=0.3).to(device)
model

cuda


TransformerEncoder(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=512, out_features=2, bias=True)
)

In [106]:
# The token_embeddings are used to convert each token (word) in the sentence into a high-dimensional vector representation that captures its semantic meaning. These embeddings are typically learned from large amounts of text data and are useful for many natural language processing tasks. By starting with these precomputed embeddings, the model can more easily learn to make accurate predictions.

def tokens_encoded(word, encod_model):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    word_tokenized = tokenizer(word, return_tensors='pt', truncation=True, padding='max_length', max_length=512)['input_ids'].float().to(device)
    return encod_model.encoder(word_tokenized)  # Get the embeddings

# tokens_encoded('example', model)

In [107]:
from IPython.display import clear_output

for epoch in range(1): # increase
    
    losses_0 = 0
    losses_1 = 0
    labels_0 = 0
    labels_1 = 0
    
    print('epoch:', epoch)
    for i, (sentence, label) in enumerate(zip(train_sentences, train_labels)):        
        tokens = sentence.split()

        # Get token embeddings and ensure proper shape
        data = torch.stack([tokens_encoded(token, model) for token in tokens]).view(1, len(tokens), 512).to(device)
        label_tensor = torch.tensor([label]).to(device)

        # Generate predictions and compute loss
        output = model(data)
        loss = criterion(output, label_tensor)
        
        if label == 0:
            losses_0 += loss
            labels_0 += 1
        else:
            losses_1 += loss
            labels_1 += 1

        # Backpropagation and optimizer steps
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, aver_loss_1: {losses_1/labels_1}, aver_loss_0: {losses_0/labels_0}")
    

epoch: 0


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 0, aver_loss_1: 0.39282071590423584, aver_loss_0: 1.0661193132400513


In [108]:
def predict(sentence):
    model.eval()
    # Deactivate the gradient computations and get the sentiment prediction.
    with torch.no_grad():
        tokens = sentence.split()
        print('tokens', tokens)
        data = torch.stack([tokens_encoded(token, model) for token in tokens]).view(1, len(tokens), 512)
        print('data.shape', data.shape)
        output = model(data)
        print('output', output)
        predicted = torch.argmax(output, dim=1)
        print('predicted', predicted)
        return "Positive" if predicted.item() == 1 else "Negative"

sample_sentence = "This is good"
print(f"'{sample_sentence}' is {predict(sample_sentence)}")

tokens ['This', 'is', 'good']
data.shape torch.Size([1, 3, 512])
output tensor([[-1.1916, -0.4394]], device='cuda:0')
predicted tensor([1], device='cuda:0')
'This is good' is Positive


In [109]:
train_input_data = [[2, 7, 8, 6, 2], [16, 14, 4, 12], [9, 14, 13, 15], [3, 14, 2, 10]]
train_target_data = [11, 5, 0, 1]

ix_to_word = {0: 'noisy',
             1: 'mammals',
             2: 'the',
             3: 'whales',
             4: 'very',
             5: 'animals',
             6: 'on',
             7: 'cat',
             8: 'sat',
             9: 'parrots',
             10: 'largest',
             11: 'mat',
             12: 'loyal',
             13: 'colorful',
             14: 'are',
             15: 'and',
             16: 'dogs'}

# The pad_sequences function is used to ensure that all sequences in a batch have the same length so they can be processed together. 

def pad_sequences(batch):
    max_len = max([len(seq) for seq in batch])
    return torch.stack([torch.cat([torch.tensor(seq), torch.zeros(max_len-len(seq)).long()]) for seq in batch])

# Pad sequences and move to device
train_input_data = pad_sequences(train_input_data).to(device)
train_target_data = torch.tensor(train_target_data).to(device)
train_input_data, train_target_data

(tensor([[ 2,  7,  8,  6,  2],
         [16, 14,  4, 12,  0],
         [ 9, 14, 13, 15,  0],
         [ 3, 14,  2, 10,  0]], device='cuda:0'),
 tensor([11,  5,  0,  1], device='cuda:0'))

In [113]:
vocab_size = len(ix_to_word)

embedding_dim = 10
hidden_dim = 16

vocab_size

17

In [114]:
class RNNWithAttentionModel(nn.Module):
    def __init__(self):
        super(RNNWithAttentionModel, self).__init__()
        # Create an embedding layer for the vocabulary
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        
        # Linear layer for attention mechanism
        # The purpose of the attention mechanism is to allow the model to focus on different parts of the input sequence when making predictions.
        # Here, the linear layer takes the hidden state from the RNN and transforms it to produce attention scores for each time step in the input sequence.
        # The dimension of the output is set to `hidden_dim` to ensure that the attention scores align with the hidden states' dimension.
        self.attention = nn.Linear(hidden_dim, hidden_dim)
        
        # Linear layer for classification
        # This layer maps the context vector (which is a weighted sum of hidden states) to the output vocabulary.
        # This is used to predict the next word in the sequence.
        self.fc = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        x = self.embeddings(x)  # Convert word indices to embeddings
        out, _ = self.rnn(x)  # Process embeddings with RNN to get hidden states for each time step
        
        # Attention mechanism: 
        # Step 1: Compute attention scores
        # The attention layer takes the hidden states from the RNN and produces attention scores.
        # These scores represent the importance of each time step in the input sequence for the current prediction.
        attn_scores = self.attention(out)
        
        # Step 2: Apply softmax to get attention weights
        # Softmax is applied to the attention scores to normalize them to a probability distribution.
        # This means the attention weights sum to 1 and can be interpreted as the model's confidence in each time step's importance.
        attn_weights = torch.nn.functional.softmax(attn_scores, dim=1)
        
        # Step 3: Compute context vector
        # The context vector is a weighted sum of the hidden states, where the weights are the attention scores.
        # This effectively summarizes the entire input sequence, focusing more on the important parts as determined by the attention mechanism.
        context = torch.sum(attn_weights * out, dim=1)
        
        # Step 4: Classification
        # The context vector is passed through a fully connected layer to produce the final output.
        # This output is used to predict the next word in the sequence.
        out = self.fc(context)
        
        return out



In [115]:
attention_model = RNNWithAttentionModel().to(device)
attention_model

RNNWithAttentionModel(
  (embeddings): Embedding(17, 10)
  (rnn): RNN(10, 16, batch_first=True)
  (attention): Linear(in_features=16, out_features=16, bias=True)
  (fc): Linear(in_features=16, out_features=17, bias=True)
)

In [116]:
optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    weight_decay: 0
)

In [117]:
# TRAIN RNN ATT

epochs = 10 # change to 300

for epoch in range(epochs):
    attention_model.train()
    optimizer.zero_grad()
    
    print('train_input_data', train_input_data)
    # print('train_input_data.shape', train_input_data.shape)
    print('train_target_data', train_target_data)
    # print('train_target_data.shape', train_target_data.shape)
    
    train_outputs = attention_model(train_input_data)
    # print('train_outputs', train_outputs)
    # print('train_outputs.shape', train_outputs.shape)
    
    probabilities = F.softmax(train_outputs, dim=1)  # Apply softmax to the outputs to get probabilities (just for understanding, because the loss criterion uses the outputs logits direct, instead of the probabilities)
    # print('probabilities:', probabilities)  
    print('training prediction:', probabilities.argmax(dim=1))
    
    loss = criterion(train_outputs, train_target_data)
    print('loss', loss)
    loss.backward()
    optimizer.step()
    print('----------------------------------------------------------------------------------------------------------------------')


train_input_data tensor([[ 2,  7,  8,  6,  2],
        [16, 14,  4, 12,  0],
        [ 9, 14, 13, 15,  0],
        [ 3, 14,  2, 10,  0]], device='cuda:0')
train_target_data tensor([11,  5,  0,  1], device='cuda:0')
training prediction: tensor([ 2, 10, 16,  2], device='cuda:0')
loss tensor(2.7768, device='cuda:0', grad_fn=<NllLossBackward0>)
----------------------------------------------------------------------------------------------------------------------
train_input_data tensor([[ 2,  7,  8,  6,  2],
        [16, 14,  4, 12,  0],
        [ 9, 14, 13, 15,  0],
        [ 3, 14,  2, 10,  0]], device='cuda:0')
train_target_data tensor([11,  5,  0,  1], device='cuda:0')
training prediction: tensor([ 1,  5, 16,  1], device='cuda:0')
loss tensor(2.6042, device='cuda:0', grad_fn=<NllLossBackward0>)
----------------------------------------------------------------------------------------------------------------------
train_input_data tensor([[ 2,  7,  8,  6,  2],
        [16, 14,  4, 12,  0],

In [118]:
# Sample input and target data
test_input_data = [[2, 7, 8, 6, 2], [16, 14, 4, 12], [9, 14, 13, 15], [3, 14, 2, 10], [2, 10, 1, 14, 2]]
test_target_data = [11, 5, 0, 1, 3]

# ix_to_word = {0: 'noisy',
#              1: 'mammals',
#              2: 'the',
#              3: 'whales',
#              4: 'very',
#              5: 'animals',
#              6: 'on',
#              7: 'cat',
#              8: 'sat',
#              9: 'parrots',
#              10: 'largest',
#              11: 'mat',
#              12: 'loyal',
#              13: 'colorful',
#              14: 'are',
#              15: 'and',
#              16: 'dogs'}

In [119]:
# TEST RNN ATT

# Process each input sequence and target
for input_seq, target in zip(test_input_data, test_target_data):
    input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)
    print('input_test', input_test)
    print([ix_to_word[index.item()] for index in input_test[0]])

    # Set the attention model to evaluation mode
    attention_model.eval()

    # Get the attention output by passing the appropriate input
    attention_output = attention_model(input_test)

    # Get the target word
    target_output = ix_to_word[target]
    print('real output', target_output)

    # Get the predicted word
    attention_prediction = ix_to_word[torch.argmax(attention_output).item()]
    print('attention_prediction', attention_prediction)

input_test tensor([[2, 7, 8, 6, 2]], device='cuda:0')
['the', 'cat', 'sat', 'on', 'the']
real output mat
attention_prediction mat
input_test tensor([[16, 14,  4, 12]], device='cuda:0')
['dogs', 'are', 'very', 'loyal']
real output animals
attention_prediction animals
input_test tensor([[ 9, 14, 13, 15]], device='cuda:0')
['parrots', 'are', 'colorful', 'and']
real output noisy
attention_prediction noisy
input_test tensor([[ 3, 14,  2, 10]], device='cuda:0')
['whales', 'are', 'the', 'largest']
real output mammals
attention_prediction mammals
input_test tensor([[ 2, 10,  1, 14,  2]], device='cuda:0')
['the', 'largest', 'mammals', 'are', 'the']
real output whales
attention_prediction mammals


In [120]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class RNNModel_2(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(RNNModel_2, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.hidden_dim = hidden_dim
        
    def forward(self, x):
        x = self.embeddings(x)  # Convert input indices to embeddings
        h0 = torch.zeros(1, x.size(0), self.hidden_dim).to(x.device)  # Initial hidden state
        out, _ = self.rnn(x, h0)  # Pass the input and initial hidden state through the RNN
        out = out[:, -1, :]  # Select the last output of each sequence
        out = self.fc(out)  # Pass the last outputs through the fully connected layer
        return out


In [121]:
# Initialize model, optimizer, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rnn_model = RNNModel_2(vocab_size, embedding_dim, hidden_dim).to(device)
optimizer = optim.Adam(rnn_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [122]:
# TRAN RNN SIMP

# Training loop
for epoch in range(1000):
    rnn_model.train()
    optimizer.zero_grad()

    train_outputs = rnn_model(train_input_data)
    loss = criterion(train_outputs, train_target_data)
    
    # print('training prediction:', train_outputs.argmax(dim=1))
    # print('loss', loss.item())

    loss.backward()
    optimizer.step()
    
    if epoch%100 == 0:
        print('Epoch:', epoch)

Epoch: 0
Epoch: 100
Epoch: 200
Epoch: 300
Epoch: 400
Epoch: 500
Epoch: 600
Epoch: 700
Epoch: 800
Epoch: 900


In [123]:
# TEST RNN SIMP

# Process each input sequence and target
for input_seq, target in zip(test_input_data, test_target_data):
    input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)
    print('input_test', input_test)
    print([ix_to_word[index.item()] for index in input_test[0]])

    # Set the model to evaluation mode
    rnn_model.eval()

    # Get the model output
    rnn_model_output = rnn_model(input_test)

    # Convert target tensor to Python integer
    target_output = ix_to_word[target]
    print('real output', target_output)

    # Get the predicted word
    rnn_model_prediction = ix_to_word[rnn_model_output.argmax(dim=1).item()]
    print('predicted output', rnn_model_prediction)
    print('-------------------')

input_test tensor([[2, 7, 8, 6, 2]], device='cuda:0')
['the', 'cat', 'sat', 'on', 'the']
real output mat
predicted output mat
-------------------
input_test tensor([[16, 14,  4, 12]], device='cuda:0')
['dogs', 'are', 'very', 'loyal']
real output animals
predicted output sat
-------------------
input_test tensor([[ 9, 14, 13, 15]], device='cuda:0')
['parrots', 'are', 'colorful', 'and']
real output noisy
predicted output noisy
-------------------
input_test tensor([[ 3, 14,  2, 10]], device='cuda:0')
['whales', 'are', 'the', 'largest']
real output mammals
predicted output mammals
-------------------
input_test tensor([[ 2, 10,  1, 14,  2]], device='cuda:0')
['the', 'largest', 'mammals', 'are', 'the']
real output whales
predicted output mat
-------------------
