## Importing Library

In [27]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torch.nn.utils.rnn import pad_sequence
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from torchtext.datasets import IMDB
import time
import random
import math
import warnings
def warn(*args,**kwargs):
    pass

warnings.warn = warn
warnings.filterwarnings('ignore')

## Dataset

In [2]:
## Loading dataset
train_iter,test_iter = IMDB()
train_iter,test_iter

(ShardingFilterIterDataPipe, ShardingFilterIterDataPipe)

In [3]:
# Quantitative and qualitative information about the training data
label,text_list = [],[]
test_label,test_text_list = [],[]
start_time = time.time()
for _,text in train_iter:
    label.append(_)
    text_list.append(text)

print(len(label))
print(list(set(label)))
print(f"Total sample:{len(text_list)}")


for _,text in test_iter:
    test_label.append(_)
    test_text_list.append(text)

end_time = time.time()
duration = end_time-start_time
print(len(test_label))
print(list(set(test_label)))
print(f"Total sample in testdataset:{len(test_text_list)}")
print(f"Time required: {duration:.2f} seconds")

12500
[1]
Total sample:12500
25000
[1, 2]
Total sample in testdataset:25000
Time required: 0.92 seconds


The training dataset has 12500 text block and the test dataset has 25000 text block

In [4]:
_,text = next(iter(train_iter))
_test,text_test = next(iter(test_iter))

print(f"First Train Text example: {text}\n")
print(f"First Test Text example: {text_test}")

First Train Text example: I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes ar

## Data Processing

In [5]:
## Create tokens
tokenizer = get_tokenizer('basic_english')

UNK_IDX, PAD_IDX, EOS_IDX = 0,1,2

special_symbols = ['<unk>','<pad>','<|endoftext|>']

vocab = build_vocab_from_iterator(map(tokenizer,[text for _,text in train_iter]), specials = special_symbols,special_first = True)


In [6]:
len(vocab)

68813

In [7]:
vocab['drink']

2435

In [8]:
## Sample accumulation
"""In decoder you have to define at how many blocks the model will look at once, This could be mentioned as context."""

def get_sample(context_size, text):
    """Context_size: The number of tokens the model will look at once
    text: the whole text_tokens, From where the model will retrieve the tokens and assign to the source and target"""
    sample_len = len(text)

    src,trg = [],[]
    if sample_len>context_size:
        start = torch.randint(low = 0,high = sample_len-context_size, size = (1,)).item()
        end = start+context_size

        src = text[start:end]
        trg = text[start+1:end+1]


    elif sample_len< context_size:
        start = torch.randint(0,sample_len,size = (1,)).item()
        end = start+context_size

        src = text[start:end]
        trg = text[start+1:end]

        trg.append('<|endoftext|>')

    return src, trg
        

In [9]:
BATCH_SIZE = 2
CONTEXT_SIZE = 20
for _ in range(BATCH_SIZE):
    _,text = next(iter(train_iter))

    src,trg = get_sample(context_size = CONTEXT_SIZE, text = tokenizer(text))

    print(f"Sample: {_}")
    print(f"Source: {src}")
    print(f"Target: {trg}")
    

Sample: 1
Source: ['and', 'ordinary', 'denizens', 'of', 'stockholm', 'about', 'their', 'opinions', 'on', 'politics', ',', 'she', 'has', 'sex', 'with', 'her', 'drama', 'teacher', ',', 'classmates']
Target: ['ordinary', 'denizens', 'of', 'stockholm', 'about', 'their', 'opinions', 'on', 'politics', ',', 'she', 'has', 'sex', 'with', 'her', 'drama', 'teacher', ',', 'classmates', ',']
Sample: 1
Source: ['made', 'porno', '.', 'while', 'my', 'countrymen', 'mind', 'find', 'it', 'shocking', ',', 'in', 'reality', 'sex', 'and', 'nudity', 'are', 'a', 'major', 'staple']
Target: ['porno', '.', 'while', 'my', 'countrymen', 'mind', 'find', 'it', 'shocking', ',', 'in', 'reality', 'sex', 'and', 'nudity', 'are', 'a', 'major', 'staple', 'in']


