In [170]:
import re
import random
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.utils.data import Dataset
from collections import Counter 
from torch.nn.utils.rnn import pad_sequence
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [171]:
with open('tam.txt') as fp:
    data = fp.readlines()    

In [172]:
english = []
tamil = []

for i in data:
    english.append(i.split('\t')[0])
    tamil.append(i.split('\t')[1])

In [173]:
english[0] , tamil[0]

('I slept.', 'நான் தூங்கினேன்.')

In [174]:
class Seq2SeqDataset(Dataset):
    def __init__(self, source_sentences, target_sentences, src_vocab, tgt_vocab):
        self.source_sentences = source_sentences
        self.target_sentences = target_sentences
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.source_sentences)

    def __getitem__(self, idx):
        src_sentence = [self.src_vocab[token] for token in self.source_sentences[idx].split()]
        tgt_sentence = [self.tgt_vocab[token] for token in self.target_sentences[idx].split()]
                
        return torch.tensor(( src_sentence )), torch.tensor((tgt_sentence))
    
    def collate_fn(batch):
        src_batch, tgt_batch = zip(*batch)

        src_lengths = [len(seq) for seq in src_batch]
        tgt_lengths = [len(seq) for seq in tgt_batch]
        
        src_padded = pad_sequence([torch.tensor(seq) if isinstance(seq, list) else seq for seq in src_batch], padding_value=500)
        tgt_padded = pad_sequence([torch.tensor(seq) if isinstance(seq, list) else seq for seq in tgt_batch], padding_value=500)
        
        return src_padded, tgt_padded


In [175]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, d_model, nhead, num_layers, dim_feedforward, dropout):
        super(Seq2SeqTransformer, self).__init__()

        self.encoder_embedding = nn.Embedding(input_dim, d_model)
        self.decoder_embedding = nn.Embedding(output_dim, d_model)

        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
                                                                                     dim_feedforward=dim_feedforward,
                                                                                     dropout=dropout), num_layers=num_layers)
        self.generator = nn.Linear(d_model, output_dim)

    def forward(self, src, tgt):
        src = self.encoder_embedding(src)
        tgt = self.decoder_embedding(tgt)

        src = src.permute(1, 0, 2)  
        tgt = tgt.permute(1, 0, 2)

        memory = self.transformer_encoder(src)
        output = self.generator(tgt)

        return output.permute(1, 0, 2)  


In [176]:
src_tokens = [token for sentence in english for token in sentence.split()]
tgt_tokens = [token for sentence in tamil for token in sentence.split()]

src_vocab = {token: idx + 1 for idx, (token, count) in enumerate(Counter(src_tokens).items())}
tgt_vocab = {token: idx + 1 for idx, (token, count) in enumerate(Counter(tgt_tokens).items())}

src_vocab['<pad>'] = 0
tgt_vocab['<pad>'] = 0


dataset = Seq2SeqDataset(english, tamil, src_vocab, tgt_vocab)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)


In [168]:
for i in data_loader:
    print(i)

[tensor([[ 58, 313,  90, 314, 101, 315]]), tensor([[  1,  86, 377, 378, 379]])]
[tensor([[ 58, 342, 264,  52, 358, 359]]), tensor([[427, 428, 429,   6,  55]])]
[tensor([[144,  55,   1,  28, 239]]), tensor([[281, 282, 283]])]
[tensor([[211, 212, 150, 101, 213]]), tensor([[251, 252, 253]])]
[tensor([[ 68, 116, 101, 117]]), tensor([[  1,  86, 120, 121]])]
[tensor([[ 1, 92, 14, 93]]), tensor([[ 1, 89, 90, 91]])]
[tensor([[11, 12]]), tensor([[10, 11]])]
[tensor([[ 44, 178, 101, 179]]), tensor([[208, 209]])]
[tensor([[ 58, 229,  85, 101,  52, 230]]), tensor([[  6, 269, 270]])]
[tensor([[259, 260,  14, 261]]), tensor([[279, 187, 319]])]
[tensor([[ 11,   8, 347, 203, 101, 351]]), tensor([[ 10, 416, 417]])]
[tensor([[208, 289, 290, 291]]), tensor([[348, 162, 349]])]
[tensor([[142, 143, 114,  14, 115]]), tensor([[  1,  68, 157,  82, 158, 159]])]
[tensor([[47, 37, 14, 48]]), tensor([[12, 41]])]
[tensor([[155,  34, 156,  14, 157]]), tensor([[170, 128, 171, 172]])]
[tensor([[ 58,   8, 203, 101, 249