## Loading data

In [1]:
with open('data/aksharantar_sampled/tam/tam_train.csv') as f:
    data_pairs = f.readlines()
data_given = [pair.split(',')[0].strip().lower() for pair in data_pairs]
data_target = [pair.split(',')[1].strip('\n').strip() for pair in data_pairs]
len(data_given), len(data_target)

(51200, 51200)

## Building the alphabet

In [2]:
class Alphabet():
    def __init__(self) -> None:
        self.letter_to_index = {}
        self.index_to_letter = ['SOW', 'EOW', 'UNK']
        self.letter_count = 3
    
    def addLetter(self, letter: str) -> None:
        if letter not in self.letter_to_index:
            self.letter_to_index[letter] = self.letter_count
            self.index_to_letter.append(letter)
            self.letter_count += 1

In [3]:
eng_alphabet, tam_alphabet = Alphabet(), Alphabet()
for word in data_given:
    for letter in word:
        eng_alphabet.addLetter(letter)
for word in data_target:
    for letter in word:
        tam_alphabet.addLetter(letter)
print(eng_alphabet.letter_count, tam_alphabet.letter_count)

29 49


## Seq2Seq model

In [47]:
import torch
from torch import nn
from torch.functional import F

class Encoder(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        cell_type: nn.Module = nn.RNN,
        num_layers: int = 1
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.is_lstm = (cell_type == nn.LSTM)
        self.embedding = nn.Embedding(num_embeddings=input_size, embedding_dim=hidden_size)
        self.encoder = cell_type(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers) 
    
    def forward(self, x, hidden):
        output = self.embedding(x).reshape(1, 1, -1) 
        output, hidden = self.encoder(output, hidden)
        return output, hidden
 
    def initHidden(self):
        return torch.zeros(self.num_layers, 1, self.hidden_size)

In [48]:
class Decoder(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        output_size: int,
        cell_type: nn.Module = nn.RNN,
        num_layers: int = 1
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.is_lstm = (cell_type == nn.LSTM)
        self.embedding = nn.Embedding(num_embeddings=output_size, embedding_dim=hidden_size)
        self.decoder = cell_type(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers) 
        self.out = nn.Linear(in_features=hidden_size, out_features=output_size)
        self.softmax = nn.LogSoftmax(dim=1)
    
    def forward(self, x, hidden):
        output = self.embedding(x).reshape(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.decoder(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(self.num_layers, 1, self.hidden_size)


## Training

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [7]:
def word_to_tensor(alphabet: Alphabet, word: str) -> torch.Tensor:
    chars = list(alphabet.letter_to_index[letter] for letter in word) + [1]
    return torch.tensor(chars, dtype=torch.long).reshape(-1, 1)

In [54]:
encoder = Encoder(eng_alphabet.letter_count, 64, num_layers=2, cell_type=nn.RNN).to(device)
decoder = Decoder(64, tam_alphabet.letter_count, num_layers=1, cell_type=nn.RNN).to(device)
optimizer_enc = torch.optim.Adam(encoder.parameters(), lr=1e-4)
optimizer_dec = torch.optim.Adam(decoder.parameters(), lr=1e-4)
loss_fn = torch.nn.NLLLoss()

In [55]:
teacher_forcing_ratio = 0.5

for epoch in range(10):
    print(f'Epoch {epoch+1}/10')
    avg_loss = 0.0
    for train_index in range(len(data_given)):
    # for train_index in range(1000):
        hidden_enc = encoder.initHidden().to(device)
        cell_enc = encoder.initHidden().to(device)

        optimizer_enc.zero_grad()
        optimizer_dec.zero_grad()

        input_vector = word_to_tensor(eng_alphabet, data_given[train_index])
        input_len = len(input_vector)

        loss = 0.0

        for char in input_vector:
            if encoder.is_lstm:
                output_enc, (hidden_enc, cell_enc) = encoder(char.to(device), (hidden_enc, cell_enc))
            else:
                output_enc, hidden_enc = encoder(char.to(device), hidden_enc)
        
        input_dec = torch.tensor([[0]])
        target_vector = word_to_tensor(tam_alphabet, data_target[train_index])
        target_len = len(target_vector)

        if decoder.is_lstm:
            cell_dec = torch.cat([cell_enc[-1].reshape(1, 1, -1)]*decoder.num_layers).to(device)
        hidden_dec = torch.cat([hidden_enc[-1].reshape(1, 1, -1)]*decoder.num_layers).to(device)

        use_teacher_forcing = True if torch.rand(1) < teacher_forcing_ratio else False

        if use_teacher_forcing:
            for di in range(target_len):
                if decoder.is_lstm:
                    output_dec, (hidden_dec, cell_dec) = decoder(input_dec.to(device), (hidden_dec, cell_dec))
                else:
                    output_dec, hidden_dec = decoder(input_dec.to(device), hidden_dec)
                loss += loss_fn(output_dec, target_vector[di].to(device))
                input_dec = target_vector[di]
        else:
            for di in range(target_len):
                if decoder.is_lstm:
                    output_dec, (hidden_dec, cell_dec) = decoder(input_dec.to(device), (hidden_dec, cell_dec))
                else:
                    output_dec, hidden_dec = decoder(input_dec.to(device), hidden_dec)
                input_dec = output_dec.argmax(dim=1).squeeze().detach()
                loss += loss_fn(output_dec, target_vector[di].to(device))
                if input_dec.item() == 1:
                    break
        
        loss.backward()
        avg_loss += loss.detach() / target_len

        optimizer_enc.step()
        optimizer_dec.step()

        if (train_index+1) % 5000 == 0:
            avg_loss /= 5000
            print(f'[Batch {train_index+1}/{len(data_given)}] ==> {avg_loss}')
            avg_loss = 0.

Epoch 1/10
[Batch 5000/51200] ==> 2.8134870529174805
[Batch 10000/51200] ==> 2.5608909130096436
[Batch 15000/51200] ==> 2.5125820636749268
[Batch 20000/51200] ==> 2.5000059604644775
[Batch 25000/51200] ==> 2.493377923965454
[Batch 30000/51200] ==> 2.476017951965332
[Batch 35000/51200] ==> 2.465585470199585
[Batch 40000/51200] ==> 2.448047637939453
[Batch 45000/51200] ==> 2.433969020843506
[Batch 50000/51200] ==> 2.4359521865844727
Epoch 2/10
[Batch 5000/51200] ==> 2.4315733909606934
[Batch 10000/51200] ==> 2.4206769466400146
[Batch 15000/51200] ==> 2.4196364879608154
[Batch 20000/51200] ==> 2.40851092338562
[Batch 25000/51200] ==> 2.3925869464874268
[Batch 30000/51200] ==> 2.394974946975708
[Batch 35000/51200] ==> 2.399857759475708
[Batch 40000/51200] ==> 2.382577657699585
[Batch 45000/51200] ==> 2.3814778327941895
[Batch 50000/51200] ==> 2.3724300861358643
Epoch 3/10
[Batch 5000/51200] ==> 2.372814178466797
[Batch 10000/51200] ==> 2.3663806915283203
[Batch 15000/51200] ==> 2.372725009

KeyboardInterrupt: 

In [28]:
chosen_index = 0

with torch.inference_mode():
    input_vector = word_to_tensor(eng_alphabet, data_given[chosen_index])
    input_len = len(input_vector)
    hidden_enc = encoder.initHidden().to(device)

    for char in input_vector:
        output_enc, hidden_enc = encoder(char.to(device), hidden_enc)

    input_dec = torch.tensor([[0]])
    hidden_dec = hidden_enc.to(device)

    translit_chars = []

    max_length = 100
    while max_length > 0:
        output_dec, hidden_dec = decoder(input_dec.to(device), hidden_dec)
        pred_char_index = output_dec.data.argmax()
        if pred_char_index.item() == 1:
            translit_chars.append('EOW')
            break
        else:
            translit_chars.append(tam_alphabet.index_to_letter[pred_char_index])
        input_dec = pred_char_index.squeeze().detach()
        max_length -= 1

In [29]:
''.join(translit_chars[:-1]), data_target[chosen_index], data_given[chosen_index]

('தொட்டர்யார', 'தொட்டாச்சார்ய', 'thottacharya')

In [21]:
torch.cuda.empty_cache()

In [30]:
with open('data/aksharantar_sampled/tam/tam_valid.csv') as f:
    val_data_pairs = f.readlines()
val_data_given = [pair.split(',')[0].strip().lower() for pair in val_data_pairs]
val_data_target = [pair.split(',')[1].strip('\n').strip() for pair in val_data_pairs]
len(val_data_given), len(val_data_target)

(4096, 4096)

In [49]:
with torch.inference_mode():
    acc = 0.0
    for chosen_index in range(len(val_data_given)):
        input_vector = word_to_tensor(eng_alphabet, val_data_given[chosen_index])
        input_len = len(input_vector)
        hidden_enc = encoder.initHidden().to(device)

        for char in input_vector:
            output_enc, hidden_enc = encoder(char.to(device), hidden_enc)

        input_dec = torch.tensor([[0]])
        hidden_dec = hidden_enc.to(device)

        translit_chars = []

        max_length = 100
        while max_length > 0:
            output_dec, hidden_dec = decoder(input_dec.to(device), hidden_dec)
            pred_char_index = output_dec.data.argmax()
            if pred_char_index.item() == 1:
                translit_chars.append('EOW')
                break
            else:
                translit_chars.append(tam_alphabet.index_to_letter[pred_char_index])
            input_dec = pred_char_index.squeeze().detach()
            max_length -= 1
        if ''.join(translit_chars) == val_data_target[chosen_index]+'EOW':
            acc += 1.
    acc /= len(val_data_given)

KeyboardInterrupt: 

In [50]:
acc

0.0