## INDEX TO ENGLISH & ENGLISH TO INDEX

In [10]:
idx_to_eng = lambda seq : " ".join([vocab.get_itos()[idx] for idx in seq])
eng_to_idx = lambda text: [vocab[token] for token in tokenizer(text)]

In [11]:
BATCH_SIZE  = 5
CONTEXT_SIZE = 20
src_batch,trg_batch = [],[]

for i in range(BATCH_SIZE):
    _,text = next(iter(train_iter)) # Take the first text sample of the training data
    src,trg = get_sample(context_size = CONTEXT_SIZE,text = tokenizer(text))
    src_vocab,trg_vocab = vocab(src),vocab(trg)
    src_tensors,trg_tensors = torch.tensor(src_vocab,dtype = torch.int64),torch.tensor(trg_vocab,dtype = torch.int64)

    src_batch.append(src_tensors),trg_batch.append(trg_tensors)

    print(f"sample: {i}")
    print(f"Source: {src_batch}")
    print(f"Target: {trg_batch}")

sample: 0
Source: [tensor([   10,    33,   693,    14,  7458,  2013,    14,   949,     3,    13,
          230, 24141,    11,     6,    61,    25,    20,   248,  1798,    10])]
Target: [tensor([   33,   693,    14,  7458,  2013,    14,   949,     3,    13,   230,
        24141,    11,     6,    61,    25,    20,   248,  1798,    10,  2307])]
sample: 1
Source: [tensor([   10,    33,   693,    14,  7458,  2013,    14,   949,     3,    13,
          230, 24141,    11,     6,    61,    25,    20,   248,  1798,    10]), tensor([   14,   259,  1743,  7457,     7,  2318, 29828,     9, 16111,    52,
           80,  4554,    28,  2407,     5,    68,    60,   338,    22,    57])]
Target: [tensor([   33,   693,    14,  7458,  2013,    14,   949,     3,    13,   230,
        24141,    11,     6,    61,    25,    20,   248,  1798,    10,  2307]), tensor([  259,  1743,  7457,     7,  2318, 29828,     9, 16111,    52,    80,
         4554,    28,  2407,     5,    68,    60,   338,    22,    57,   615

## Create Custom Collate Function

In [12]:
def collate_function(batch):
    """The collate_batch function prepares batches of source and target sequences for training by processing each text sample in a given batch. 
    It generates source and target sequences using the get_sample function with a specified block size, converts these sequences to indices using a vocabulary, and transforms them into PyTorch tensors. 
    The sequences are then padded to ensure uniform length across the batch. Finally, it returns the padded source and target batches, ready for training on the specified device (DEVICE)."""

    src_batch,trg_batch = [],[]
    for _,text in batch:

        token_text = tokenizer(text)

        src,trg = get_sample(CONTEXT_SIZE,token_text)

        src_indices, trg_indices = vocab(src),vocab(trg)

        src_seq,trg_seq = torch.tensor(src_indices,dtype = torch.int64),torch.tensor(trg_indices,dtype = torch.int64)

        
        src_batch.append(src_seq)
        trg_batch.append(trg_seq)



    src_batch = pad_sequence(src_batch,padding_value = PAD_IDX, batch_first = False)
    trg_batch = pad_sequence(trg_batch,padding_value = PAD_IDX, batch_first = False)
    return src_batch,trg_batch

## Create DataLoaders

In [13]:
train_dataloader = DataLoader(dataset = train_iter,
                             batch_size = 1,
                             shuffle = True,
                             collate_fn = collate_function)
test_dataloader = DataLoader(dataset = test_iter,
                            batch_size = 1,
                            shuffle = True,
                            collate_fn = collate_function)

## Iterating through Data sample

In [14]:
batch = iter(train_dataloader)

for sample in range(10):
    src,trg = next(batch)

    print(f"sample: {sample}")
    print(f"Source: {idx_to_eng(src)}")
    print(f"Target: {idx_to_eng(trg)}")

sample: 0
Source: his film-lenghts art works - but , unfortunately , he masters at a way too limited level the specifically cinematographic
Target: film-lenghts art works - but , unfortunately , he masters at a way too limited level the specifically cinematographic means
sample: 1
Source: , the fighting is rather poor . bone manages to take out well established tough-man street fighters in single punches
Target: the fighting is rather poor . bone manages to take out well established tough-man street fighters in single punches (
sample: 2
Source: what a shame ) , and other various creep-azoids to pretend to spoof way too may things has nothing going
Target: a shame ) , and other various creep-azoids to pretend to spoof way too may things has nothing going for
sample: 3
Source: many of the scenes are more dull and lifeless than staring at a wall for two hours . stroker ace
Target: of the scenes are more dull and lifeless than staring at a wall for two hours . stroker ace is
sample: 4
Sou

## Device Agnostic Code

In [15]:
DEVICE = 'mps' if torch.backends.mps.is_available() else 'cpu'
DEVICE

'mps'

## Positional Embedding

In [16]:
class PositionalEmbedding(nn.Module):

    def __init__(self,
                emb_dim: int,
                dropout: float,
                max_len = 5000):

        super().__init__()

        den = torch.exp(-torch.arange(0,emb_dim,2)*math.log(10000)/emb_dim)
        pos  = torch.arange(0,max_len).reshape(max_len,1)

        pos_embedding = torch.zeros(size = (max_len,emb_dim))


        pos_embedding[:,0::2] = torch.sin(pos*den)
        pos_embedding[:,1::2] = torch.cos(pos*den)


        # Pos_embedding_shape = [seq_len,emb_dim]

        # add the batch_size to the sequence
        pos_embedding = pos_embedding.unsqueeze(dim = -2)

        self.dropout = nn.Dropout(dropout)

        # Positional embedding is a non-learnable parameter. It won't be updated with time
        
        self.register_buffer('pos_embedding',pos_embedding)


    def forward(self,token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0)])

