In [2]:
import numpy as np
import torch
from torch import nn
from transformers import AlbertModel, AlbertTokenizer
from nltk.tokenize import word_tokenize
from datasets import load_dataset, DatasetDict
from tqdm import tqdm
from torch.utils.data import DataLoader

In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Data

In [4]:
glove_path = 'glove.6B/glove.6B.50d.txt'
input_dim = 50

# Load datasets from CSV files
train_dataset = load_dataset('csv', data_files='team16_ta_train.csv', split='train')
validation_dataset = load_dataset('csv', data_files='team16_ta_valid.csv', split='train')
test_dataset = load_dataset('csv', data_files='team16_ta_test.csv', split='train')

# Create a DatasetDict
dataset = DatasetDict({
    "train": train_dataset,
    "validation": validation_dataset,
    "test": test_dataset
})

## Input Embeddings

In [5]:
# Load GloVe embeddings
def load_glove_embeddings(file_path):
    embeddings_index = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    return embeddings_index

glove_embeddings = load_glove_embeddings(glove_path)

In [6]:
special_tokens = {
    '<pad>': np.zeros(input_dim),
    '<sos>': np.random.normal(size=(input_dim,)),
    '<eos>': np.random.normal(size=(input_dim,)),
    '<unk>': np.random.normal(size=(input_dim,))
}

glove_embeddings.update(special_tokens)

In [7]:
def tokenize_en(text_list):
    if type(text_list) is str:
        text_list = [text_list]
    batch_tokens = [word_tokenize(text) for text in text_list]
    max_length = max([len(token_list) for token_list in batch_tokens])
    for token_list in batch_tokens:
        token_list.extend(['<pad>'] * (max_length - len(token_list)))
    return batch_tokens

def embedding_en(token_list):
    def compute_embedding(tokens):
        vectors = [glove_embeddings.get(token, glove_embeddings['<unk>']) for token in tokens]
        return torch.tensor(vectors, dtype=torch.float32)

    if not isinstance(token_list[0], list):
        token_list = [token_list]
    batch_tensors = [compute_embedding(tokens) for tokens in token_list]
    return torch.stack(batch_tensors)

## Output Embeddings

In [8]:
tokenizer = AlbertTokenizer.from_pretrained('ai4bharat/indic-bert')
model = AlbertModel.from_pretrained('ai4bharat/indic-bert')

def tokenize_ta(text):
    return tokenizer(text, padding=True, return_tensors='pt')

def embedding_ta(tokens:dict|torch.Tensor, model=model):   
    with torch.no_grad():
        if isinstance(tokens, torch.Tensor):
            output = model(tokens)
        else:
            output = model(**tokens)

    return output.last_hidden_state

  return self.fget.__get__(instance, owner)()


## Building Vocabolary for Target Language

In [81]:
def build_vocab_ta():
    temp = {0: 3}
    for sentence in tqdm(dataset['train']['target']):
        for token in tokenize_ta(sentence)['input_ids'][0]:
            if temp.get(token.item()) is None:
                temp[token.item()] = 1
            else:
                temp[token.item()] = temp[token.item()] + 1
    
    vocab = set()
    for tk in temp.keys():
        if temp[tk] > 1:
            vocab.add(tk)
            
    return torch.tensor(np.sort(list(vocab)))

vocab_ta = build_vocab_ta()

100%|██████████| 70000/70000 [00:38<00:00, 1825.75it/s]
8074it [00:00, 1345924.66it/s]


In [100]:

token_ta = dict()
for id, token in tqdm(enumerate(vocab_ta)):
    token_ta[token.item()] = id

def token_to_id(target):
    if target.dim() == 0:
        return torch.tensor(token_ta.get(target.item(), 2))
    else:
        return torch.stack([token_to_id(tgt) for tgt in target])

0it [00:00, ?it/s]

8074it [00:00, 532398.61it/s]


In [11]:
def build_map(vocab):
    vocab_map = dict()
    model_input = {
        'token_type_ids':   torch.tensor(0).reshape(1,1),
        'attention_mask':   torch.tensor(1).reshape(1,1)
    }
    for token in tqdm(vocab):
        model_input['input_ids'] = torch.tensor(token.item()).reshape(1,1)
        with torch.no_grad():
            vocab_map[token.item()] = model(**model_input).last_hidden_state
    
    return vocab_map

# dict with each token as key and respective embedding as value
map_ta = build_map(vocab_ta)
map_ta[0] = torch.zeros(1, 1, 768)

100%|██████████| 8073/8073 [03:31<00:00, 38.24it/s]


## Seq2Seq Model

