In [33]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import spacy
from utils import build_phrase_vocab
from preprocessing import preprocess_with_phrases

In [34]:
import pandas as pd
import csv

df = pd.read_csv('hindi_english_parallel.csv')

df_ = df.sample(20)

en_texts = df_['english'].to_list()
hi_texts = df_['hindi'].to_list()

In [37]:
sequence_len = 128
min_len = 5

# Filter out non-string values from en_texts and hi_texts
en_texts_filtered = [text for text in en_texts if isinstance(text, str)]
hi_texts_filtered = [text for text in hi_texts if isinstance(text, str)]

en_proc, hi_proc, phrase_tags = preprocess_with_phrases(en_texts_filtered, hi_texts_filtered, min_len, sequence_len)
phrase2idx = build_phrase_vocab()

In [38]:
phrase2idx

{'O': 0, 'NP': 1, 'VP': 2, 'PP': 3, 'ADJP': 4, 'ADVP': 5, 'CONJP': 6, 'QP': 7}

In [39]:
class CharPhraseDataset(Dataset):
    def __init__(self, x, y, phrases, sequence_len, ch2i, phrase2idx):
        self.x, self.y, self.phrases = x, y, phrases
        self.sequence_len = sequence_len
        self.ch2i = ch2i
        self.phrase2idx = phrase2idx

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x = [self.ch2i.get(c, 0) for c in self.x[idx]]
        y = [self.ch2i.get(c, 0) for c in self.y[idx]]
        p = [self.phrase2idx.get(tag, 0) for tag in self.phrases[idx]]
        x = x[:self.sequence_len] + [0]*(self.sequence_len - len(x))
        y = y[:self.sequence_len] + [0]*(self.sequence_len - len(y))
        p = p[:self.sequence_len] + [0]*(self.sequence_len - len(p))
        return torch.tensor(x), torch.tensor(y), torch.tensor(p)

In [40]:
chars = set(''.join(en_proc + hi_proc))
ch2i = {c: i for i, c in enumerate(['<pad>'] + sorted(list(chars)))}

In [41]:
ch2i

{'<pad>': 0,
 ' ': 1,
 '(': 2,
 ')': 3,
 ',': 4,
 '.': 5,
 '0': 6,
 '1': 7,
 '2': 8,
 '3': 9,
 '5': 10,
 'A': 11,
 'C': 12,
 'D': 13,
 'H': 14,
 'J': 15,
 'M': 16,
 'N': 17,
 'P': 18,
 'R': 19,
 'T': 20,
 'W': 21,
 '_': 22,
 'a': 23,
 'b': 24,
 'c': 25,
 'd': 26,
 'e': 27,
 'f': 28,
 'g': 29,
 'h': 30,
 'i': 31,
 'k': 32,
 'l': 33,
 'm': 34,
 'n': 35,
 'o': 36,
 'p': 37,
 'r': 38,
 's': 39,
 't': 40,
 'u': 41,
 'v': 42,
 'w': 43,
 'y': 44,
 'ं': 45,
 'अ': 46,
 'आ': 47,
 'इ': 48,
 'ई': 49,
 'उ': 50,
 'ऊ': 51,
 'ए': 52,
 'ओ': 53,
 'औ': 54,
 'क': 55,
 'ग': 56,
 'च': 57,
 'छ': 58,
 'ज': 59,
 'ञ': 60,
 'ट': 61,
 'ठ': 62,
 'ड': 63,
 'ढ': 64,
 'ण': 65,
 'त': 66,
 'थ': 67,
 'द': 68,
 'ध': 69,
 'न': 70,
 'प': 71,
 'फ': 72,
 'ब': 73,
 'भ': 74,
 'म': 75,
 'य': 76,
 'र': 77,
 'ल': 78,
 'व': 79,
 'श': 80,
 'ष': 81,
 'स': 82,
 'ह': 83,
 '़': 84,
 'ा': 85,
 'ि': 86,
 'ी': 87,
 'ु': 88,
 'ू': 89,
 'ृ': 90,
 'े': 91,
 'ै': 92,
 'ो': 93,
 '्': 94,
 '।': 95}

In [42]:
dataset = CharPhraseDataset(en_proc, hi_proc, phrase_tags, sequence_len, ch2i, phrase2idx)

In [43]:
from transformer import TransformerWithPhrase, TransformerConfig
mconfig = TransformerConfig(
    vocab_size=len(ch2i),
    sequence_len=sequence_len,
    nblock=4,
    nhead=8,
    embed_dim=256,
    phrase_emb_dim=16
)
model = TransformerWithPhrase(mconfig, phrase_vocab_size=len(phrase2idx))

In [44]:
from trainer import Trainer, TrainerConfig
trainer_config = TrainerConfig(max_epochs=100, batch_size=64, learning_rate=3e-4, device='cuda' if torch.cuda.is_available() else 'cpu')
trainer = Trainer(model, dataset, trainer_config)
trainer.train()

