In [1]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import spacy
from torchtext.data.metrics import bleu_score
from torchtext.data import Field, BucketIterator,TabularDataset
from torchtext.data.utils import get_tokenizer
import pandas as pd
from sklearn.model_selection import train_test_split
# CUDA=False
CUDA=torch.cuda.is_available()
device=torch.device("cuda:7" if CUDA else "cpu")

In [2]:
class _transfomer(nn.Module):
    def __init__(self,embed_size,src_vocab_size,trg_vocab_size,src_pad_idx,
                 num_heads,n_enc_layer,n_dec_layer,forward_expansion,
                 dropout,max_len,device):
        super(_transfomer, self).__init__()
        self.src_embedding=nn.Embedding(src_vocab_size,embed_size) 
        self.src_postion_embedding=nn.Embedding(max_len,embed_size) 
        self.trg_embedding=nn.Embedding(trg_vocab_size,embed_size) 
        self.trg_postion_embedding=nn.Embedding(max_len,embed_size)
        self.device=device 
        self.tranformer=nn.Transformer(embed_size,
                                       num_heads,
                                       n_enc_layer,
                                       n_dec_layer,
                                       forward_expansion,
                                       dropout,
                                       )
        self.fc_out=nn.Linear(embed_size,trg_vocab_size)
        self.dropout=nn.Dropout(dropout)
        self.src_pad_idx=src_pad_idx
        
    def _make_src_mask(self,src):
        src_mask=src.transpose(0,1)==self.src_pad_idx
        return src_mask
    
    def forward(self,src,trg):
        src_seq_len,N=src.shape
        trg_seq_len,N=trg.shape
        
        src_position=(
            torch.arange(0,src_seq_len).unsqueeze(1).expand(src_seq_len,N).to(self.device)
        )
        trg_position=(
            torch.arange(0,trg_seq_len).unsqueeze(1).expand(trg_seq_len,N).to(self.device)
        )
        
        embed_src=self.dropout(
            (self.src_embedding(src) + self.src_postion_embedding(src_position))
            )
        
        embed_trg=self.dropout(
            (self.trg_embedding(trg) + self.trg_postion_embedding(trg_position))
        )
        _src_mask_pad=self._make_src_mask(src)
        trg_mask=self.tranformer.generate_square_subsequent_mask(trg_seq_len).to(self.device)

        out=self.tranformer(
            embed_src,
            embed_trg,
            src_key_padding_mask=_src_mask_pad,
            tgt_mask=trg_mask
        )
        out=self.fc_out(out)
        
        return out
        
    

In [3]:
vocab=torch.load("/mnt/disk1/Gulshan/rnn/translation/vocab_tamil_english.pth.tar")
tamil=vocab['tamil_voc']
english=vocab['english_voc']

In [4]:
eng_vocab_size=len(vocab['tamil_voc']['itos'])
tam_vocab_size=len(vocab['english_voc']['itos'])

In [5]:
epochs=50
lr=20**-4
bs=32
src_vocab_size=eng_vocab_size
trg_vocab_size=tam_vocab_size
# embed_size=512
# num_head=8
embed_size=768
num_head=12
n_ecode_layer=8 #6
n_decode_layer=8
dropout=0.2 #2
max_len=100
forward_expansion=4
src_pad_idx=english['stoi']["<pad>"]
writer=SummaryWriter("runs/loss_plot")
pad_idx=english['stoi']["<pad>"]

In [6]:
_model = _transfomer(embed_size,src_vocab_size,trg_vocab_size,
                  src_pad_idx,num_head,n_ecode_layer,n_decode_layer,
                  forward_expansion,dropout,max_len,device).to(device)
_load=torch.load('/mnt/disk1/Gulshan/rnn/translation/checkpoints/6.25e-06_768_12_100.pth.tar')
_load=_load['model']
_model.load_state_dict(_load)
# _model



<All keys matched successfully>

In [7]:
english['stoi']["<end>"],english['stoi']["<start>"]

(3, 2)

In [8]:
def translate_sentence(model, sentence, tamil, english, device, max_length=50):
    model.eval()
    # Load german tokenizer
    # spacy_ger = spacy.load("de")

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    if type(sentence) == str:
        tokens = [token.lower() for token in sentence.split(' ')]
    else:
        tokens = [token.lower() for token in sentence]

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, "<start>")
    tokens.append("<end>")

    # Go through each german token and convert to an index
    text_to_indices = [english['stoi'][token] for token in tokens]
    print(text_to_indices)
    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)
    # print(tamil['stoi']["<start>"])
    outputs = [tamil['stoi']["<start>"]]
    # print(outputs)
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(1).to(device)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[-1, :].item()
        print(best_guess)
        # print(output[0])
        outputs.append(best_guess)
        if best_guess == tamil['stoi']["<end>"]:
            break

    translated_sentence = [tamil['itos'][idx] for idx in outputs]
    # remove start token
    print(translated_sentence[1:-1])
    output=' '.join([(i) for i in translated_sentence[1:-1]])
    return output

In [11]:
sen="come"
x=translate_sentence(_model,sen,tamil,english,device)

[2, 122, 3]
0
0
4
3
['<unk>', '<unk>', '.']


In [10]:
print(x.shape)
print(x.shape,x[-1,:].shape)
print(x.argmax(2)[-1,:])

AttributeError: 'str' object has no attribute 'shape'