## Masking

In [17]:
def create_mask(src_seq_len):

    mask = (torch.triu(torch.ones(size = (src_seq_len,src_seq_len)))==1).transpose(0,1)
    mask = mask.float().masked_fill(mask==1,float(0.0)).masked_fill(mask ==0,float('-inf'))

    return mask

In [18]:
# Experiment
triu = torch.triu(torch.ones(size = (3,3))==1).transpose(0,1)
mask = triu.float().masked_fill(triu==1,float(0.0)).masked_fill(triu ==0,float('-inf'))
mask

tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])

In [19]:
def generate_mask(src):
    src_seq_len = src.shape[0]
    src_mask = create_mask(src_seq_len)
    src_padding_mask = (src==PAD_IDX).transpose(0,1)

    return src_mask, src_padding_mask

In [20]:
## Test
src_t = torch.rand(5,5)
m =generate_mask(src_t)
m

(tensor([[0., -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf],
         [0., 0., 0., 0., -inf],
         [0., 0., 0., 0., 0.]]),
 tensor([[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]]))

## Custom GPT Model Architecture

In [59]:
class CustomGPTModel(nn.Module):

    def __init__(self,
                emb_dim: int,
                vocab_size: int,
                num_head: int,
                num_layers: int,
                max_seq_len = 500,
                dropout = 0.1):
        super().__init__()

        self.init_weights()

        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.positional_encoding = PositionalEmbedding(emb_dim,dropout,max_len = max_seq_len)
        self.emb_dim  = emb_dim

        # Encoder layers
        encoder_layers = nn.TransformerEncoderLayer(d_model = emb_dim,nhead = num_head,dropout = dropout)

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer = encoder_layers,num_layers = num_layers)

        self.lm_head = nn.Linear(emb_dim,vocab_size) # produce the final outputs, the final logits over vocabulary


    def init_weights(self):

        for p in self.parameters():
            if p.dim>1:
                nn.init_xavier_uniform_(p)


    def create_mask(self,source):
        src_seq_len = source.shape[0]

        src_mask = nn.Transformer.generate_square_subsequent_mask(src_seq_len)

        src_padding_mask = (src==PAD_IDX).transpose(0,1)

        return src_mask, src_padding_mask

    def decoder(self,x,src_mask):
        

        x = x.to(DEVICE)

        seq_len = x.shape[0]


        # Add positional encodding to the input embedding
        x = self.embedding(x)*math.sqrt(self.emb_dim)

        x = self.positional_encoding(x)

        if src_mask is None:
            src_mask, src_padding_mask = generate_mask(x)


        output = self.transformer_encoder(x,src_mask)

        logits = self.lm_head(x)

        return logits



    def forward(self,x,src_mask= None, key_padding_mask = None):

        x.to(DEVICE)

        seq_length = x.shape[0]

        # add positional embeddings to the embedding
        x = self.embedding(x)*math.sqrt(self.emb_dim)
        x = self.positional_encoding(x)

        if src_mask is None:
            src_mask, src_padding_mask = generate_mask(x)


        output = self.transformer_encoder(x,src_mask, key_padding_mask)

        x = self.lm_head(x)

        return x
        

        

