In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import lightning.pytorch as pl

# import other libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
from typing import *
import time
import math
import random
import wandb
wandb.login()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmantra7[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# define Lang
class Lang:
	def __init__(self, wordList):
		self.char2index = {'A': 0, 'Z': 1}
		self.char2count = {}
		self.index2char = {0: 'A', 1: 'Z'}
		self.n_chars = 2

		for word in wordList:
			self.addWord(word)

	def addWord(self, word):
		for char in word:
			self.addChar(char)

	def addChar(self, char):
		if char not in self.char2index:
			self.char2index[char] = self.n_chars
			self.char2count[char] = 1
			self.index2char[self.n_chars] = char
			self.n_chars += 1
		else:
			self.char2count[char] += 1

	def encode(self, word):
		encoded = [0] * len(word)
		for i in range(len(word)):
			if word[i] in self.char2index:
				encoded[i] = self.char2index[word[i]]
			else:
				encoded[i] = self.char2index['Z']
		return encoded
	
	def one_hot_encode(self, word):
		one_hot = torch.zeros(len(word), self.n_chars, device=device)
		for i in range(len(word)):
			if word[i] in self.char2index:
				one_hot[i][self.char2index[word[i]]] = 1
			else:
				one_hot[i][self.char2index['Z']] = 1			
		return one_hot
		
	def decode(self, word):
		decoded = ''
		for i in range(len(word)):
			if word[i].argmax().item() in self.index2char:
				decoded += self.index2char[word[i].argmax().item()]
			else:
				decoded += 'Z'
		return decoded
	
	def decode_one_hot(self, word):
		decoded = ''
		for i in range(len(word)):
			if word[i].argmax().item() in self.index2char:
				decoded += self.index2char[word[i].argmax().item()]
			else:
				decoded += 'Z'
		return decoded

In [4]:
def tensorFromWord(lang : Lang, word : str):
    indexes = lang.encode(word)
    indexes.append(1)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def tensorsFromPair(pair, inp_lang : Lang, out_lang : Lang):
    input_tensor = tensorFromWord(inp_lang, pair[0])
    target_tensor = tensorFromWord(out_lang, pair[1])
    return (input_tensor.unsqueeze(1), target_tensor)

In [5]:
# create dataset
class AksharantarDataset(Dataset):
	def __init__(self, data, inp_lang, out_lang):
		self.data = data
		self.inp_lang = inp_lang
		self.out_lang = out_lang

	def __len__(self):
		return len(self.data)

	def __getitem__(self, idx):
		if torch.is_tensor(idx):
			idx = idx.tolist()

		inp_seq = self.inp_lang.one_hot_encode(self.data['input_seq'][idx]).unsqueeze(1)
		out_seq = self.out_lang.one_hot_encode(self.data['target_seq'][idx]).unsqueeze(1)

		sample = {'input_seq': inp_seq, 'target_seq': out_seq}
		return sample

In [6]:
def DataLoader(lang : str):
	train_data = pd.read_csv(f'aksharantar_sampled/{lang}/{lang}_train.csv')
	test_data = pd.read_csv(f'aksharantar_sampled/{lang}/{lang}_test.csv')
	valid_data = pd.read_csv(f'aksharantar_sampled/{lang}/{lang}_valid.csv')
	
	train_data.columns = ['input_seq', 'target_seq']
	test_data.columns = ['input_seq', 'target_seq']
	valid_data.columns = ['input_seq', 'target_seq']

	return train_data, test_data, valid_data

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [7]:
guj_data = DataLoader('tam')

guj_inp_lang = Lang(guj_data[0]['input_seq'])
guj_out_lang = Lang(guj_data[0]['target_seq'])
guj_out_lang.n_chars

48

In [8]:
def get_cell(str):
	if str == 'lstm':
		return nn.LSTM
	elif str == 'gru':
		return nn.GRU
	elif str == 'rnn':
		return nn.RNN
	else:
		raise ValueError('Invalid cell type')

In [9]:
class EncoderRNN(nn.Module):
	def __init__(self, input_size, embed_size, hidden_size, n_layers=1, type='gru', dropout=0.2):
		super(EncoderRNN, self).__init__()
		self.hidden_size = hidden_size
		self.n_layers = n_layers

		self.embedding = nn.Embedding(input_size, embed_size)
		self.cell = get_cell(type)(embed_size, hidden_size, n_layers, dropout=dropout)

	def forward(self, input, hidden):

		embedded = self.embedding(input)
		output = embedded
		output, hidden = self.cell(output, hidden)
		return output, hidden

	def initHidden(self):
		return torch.zeros(self.n_layers, 1, self.hidden_size, device=device)
	
class DecoderRNN(nn.Module):
	def __init__(self, hidden_size, output_size, n_layers=1, type='gru', dropout=0.2):
		super(DecoderRNN, self).__init__()
		self.hidden_size = hidden_size
		self.n_layers = n_layers

		self.embedding = nn.Embedding(output_size, hidden_size)
		self.cell = get_cell(type)(hidden_size, hidden_size, n_layers, dropout=dropout)
		self.out = nn.Linear(hidden_size, output_size)
		self.softmax = nn.LogSoftmax(dim=1)

	def forward(self, input, hidden):
		output = self.embedding(input).view(1, 1, -1)
		output = F.relu(output)
		output, hidden = self.cell(output, hidden)
		output = self.softmax(self.out(output[0]))
		return output, hidden

	def initHidden(self):
		return torch.zeros(self.n_layers, 1, self.hidden_size, device=device)


In [10]:
class Seq2Seq(nn.Module):
	def __init__(self, input_size, hidden_size, embed_size, output_size, n_layers=1, type='gru', dropout=0.2):
		super(Seq2Seq, self).__init__()
		self.input_size = input_size
		self.hidden_size = hidden_size
		self.output_size = output_size
		self.n_layers = n_layers

		self.encoder = EncoderRNN(input_size, embed_size, hidden_size, n_layers, type, dropout).to(device)
		self.decoder = DecoderRNN(hidden_size, output_size, n_layers, type, dropout).to(device)

	def forward(self, input_tensor, target_tensor, max_length=50):
		encoder_hidden = self.encoder.initHidden()

		input_length = input_tensor.size(0)
		target_length = target_tensor.size(0)

		encoder_outputs = torch.zeros(max_length, self.encoder.hidden_size, device=device)

		for ei in range(input_length):
			encoder_output, encoder_hidden = self.encoder(
				input_tensor[ei], encoder_hidden)
			encoder_outputs[ei] = encoder_output[0, 0]

		decoder_input = torch.tensor([[0]], device=device)  # SOS

		decoder_hidden = encoder_hidden
		use_teacher_forcing = True if random.random() < 0.5 else False

		decoder_outputs = []
		if use_teacher_forcing:
			# Teacher forcing: Feed the target as the next input
			for di in range(target_length):
				decoder_output, decoder_hidden = self.decoder(
					decoder_input, decoder_hidden)
				decoder_outputs.append(decoder_output)
				decoder_input = target_tensor[di]  # Teacher forcing
		else:
			# Without teacher forcing: use its own predictions as the next input
			for di in range(target_length):
				decoder_output, decoder_hidden = self.decoder(
					decoder_input, decoder_hidden)
				decoder_outputs.append(decoder_output)
				topv, topi = decoder_output.topk(1)
				decoder_input = topi.squeeze().detach()  # detach from history as input

				if decoder_input.item() == 1:
					break

		return decoder_outputs

	def predict(self, input_tensor, max_length = 50):
		encoder_hidden = self.encoder.initHidden()

		input_length = input_tensor.size(0)

		encoder_outputs = torch.zeros(max_length, self.encoder.hidden_size, device=device)

		for ei in range(input_length):
			encoder_output, encoder_hidden = self.encoder(
				input_tensor[ei], encoder_hidden)
			encoder_outputs[ei] = encoder_output[0, 0]

		decoder_input = torch.tensor([[0]], device=device)  # SOS

		decoder_outputs = []

		decoder_hidden = encoder_hidden
		for di in range(max_length):
			decoder_output, decoder_hidden = self.decoder(
				decoder_input, decoder_hidden)
			topv, topi = decoder_output.data.topk(1)
			if topi.item() == 1:
				break
			decoder_outputs.append(decoder_output)

			decoder_input = topi.squeeze().detach()
		
		return decoder_outputs

In [11]:
class Translator:
	def __init__(self, lang, embed_size=10, hidden_size=10, n_layers=1, max_length=50, type='gru', dropout=0.2):
		self.train_data, self.test_data, self.valid_data = DataLoader(lang)

		self.inp_lang = Lang(self.train_data['input_seq'])
		self.out_lang = Lang(self.train_data['target_seq'])

		self.model = Seq2Seq(self.inp_lang.n_chars, hidden_size, embed_size, self.out_lang.n_chars, n_layers, type, dropout)
		self.criterion = nn.NLLLoss()
		self.max_length = max_length

		self.pairs = [tensorsFromPair((self.train_data['input_seq'][i], self.train_data['target_seq'][i]), self.inp_lang, self.out_lang)
								for i in range(len(self.train_data))]

	def trainOne(self, input_tensor, target_tensor):
		self.encoder_optim.zero_grad()
		self.decoder_optim.zero_grad()

		decoder_outputs = self.model.forward(input_tensor, target_tensor, self.max_length)

		loss = 0
		for di in range(len(decoder_outputs)):
			loss += self.criterion(decoder_outputs[di], target_tensor[di])
		loss.backward()

		self.encoder_optim.step()
		self.decoder_optim.step()

		return loss.item() / target_tensor.size(0)

	def train(self,epoch=1, n_iters=10000, print_every=1000, plot_every=100, learning_rate=0.01, rand=False, dumpName='model'):
		self.encoder_optim = optim.SGD(self.model.encoder.parameters(), lr=learning_rate)
		self.decoder_optim = optim.SGD(self.model.decoder.parameters(), lr=learning_rate)

		start = time.time()
		train_loss = []
		train_acc = []
		valid_loss = []
		valid_acc = []

		for i in range(epoch):
			print_loss_total = 0
			tot_loss = 0
			print("Epoch: ", i)
			if rand:
				training_pairs = [random.choice(self.pairs) for i in range(n_iters)]
			else:
				training_pairs = self.pairs

			for iter in tqdm(range(1, len(training_pairs) + 1)):
				training_pair = training_pairs[iter - 1]
				input_tensor = training_pair[0]
				target_tensor = training_pair[1]

				loss = self.trainOne(input_tensor, target_tensor)
				print_loss_total += loss
				tot_loss += loss

				if iter % print_every == 0:
					print_loss_avg = print_loss_total / print_every
					print_loss_total = 0
					print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
												iter, iter / n_iters * 100, print_loss_avg))
			train_loss.append(tot_loss / len(training_pairs))
			train_acc.append(self.accuracy(self.valid_data))
			valid_stats = self.calculate_stats(self.valid_data)
			valid_loss.append(valid_stats[0])
			valid_acc.append(valid_stats[1])
			pickle.dump(self, open(dumpName + str(i) + '.pkl', 'wb'))
		return train_loss, train_acc, valid_loss, valid_acc
					
	def accuracy(self, data):
		return np.sum([(self.translate(data['input_seq'][i]) == data['target_seq'][i]) for i in range(len(data))]) / len(data)
				
	def translate(self, word):
		tensor = tensorFromWord(self.inp_lang, word).unsqueeze(1)
		outs = self.model.predict(tensor, self.max_length)
		return self.out_lang.decode(outs)
	
	def calculate_stats(self, data):
		with torch.no_grad():
			loss = 0
			acc = 0
			for i in range(len(data)):
				tensor = tensorFromWord(self.inp_lang, data['input_seq'][i]).unsqueeze(1)
				output = self.model.predict(tensor, self.max_length)
				word = self.out_lang.decode(output)
				target = data['target_seq'][i]
				acc += (word == target)
				target = tensorFromWord(self.out_lang, data['target_seq'][i])

				mx_len = min(len(output), len(target))

				while(len(output) < mx_len):
					output = torch.cat((output, self.out_lang.one_hot_encode_char('Z')), 0, device=device)

				while(len(target) < mx_len):
					target = torch.cat((target, self.out_lang.one_hot_encode_char('Z')), 0, device=device)

				for di in range(mx_len):
					loss += self.criterion(output[di], target[di]) / mx_len
			return loss / len(data), acc / len(data)

In [12]:
tam_trans = Translator('tam', 16, 128, 1, 50, 'gru', 0.2)
# tam_trans.train(epoch=10, print_every=1000, plot_every=100, learning_rate=0.001, rand=False, dumpName='tam_model')



In [13]:
trans = pickle.load(open('trans_128_64_1_gru_0.3_9.pkl', 'rb'))

In [14]:
trans.accuracy(trans.valid_data)

KeyboardInterrupt: 

In [20]:
trans.translate('pani')

'પાની'