In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cpu'

MAXLEN = 10

In [2]:
class Encoder(nn.Module):
    def __init__(self, input_dim ,hidden_dim,p=0.5) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p)
        self.embedding = nn.Embedding(input_dim,hidden_dim)
        self.lstm = nn.GRU(hidden_dim,hidden_dim,num_layers=1,batch_first=True)
    def forward(self,x):
        emb = self.dropout(self.embedding(x))
        output,hidden = self.lstm(emb)
        return output , hidden
class AttentionLayer(nn.Module):
    def __init__(self, hidden_dim) -> None:
        super().__init__()
        self.l1 = nn.Linear(hidden_dim,hidden_dim)
        self.l2 = nn.Linear(hidden_dim,hidden_dim)
        self.l3 = nn.Linear(hidden_dim,1)
    def forward(self,encoder_output , prev_hidden):
        score = self.l3(torch.tanh(self.l1(encoder_output)+self.l2(prev_hidden)))
        weights = torch.softmax(score,dim=1)
        weights = weights.permute(0,2,1)
        context = torch.bmm(weights,encoder_output)
        return context , weights
class Decoder(nn.Module):
    def __init__(self, output_dim , hidden_dim , p=0.5) -> None:
        super().__init__()
        self.embedding = nn.Embedding(output_dim,hidden_dim)
        self.attention = AttentionLayer(hidden_dim)
        self.lstm = nn.GRU(2*hidden_dim,hidden_dim,batch_first = True)
        self.out = nn.Linear(hidden_dim,output_dim)
        self.dropout = nn.Dropout(p)
    def forward(self,encoder_output , encoder_hidden,target=None,teacher_forcing_ratio=0.5):
        batch_size = encoder_output.shape[0]
        max_len = target.shape[1] if target is not None else MAXLEN
        decoder_input = torch.zeros(batch_size,1,dtype=torch.long).to(device)
        decoder_hidden = encoder_hidden
        outputs = []
        use_teacher_forcing = True if torch.rand(1).item() < teacher_forcing_ratio and target is not None else False

        if use_teacher_forcing:
            for i in range(max_len):
                output, decoder_hidden, weights = self.model(decoder_input, decoder_hidden, encoder_output)
                outputs.append(output)
                decoder_input = target[:, i].unsqueeze(1)  # Use target input at current time step
        else:
            for i in range(max_len):
                output, decoder_hidden, weights = self.model(decoder_input, decoder_hidden, encoder_output)
                outputs.append(output)
                decoder_input = output.argmax(-1)  # Use model's output as input at current time step

        outputs = torch.cat(outputs,dim=1)
        outputs = F.log_softmax(outputs,dim=2)
        return outputs,decoder_hidden
    def model(self,input,hidden,encoder_outputs):
        embedding = self.dropout(self.embedding(input))
        attention_hidden = hidden.permute(1,0,2)
        context,weights = self.attention(encoder_outputs,attention_hidden)
        input_gru = torch.cat((embedding,context),dim=2)
        output , hidden = self.lstm(input_gru,hidden)
        prediction = self.out(output)
        return prediction , hidden , weights

In [5]:
SOS_TOKEN = 0
EOS_TOKEN = 1
class Lang:
    def __init__(self,name) -> None:
        self.name = name
        self.index2word = {0:SOS_TOKEN,1:EOS_TOKEN}
        self.word2index = {}
        self.word2count = {}
        self.n_words = 2
    def addWord(self,word):
        if word not in self.word2index:
            self.word2index[word]=self.n_words
            self.word2count[word]=1
            self.index2word[self.n_words]=word
            self.n_words+=1
    def addSentense(self,sentence):
        for word in sentence.split():
            self.addWord(word)

In [6]:
langs = torch.load('lang.pth')
input_lang = langs['input_lang']
output_lang = langs['output_lang']
encoder = Encoder(input_lang.n_words,hidden_dim=256).to(device)
decoder = Decoder(output_lang.n_words,256).to(device)
checkpoint = torch.load('checkpoint/checkpoint.pth',map_location=torch.device(device))
encoder.load_state_dict(checkpoint['encoder'])
decoder.load_state_dict(checkpoint['decoder'])

<All keys matched successfully>

In [7]:
def indexFromSentence(lang , sent):
        return  [lang.word2index[word] for word in str(sent).split()]
def tensorFromSentence(lang, sentence):
    indexes = indexFromSentence(lang, sentence)
    indexes.append(EOS_TOKEN)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)
def anuvadkaro(sent):
  with torch.no_grad():
          inp = tensorFromSentence(input_lang,sent)
          # print(f'indexes={inp}')
          enc_out , enc_hidden = encoder(inp)
          dec_out , dec_hid = decoder(enc_out , enc_hidden)
          # print(dec_out.shape)
          dec_out = dec_out.argmax(-1)
          decoded_ids = dec_out.squeeze()
          decoded_words = []
          # print(f'decoded_ids={decoded_ids}')
          for idx in decoded_ids:
              if idx.item() == EOS_TOKEN:
                  decoded_words.append('<EOS>')
                  break
              else:
                  decoded_words.append(output_lang.index2word[idx.item()])
  return decoded_words[:-1]

In [9]:
print(anuvadkaro('highlight duration'))
print(anuvadkaro('perform action'))
print(anuvadkaro('too many selectable children')) #बहुत अधिक चयनीय शिशु हैं

['हाइलाइट', 'अवधिः', 'हाइलाइट', 'रकें']
['कार्रवाई', 'संपन्न', 'करें']
['बहुत', 'अधिक', 'चयनीय', 'शिशु', 'हैं']