In [60]:
emb_dim = 500
num_head = 2
num_layers = 2
vocab_size = len(vocab)

model = CustomGPTModel(emb_dim = emb_dim,
                       vocab_size = vocab_size,
                       num_head = num_head,
                       num_layers = num_layers).to(DEVICE)

## Prompting

#### PARAMETERS

In [64]:
BLOCK_SIZE = 20


In [41]:

def prompting(prompt=None,block_size=20):

    while prompt is None:
        prompt = input("Prompt can't be empty. Please enter a valid prompt")

    prompt_tokens = tokenizer(prompt)

    if len(prompt_tokens)>block_size:
        prompt_tokens = prompt_tokens[-block_size:]


    prompt_indices = vocab(prompt_tokens)
    prompt_tensors = torch.tensor(prompt_indices,dtype = torch.int64).reshape(-1,1).to(DEVICE)

    return prompt_tensors
    

In [43]:
prompt = prompting()
prompt

Prompt can't be empty. Please enter a valid prompt my name is


tensor([[ 72],
        [373],
        [ 11]], device='mps:0')

In [35]:
prompt_tensor = prompting('The sky is')
prompt_tensor

tensor([[   4],
        [2290],
        [  11]])

## Output Check

In [48]:
def output(prompt):
    prompt_logit = prompting(prompt) # prompt_shape : [seq_len,batch_size]
    logits = model(prompt_logit)

    print(f"output_shape:{logits.shape}")

    logits = logits.transpose(0,1)

    logit_prediction = logits[:,-1]
    print(f"logit prediction dimension:{logit_prediction.shape}")

    next_token_encoded = torch.argmax(logit_prediction,dim = -1).reshape(-1,1)

    return next_token_encoded


result = output("The sky is")
predicted_word = idx_to_eng(result)
predicted_word

output_shape:torch.Size([3, 1, 68813])
logit prediction dimension:torch.Size([1, 68813])


'cumming'

## Autoregressive Text Generation

In [58]:
# Declaring prompt 
prompt = "The sky is"
prompt_tokens = prompting(prompt)
## By using model output
max_new_tokens = 20 # how many words you may allow your model to generate

for i in range(max_new_tokens):
    

    logit = model(prompt_tokens)
    logit_reshape = logit.transpose(0,1)

    logit_prediction = logit_reshape[:,-1]

    next_token_encoded = torch.argmax(logit_prediction, dim = -1).reshape(-1,1)

    prompt_tokens = torch.cat((prompt_tokens,next_token_encoded),dim =0)

    print(f"input_prompt_shape: {prompt_tokens.shape}")
    print(f"logit shape: {logit.shape}")
    print(f"next_token_shape: {next_token_encoded.shape}")
    print(f"output:{" ".join([idx_to_eng(tokens) for tokens in prompt_tokens])}")

