## Loading data

In [1]:
with open('data/aksharantar_sampled/tam/tam_train.csv') as f:
    data_pairs = f.readlines()
data_given = [pair.split(',')[0].strip().lower() for pair in data_pairs]
data_target = [pair.split(',')[1].strip('\n').strip() for pair in data_pairs]
len(data_given), len(data_target)

(51200, 51200)

## Building the alphabet

In [2]:
class Alphabet():
    def __init__(self) -> None:
        self.letter_to_index = {}
        self.index_to_letter = ['SOW', 'EOW', 'UNK']
        self.letter_count = 3
    
    def addLetter(self, letter: str) -> None:
        if letter not in self.letter_to_index:
            self.letter_to_index[letter] = self.letter_count
            self.index_to_letter.append(letter)
            self.letter_count += 1

In [3]:
eng_alphabet, tam_alphabet = Alphabet(), Alphabet()
for word in data_given:
    for letter in word:
        eng_alphabet.addLetter(letter)
for word in data_target:
    for letter in word:
        tam_alphabet.addLetter(letter)
print(eng_alphabet.letter_count, tam_alphabet.letter_count)

29 49


## Seq2Seq model

In [14]:
import torch
from torch import nn
from torch.functional import F

class Encoder(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        cell_type: nn.Module = nn.RNN,
        num_layers: int = 1
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(num_embeddings=input_size, embedding_dim=hidden_size)
        self.encoder = cell_type(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers)
    
    def forward(self, x, hidden):
        output = self.embedding(x).reshape(1, 1, -1) 
        output, hidden = self.encoder(output, hidden)
        return output, hidden
 
    def initHidden(self):
        return torch.zeros(2, 1, self.hidden_size)

In [15]:
class Decoder(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        output_size: int,
        cell_type: nn.Module = nn.RNN,
        num_layers: int = 1
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(num_embeddings=output_size, embedding_dim=hidden_size)
        self.decoder = cell_type(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers)
        self.out = nn.Linear(in_features=hidden_size, out_features=output_size)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x, hidden):
        output = self.embedding(x).reshape(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.decoder(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(2, 1, self.hidden_size)


## Training

In [16]:
device = 'cpu'

In [17]:
def word_to_tensor(alphabet: Alphabet, word: str) -> torch.Tensor:
    chars = list(alphabet.letter_to_index[letter] for letter in word) + [1]
    return torch.tensor(chars, dtype=torch.long).reshape(-1, 1)

In [24]:
encoder = Encoder(eng_alphabet.letter_count, 64, num_layers=2, cell_type=nn.GRU).to(device)
decoder = Decoder(64, tam_alphabet.letter_count, num_layers=2, cell_type=nn.GRU).to(device)
optimizer_enc = torch.optim.SGD(encoder.parameters(), lr=1e-3)
optimizer_dec = torch.optim.SGD(decoder.parameters(), lr=1e-3)
loss_fn = torch.nn.NLLLoss()

In [25]:
teacher_forcing_ratio = 0.5

for train_index in range(len(data_given)):
    hidden_enc = encoder.initHidden()

    optimizer_enc.zero_grad()
    optimizer_dec.zero_grad()

    input_vector = word_to_tensor(eng_alphabet, data_given[train_index])
    input_len = len(input_vector)

    loss = 0.0

    for char in input_vector:
        output_enc, hidden_enc = encoder(char, hidden_enc)
    
    input_dec = torch.tensor([[0]])
    target_vector = word_to_tensor(tam_alphabet, data_target[train_index])
    target_len = len(target_vector)

    hidden_dec = hidden_enc

    use_teacher_forcing = True if torch.rand(1) < teacher_forcing_ratio else False

    if use_teacher_forcing:
        for di in range(target_len):
            output_dec, hidden_dec = decoder(input_dec, hidden_dec)
            loss += loss_fn(output_dec, target_vector[di])
            input_dec = target_vector[di]
    else:
        for di in range(target_len):
            output_dec, hidden_dec = decoder(input_dec, hidden_dec)
            input_dec = output_dec.argmax(dim=1).squeeze().detach()
            loss += loss_fn(output_dec, target_vector[di])
            if input_dec.item() == 1:
                break
    
    loss.backward()

    optimizer_enc.step()
    optimizer_dec.step()

    print(f'[Batch {train_index+1}/{len(data_given)}] ==> {loss.item() / target_len}')

[Batch 1/51200] ==> 3.8922669546944753
[Batch 2/51200] ==> 3.9019040194424717
[Batch 3/51200] ==> 3.9124839305877686
[Batch 4/51200] ==> 3.893746270073785
[Batch 5/51200] ==> 3.918499310811361
[Batch 6/51200] ==> 3.8995184647409538
[Batch 7/51200] ==> 3.814893510606554
[Batch 8/51200] ==> 3.908292917104868
[Batch 9/51200] ==> 3.9072017669677734
[Batch 10/51200] ==> 3.8632586552546573
[Batch 11/51200] ==> 3.943169275919596
[Batch 12/51200] ==> 3.932149887084961
[Batch 13/51200] ==> 3.869964175754123
[Batch 14/51200] ==> 3.8370844523111978
[Batch 15/51200] ==> 3.8664524371807394
[Batch 16/51200] ==> 3.9024386088053387
[Batch 17/51200] ==> 3.8220576236122534
[Batch 18/51200] ==> 3.8607406616210938
[Batch 19/51200] ==> 3.8458255767822265
[Batch 20/51200] ==> 3.8517586203182446
[Batch 21/51200] ==> 3.839854876200358
[Batch 22/51200] ==> 3.7801915486653646
[Batch 23/51200] ==> 3.8682943490835338
[Batch 24/51200] ==> 3.8552936553955077
[Batch 25/51200] ==> 3.878503163655599
[Batch 26/51200] =

In [29]:
with torch.inference_mode():
    input_vector = word_to_tensor(eng_alphabet, data_given[51196])
    input_len = len(input_vector)
    hidden_enc = encoder.initHidden()

    for char in input_vector:
        output_enc, hidden_enc = encoder(char, hidden_enc)

    input_dec = torch.tensor([[0]])
    hidden_dec = hidden_enc

    translit_chars = []

    while True:
        output_dec, hidden_dec = decoder(input_dec, hidden_dec)
        pred_char_index = output_dec.data.argmax()
        if pred_char_index.item() == 1:
            translit_chars.append('EOW')
            break
        else:
            translit_chars.append(tam_alphabet.index_to_letter[pred_char_index])
        input_dec = pred_char_index.squeeze().detach()

In [34]:
''.join(translit_chars[:-1])

'பட்டப்பட்டத்்்'