In [1]:
import torch
import bpe_tokenizer as D
import string


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset

In [3]:
ds = load_dataset("cfilt/iitb-english-hindi")

english_characters = list(string.ascii_lowercase) + list(string.ascii_uppercase)

punctuation_list = list(string.punctuation)

char_to_keep = english_characters + punctuation_list + [' ']

def custom_filter(example):

    for word in example['translation']['en']:
        if word not in char_to_keep:
            return False
        

    for word in example['translation']['hi']:
        if not ((ord(u'\u0900') <= ord(word) <= ord(u'\u097F') ) or (word in list(string.punctuation)) or (word == ' ')):
            return False
        
    # removed sentences greater than 90th percentile     
    if len(example['translation']['en']) > 161:
        return False
    
    if len(example['translation']['hi']) > 115:
        return False

    return True


ds_filtered = ds.filter(custom_filter)

# corpus = ds_filtered['train']['translation']

In [4]:
max_tokens = 200

In [5]:
x_en = ds_filtered['train'][0]['translation']['en']
x_hi = ds_filtered['train'][0]['translation']['hi']

x_en = [x_en]
x_hi = [x_hi]

In [6]:
all_tokens = D.bpe_en_obj.base_vocab + ['<unk>', '<pad>']
word2idx = {}
for ind, ele in enumerate(all_tokens):
    word2idx[ele] = ind

In [7]:
def tokenize(x):

    res = D.bpe_en_obj.tokenize(x)
    while len(res) < max_tokens:
        res.append('<pad>')

    
    return torch.tensor([word2idx[ele] for ele in res])


enc_input = torch.stack([tokenize(x) for x in x_en], dim = 0)

# print([tokenize(x) for x in x_en])

In [8]:
all_tokens = D.bpe_hin_obj.base_vocab + ['<unk>', '<pad>', '<eos>', '<start>']
word2idx = {}
for ind, ele in enumerate(all_tokens):
    word2idx[ele] = ind

In [9]:

def tokenize(x):

    res = D.bpe_hin_obj.tokenize(x)
    key = 0
    
    while len(res) < max_tokens:

        if not key:
            res.insert(0, '<start>')
            res.append('<eos>')
            key = 1
            continue

        res.append('<pad>')
    
    return torch.tensor([word2idx[ele] for ele in res])

dec_input = torch.stack([tokenize(x) for x in x_hi], dim = 0)

In [35]:
def make_target_output(x):
    temp = []

    for ele in x:
        shifted_tensor = torch.roll(ele, shifts=-1, dims=-1)
        shifted_tensor[-1] = 201
        temp.append(shifted_tensor)

    return torch.stack(temp, dim = 0)
        

target = make_target_output(dec_input)

In [11]:
enc_input.shape

torch.Size([1, 200])

In [53]:
# enc_input
# dec_input
# target

from decoder import decoder_stack
from encoder import encoder_stack

loss = torch.nn.CrossEntropyLoss()

enc = encoder_stack(4, 4, 512)

enc_output = enc(enc_input)

dec = decoder_stack(4, 4, 512, enc_output)

output = dec(dec_input)

output = output.reshape(-1, 204)
target = target.reshape(-1)

loss(output, target)
# add the cross entropy loss function

# do the backward pass

# add batching logic

# visualize training loss and see if its converging !


tensor(5.3200, grad_fn=<NllLossBackward0>)

In [38]:
# for ele in dec(dec_input)[0]:
#     print(sum(ele))


print(output.shape)
print(target.shape)

torch.Size([1, 200, 204])
torch.Size([1, 200])


In [37]:
# output, 
print(output)


tensor(1.9571, grad_fn=<NllLossBackward0>)


In [29]:
torch.nn.functional.one_hot(target, num_classes= 204 ).shape

torch.Size([1, 200, 204])

In [None]:
target

In [30]:
loss = torch.nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)

In [41]:
print(torch.empty(3, dtype=torch.long).random_(5))

tensor([2, 4, 1])


In [44]:
torch.squeeze(output).shape

torch.Size([200, 204])

In [47]:
torch.squeeze(target)

tensor([  5,  42, 167, 160,   5,  40,  65, 182,  47,  75,  23, 160, 181, 160,
         42,  57,  65,   2,  26,  40,  64,  47, 172, 160,  53, 180,  62, 166,
         46, 160, 169, 160, 193,  45, 160,  38, 165, 202, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
        201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 2

In [48]:
loss(torch.squeeze(output), torch.squeeze(target))

tensor(5.3190, grad_fn=<NllLossBackward0>)

In [52]:
temp = torch.randn(size = (2,2))

print(temp)
print(temp.reshape(-1))

tensor([[-0.7054,  0.8946],
        [ 0.2116, -2.9205]])
tensor([-0.7054,  0.8946,  0.2116, -2.9205])