input_prompt_shape: torch.Size([4, 1])
logit shape: torch.Size([3, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is plastic
input_prompt_shape: torch.Size([5, 1])
logit shape: torch.Size([4, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is plastic trying
input_prompt_shape: torch.Size([6, 1])
logit shape: torch.Size([5, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is plastic trying women/girls
input_prompt_shape: torch.Size([7, 1])
logit shape: torch.Size([6, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is plastic trying women/girls miiko
input_prompt_shape: torch.Size([8, 1])
logit shape: torch.Size([7, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is plastic trying women/girls miiko loooooong
input_prompt_shape: torch.Size([9, 1])
logit shape: torch.Size([8, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is plastic trying women/girls miiko loooooong smart-mouthed
input_prompt_shape:

In [62]:
# By using model decoder
# Declaring prompt 
prompt = "The sky is"
prompt_tokens = prompting(prompt)
## By using model output
max_new_tokens = 20 # how many words you may allow your model to generate

for i in range(max_new_tokens):
    

    logit = model.decoder(prompt_tokens,src_mask = None)
    logit_reshape = logit.transpose(0,1)

    logit_prediction = logit_reshape[:,-1]

    next_token_encoded = torch.argmax(logit_prediction, dim = -1).reshape(-1,1)

    prompt_tokens = torch.cat((prompt_tokens,next_token_encoded),dim =0)

    print(f"input_prompt_shape: {prompt_tokens.shape}")
    print(f"logit shape: {logit.shape}")
    print(f"next_token_shape: {next_token_encoded.shape}")
    print(f"output:{" ".join([idx_to_eng(tokens) for tokens in prompt_tokens])}")

input_prompt_shape: torch.Size([4, 1])
logit shape: torch.Size([3, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is give
input_prompt_shape: torch.Size([5, 1])
logit shape: torch.Size([4, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is give anti-human
input_prompt_shape: torch.Size([6, 1])
logit shape: torch.Size([5, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is give anti-human drunk/stoned
input_prompt_shape: torch.Size([7, 1])
logit shape: torch.Size([6, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is give anti-human drunk/stoned reimagined
input_prompt_shape: torch.Size([8, 1])
logit shape: torch.Size([7, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is give anti-human drunk/stoned reimagined communicable
input_prompt_shape: torch.Size([9, 1])
logit shape: torch.Size([8, 1, 68813])
next_token_shape: torch.Size([1, 1])
output:the sky is give anti-human drunk/stoned reimagined communicable land--

# Generation function

In [102]:

def generate(model, prompt=None, max_new_tokens = 15, block_size = BLOCK_SIZE, vocab= vocab, tokenizer=tokenizer):

    model.to(DEVICE)

    prompt_encoded = prompting(prompt).to(DEVICE)
    

    for i in range(max_new_tokens):

        logits = model(prompt_encoded).transpose(0,1)

        logit_prediction = logits[:,-1]

        encoded_logit = torch.argmax(logit_prediction, dim = -1).reshape(-1,1)


        # if the next token is end of text, then stop the generation
        if encoded_logit == EOS_IDX:
            break

        prompt_encoded = torch.cat((prompt_encoded,encoded_logit),dim = 0)[-block_size:]


        print(f"prompt_encoded_shape: {prompt_encoded.shape}")

        

    tokens=[idx_to_eng(tokens) for  tokens in prompt_encoded]

    return " ".join(tokens)

In [103]:
prompt_output = generate(model,prompt = "My love")

prompt_encoded_shape: torch.Size([3, 1])
prompt_encoded_shape: torch.Size([4, 1])
prompt_encoded_shape: torch.Size([5, 1])
prompt_encoded_shape: torch.Size([6, 1])
prompt_encoded_shape: torch.Size([7, 1])
prompt_encoded_shape: torch.Size([8, 1])
prompt_encoded_shape: torch.Size([9, 1])
prompt_encoded_shape: torch.Size([10, 1])
prompt_encoded_shape: torch.Size([11, 1])
prompt_encoded_shape: torch.Size([12, 1])
prompt_encoded_shape: torch.Size([13, 1])
prompt_encoded_shape: torch.Size([14, 1])
prompt_encoded_shape: torch.Size([15, 1])
prompt_encoded_shape: torch.Size([16, 1])
prompt_encoded_shape: torch.Size([17, 1])


In [104]:
prompt_output

'my love squirmed fifthly tolkiens angeline landover malte kleenex resolute ethnocentrism rychard janis laverne immaturity uncovers henry-freakin'

In [98]:
main_lis =[]
tokens = torch.randint(0,20,size = (20,1))
main_lis.append([idx_to_eng(tok) for tok in tokens])
main_lis

[['of',
  'this',
  'was',
  'is',
  ',',
  'this',
  'i',
  '<|endoftext|>',
  'was',
  "'",
  'the',
  'of',
  'this',
  '<pad>',
  'of',
  'in',
  'is',
  'is',
  'that',
  'i']]

## Training & Testing

The main difference in training and inferencing lies in the inputs to the decoder.
during training, the decoder has the access 0f the ground truth (receiving the exact target sequence tokens incrementally through a technique known as `teacher forcing`)

In [115]:
src,trg = next(iter(train_dataloader))


In [116]:
src,src.shape,tgr.shape,tgr

(tensor([[  197],
         [    8],
         [   24],
         [ 2721],
         [   61],
         [    5],
         [  209],
         [    5],
         [   23],
         [   15],
         [ 1439],
         [35725],
         [ 3947],
         [   60],
         [    6],
         [66903],
         [ 5258],
         [   16],
         [ 5190],
         [  287]]),
 torch.Size([20, 1]),
 torch.Size([20, 1]),
 tensor([[533],
         [ 11],
         [191],
         [ 20],
         [  4],
         [367],
         [ 19],
         [  3],
         [  4],
         [274],
         [658],
         [  6],
         [149],
         [ 83],
         [  3],
         [ 15],
         [ 19],
         [ 91],
         [ 31],
         [658]]))

In [117]:
mask,padding_mask =  generate_mask(src)

In [118]:
mask,padding_mask

(tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0.

In [119]:
logit = model(src.to(DEVICE),src_mask = mask.to(DEVICE),key_padding_mask = padding_mask.to(DEVICE))
print(logit.shape)

torch.Size([20, 1, 68813])


In [120]:
print(f"output shape{logit.shape}")
print(f"source shape {src.shape}")

output shapetorch.Size([20, 1, 68813])
source shape torch.Size([20, 1])


In [121]:
print(f"Target shape: {trg.shape}")

Target shape: torch.Size([20, 1])


In [122]:
src,trg

(tensor([[  197],
         [    8],
         [   24],
         [ 2721],
         [   61],
         [    5],
         [  209],
         [    5],
         [   23],
         [   15],
         [ 1439],
         [35725],
         [ 3947],
         [   60],
         [    6],
         [66903],
         [ 5258],
         [   16],
         [ 5190],
         [  287]]),
 tensor([[    8],
         [   24],
         [ 2721],
         [   61],
         [    5],
         [  209],
         [    5],
         [   23],
         [   15],
         [ 1439],
         [35725],
         [ 3947],
         [   60],
         [    6],
         [66903],
         [ 5258],
         [   16],
         [ 5190],
         [  287],
         [   14]]))

In [124]:
print(logit.reshape(-1,logit.shape[-1]).shape)
print(tgr.reshape(-1).shape)

torch.Size([20, 68813])
torch.Size([20])


## Loss Function

In [128]:
from torch.nn import CrossEntropyLoss
loss_fn = CrossEntropyLoss(ignore_index=PAD_IDX).to(DEVICE)

In [130]:
loss = loss_fn(logit.reshape(-1,logit.shape[-1]).to(DEVICE),trg.reshape(-1).to(DEVICE))
print(loss.item())

61.2148323059082


In [None]:
def evaluate(logit,target):

    logit = logit.reshape(-1,logit.shape[-1])

In [None]:
b = torch.