In [31]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True, dtype=torch.float32)

    # input_seq : source sequence embeddings => (N, seq_len, em_size) / (seq_len, em_size)
    def forward(self, input_seq):
        output, curr_state = self.rnn(input_seq)
        return curr_state
    
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True, dtype=torch.float32)
        self.fc = nn.Linear(hidden_dim, output_dim, dtype=torch.float32)
    
    # y : [mostly] prev. (expected) word embedding => (N, em_size) / (1, em_size)
    def forward(self, y, prev_state):
        output, curr_state = self.rnn(y, prev_state)
        prediction = self.fc(output)
        return prediction, curr_state

In [108]:
class Seq2SeqModel(nn.Module):
    def __init__(self, src_em_dim, hidden_dim, tgt_em_dim, tgt_dim):
        super().__init__()
        self.src_em_dim = src_em_dim
        self.hidden_dim = hidden_dim
        self.tgt_em_dim = tgt_em_dim
        self.tgt_dim = tgt_dim
        self.encoder = Encoder(src_em_dim, hidden_dim)
        self.decoder = Decoder(tgt_em_dim, hidden_dim, tgt_dim)

    # source : tensor of embeddings of tokens
    # target : tensor of tokens only => (N, seq_len)
    def forward(self, source, target=None):
        batch_size = source.shape[0]
        target_len = 1000 if target is None else target.shape[1]
        
        last_encoder_state = self.encoder(source)

        outputs = []
        prev_state = last_encoder_state

        # should be (N, 1, em_size)
        decoder_input = torch.tile(map_ta[2], (batch_size, 1, 1))
        
        for t in range(1, target_len):
            decoder_output, state = self.decoder(decoder_input, prev_state)
            outputs.append(decoder_output)
            prev_state = state

            if self.training:
                temp1 = [map_ta.get(tk.item(), map_ta[2]) for tk in target[:, t]]
                decoder_input = torch.concat(temp1)
            else:
                decoder_input = torch.concat([map_ta.get(vocab_ta[torch.argmax(y)].item(), map_ta[2]) for y in decoder_output])
            
        return torch.concat(outputs, dim=1).to(device)

## Training

In [120]:
num_epochs = 1
learning_rate = 0.001
batch_size = 64

en_vocab_size = 0 # not neccesary since we are not constructing any vocabolary for english and it is not needed
en_embedding_size = input_dim
ta_vocab_size = len(vocab_ta)
ta_embedding_size = 768

hidden_dim = en_embedding_size * 2

machine = Seq2SeqModel(en_embedding_size, hidden_dim, ta_embedding_size, ta_vocab_size)

machine.encoder.to(device)
machine.decoder.to(device)
machine.to(device)

optimizer = torch.optim.Adam(machine.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

train_dataset.set_format(type='torch', columns=['source', 'target'])
validation_dataset.set_format(type='torch', columns=['source', 'target'])
test_dataset.set_format(type='torch', columns=['source', 'target'])

def collate_fn(example_list: list):
    source_list = [example['source'] for example in example_list]
    target_list = [example['target'] for example in example_list]

    source_tensor = embedding_en(tokenize_en(source_list))
    target_tensor = tokenize_ta(target_list)['input_ids']

    return source_tensor, target_tensor


In [121]:
for epoch in range(num_epochs):
    train_iterator = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    for source, target in tqdm(train_iterator):

        optimizer.zero_grad()

        outputs = machine(source, target)
        
        target_one_hot = nn.functional.one_hot(token_to_id(target[:, 1:]), num_classes=machine.tgt_dim).float()
        loss = criterion(outputs, target_one_hot)

        loss.backward()

        optimizer.step()
        
    print(f'Epoch [{epoch+1}] ----------------------------------------------')

100%|██████████| 1094/1094 [34:03<00:00,  1.87s/it]

Epoch [1] ----------------------------------------------





## Testing

In [139]:
def decode_sentences(outputs): 

    def decode_sentence(Y): # Y is 2D tensor => (seq_len, vocab_size)
        tokens = []
        for y in Y:
            token = vocab_ta[torch.argmax(y)].item()
            if token == 0:
                break
            else:
                tokens.append(token)
            
        return tokenizer.decode(tokens)

    if len(outputs.shape) == 2:   
        outputs = outputs.unsqueeze(dim=0)
    
    return [decode_sentence(Y) for Y in outputs]

In [143]:
idx = 5

en_sentence = embedding_en(tokenize_en(dataset['train']['source'][idx]))
ta_sentence = dataset['train']['target'][idx]

machine.eval()
outputs = machine(en_sentence)

ta_sentence, decode_sentences(outputs)

('எனக்கு ஒன்றும் தோன்ற வில்லை.', ['கக'])