Epoch 1, Loss: 4.4590
Epoch 2, Loss: 4.2227
Epoch 2, Loss: 4.2227
Epoch 3, Loss: 3.7234
Epoch 3, Loss: 3.7234
Epoch 4, Loss: 3.7387
Epoch 4, Loss: 3.7387
Epoch 5, Loss: 3.6666
Epoch 5, Loss: 3.6666
Epoch 6, Loss: 3.5823
Epoch 6, Loss: 3.5823
Epoch 7, Loss: 3.5240
Epoch 7, Loss: 3.5240
Epoch 8, Loss: 3.5213
Epoch 8, Loss: 3.5213
Epoch 9, Loss: 3.5201
Epoch 9, Loss: 3.5201
Epoch 10, Loss: 3.5057
Epoch 10, Loss: 3.5057
Epoch 11, Loss: 3.4820
Epoch 11, Loss: 3.4820
Epoch 12, Loss: 3.4682
Epoch 12, Loss: 3.4682
Epoch 13, Loss: 3.4597
Epoch 13, Loss: 3.4597
Epoch 14, Loss: 3.4608
Epoch 14, Loss: 3.4608
Epoch 15, Loss: 3.4675
Epoch 15, Loss: 3.4675
Epoch 16, Loss: 3.4528
Epoch 16, Loss: 3.4528
Epoch 17, Loss: 3.4421
Epoch 17, Loss: 3.4421
Epoch 18, Loss: 3.4324
Epoch 18, Loss: 3.4324
Epoch 19, Loss: 3.4363
Epoch 19, Loss: 3.4363
Epoch 20, Loss: 3.4288
Epoch 20, Loss: 3.4288
Epoch 21, Loss: 3.4304
Epoch 21, Loss: 3.4304
Epoch 22, Loss: 3.4256
Epoch 22, Loss: 3.4256
Epoch 23, Loss: 3.4232
Epoch

In [45]:
from preprocessing import extract_7_phrases
test_sents = ["python", "girlfriend"]
test_phrases = [extract_7_phrases(s) for s in test_sents]
test_p = [[phrase2idx.get(tag, 0) for tag in tags] + [0]*(sequence_len - len(tags)) for tags in test_phrases]
test_x = [[ch2i.get(c, 0) for c in s] + [0]*(sequence_len - len(s)) for s in test_sents]

device = trainer_config.device
# ensure model is on same device and in eval mode
model.to(device)
model.eval()

test_x = torch.tensor(test_x, dtype=torch.long, device=device)
test_p = torch.tensor(test_p, dtype=torch.long, device=device)

with torch.no_grad():
    translations = model.generate(test_x, test_p)

i2ch = {i: c for c, i in ch2i.items()}
def decode(indices):
    return ''.join([i2ch.get(int(idx), '') for idx in indices if int(idx) != 0])

for sent in translations.cpu().numpy():
    print(decode(sent))
print(translations)

   मा r मे अक काय उपाधाद पा पा
   मा r मे अक काय उपाधाद पा पा
tensor([[ 1,  1,  1, 75, 85,  1, 38,  1, 75, 91,  1, 46, 55,  1, 55, 85, 76,  1,
         50, 71, 85, 69, 85, 68,  1, 71, 85,  1, 71, 85],
        [ 1,  1,  1, 75, 85,  1, 38,  1, 75, 91,  1, 46, 55,  1, 55, 85, 76,  1,
         50, 71, 85, 69, 85, 68,  1, 71, 85,  1, 71, 85]])


In [46]:
i2ch

{0: '<pad>',
 1: ' ',
 2: '(',
 3: ')',
 4: ',',
 5: '.',
 6: '0',
 7: '1',
 8: '2',
 9: '3',
 10: '5',
 11: 'A',
 12: 'C',
 13: 'D',
 14: 'H',
 15: 'J',
 16: 'M',
 17: 'N',
 18: 'P',
 19: 'R',
 20: 'T',
 21: 'W',
 22: '_',
 23: 'a',
 24: 'b',
 25: 'c',
 26: 'd',
 27: 'e',
 28: 'f',
 29: 'g',
 30: 'h',
 31: 'i',
 32: 'k',
 33: 'l',
 34: 'm',
 35: 'n',
 36: 'o',
 37: 'p',
 38: 'r',
 39: 's',
 40: 't',
 41: 'u',
 42: 'v',
 43: 'w',
 44: 'y',
 45: 'ं',
 46: 'अ',
 47: 'आ',
 48: 'इ',
 49: 'ई',
 50: 'उ',
 51: 'ऊ',
 52: 'ए',
 53: 'ओ',
 54: 'औ',
 55: 'क',
 56: 'ग',
 57: 'च',
 58: 'छ',
 59: 'ज',
 60: 'ञ',
 61: 'ट',
 62: 'ठ',
 63: 'ड',
 64: 'ढ',
 65: 'ण',
 66: 'त',
 67: 'थ',
 68: 'द',
 69: 'ध',
 70: 'न',
 71: 'प',
 72: 'फ',
 73: 'ब',
 74: 'भ',
 75: 'म',
 76: 'य',
 77: 'र',
 78: 'ल',
 79: 'व',
 80: 'श',
 81: 'ष',
 82: 'स',
 83: 'ह',
 84: '़',
 85: 'ा',
 86: 'ि',
 87: 'ी',
 88: 'ु',
 89: 'ू',
 90: 'ृ',
 91: 'े',
 92: 'ै',
 93: 'ो',
 94: '्',
 95: '।'}