In [1]:
# implement a rnn encoder decoder network for seq 2 seq
import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


### DATA

In [2]:
# to keep thing simple lets use one hot encoding representations for our sequential data
# we are going to split the data X, Y
data = [["i like to eat fish", "toi thit an ca"],
		["have you ate yet", "co an com chua"],
		["we are going to church tomorrow", "ngay may minh di le"]]

# also, most tokenizers add special tokens such as <SOS>, <EOS>, <SEP>, <MASK>, etc.
# padding done as at this level as well
data = [["<SOS> "+sentence+" <EOS>" for sentence in entry] for entry in data]

MAX_SEQ_LEN = 10
def pad(text):
	while len(text.split(" ")) < MAX_SEQ_LEN:
		text += " <PAD>"
	return text
data = [[pad(sentence) for sentence in entry] for entry in data]

print(data)

[['<SOS> i like to eat fish <EOS> <PAD> <PAD> <PAD>', '<SOS> toi thit an ca <EOS> <PAD> <PAD> <PAD> <PAD>'], ['<SOS> have you ate yet <EOS> <PAD> <PAD> <PAD> <PAD>', '<SOS> co an com chua <EOS> <PAD> <PAD> <PAD> <PAD>'], ['<SOS> we are going to church tomorrow <EOS> <PAD> <PAD>', '<SOS> ngay may minh di le <EOS> <PAD> <PAD> <PAD>']]


In [3]:
english_dict = {}
idx = 1
for entry in data:
	tokens = entry[0].split(" ")
	for token in tokens:

		if token not in english_dict:
			english_dict[token] = idx
			idx += 1

idx = 1
vietnamese_dict = {}
idx = 0
for entry in data:
	tokens = entry[1].split(" ")
	for token in tokens:

		if token not in vietnamese_dict:
			vietnamese_dict[token] = idx
			idx += 1

EMBEDDING_SIZE = 25


english = []
for entry in data:
	sentence = entry[0]
	tokens = sentence.split(" ")

	sequence = []
	for token in tokens:
		ohe = [0]*EMBEDDING_SIZE
		ohe[english_dict[token]] = 1
		sequence.append(ohe)

	english.append(sequence)

vietnamese = []
for entry in data:
	sentence = entry[1]
	tokens = sentence.split(" ")

	sequence = []
	for token in tokens:
		ohe = [0]*EMBEDDING_SIZE
		ohe[vietnamese_dict[token]] = 1
		sequence.append(ohe)

	vietnamese.append(sequence)

english = torch.Tensor(english)
vietnamese = torch.Tensor(vietnamese)

print(english.size(), vietnamese.size()) # batch size, sequence length, embedding size

torch.Size([3, 10, 25]) torch.Size([3, 10, 25])


In [4]:
# using pytorch dataset and dataloader can be helpful
class EN_VN_dataset(Dataset):
	def __init__(self, lang1, lang2):
		self.lang1 = lang1
		self.lang2 = lang2

	def __len__(self):
		return self.lang1.size(0)

	def __getitem__(self, index):
		sentence1 = self.lang1[index]
		sentence2 = self.lang2[index]
		return sentence1, sentence2

### MODEL

In [5]:
class Encoder(nn.Module):
	def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False, batch_first=True):
		super(Encoder, self).__init__()
		self.rnn = nn.RNN(input_size=input_size,
					hidden_size=hidden_size,
					num_layers=num_layers,
					bidirectional=bidirectional,
					batch_first=batch_first)

	def forward(self, x):
		output, hn = self.rnn(x)
		return output, hn

# check image on read me
# if bidirectional = True
# 	according to torch LSTM doc, hn contains the final hidden states of forward and backward
#	output contains the forward output and backward output at time step t

In [6]:
class Decoder(nn.Module):
	def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False, batch_first=True):
		super(Decoder, self).__init__()
		self.rnn = nn.RNN(input_size=input_size,
					hidden_size=hidden_size,
					num_layers=num_layers,
					bidirectional=bidirectional,
					batch_first=batch_first) # common to have decoder architecture similar to encoder
		
		self.linear = nn.Linear(hidden_size, EMBEDDING_SIZE) # classificatin head
		self.softmax = nn.Softmax(dim=-1)

	def forward(self, encoder_hidden):
		decoder_hidden = encoder_hidden

		# we are going to unraveling the translation backwards
		# begin with the first input as <EOS>
		decoder_input = []
		ohe = [0]*EMBEDDING_SIZE
		ohe[vietnamese_dict["<EOS>"]] = 1
		decoder_input.append([ohe]) # sequence length is one as we will step one at a time to change the input after each step
		decoder_input = decoder_input*encoder_hidden.size(1)
		decoder_input = torch.Tensor(decoder_input)

		decoder_outputs = []
		decoder_outputs.append(decoder_input.to(device))
		for step in range(MAX_SEQ_LEN-1):
			decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
			decoder_outputs.append(decoder_output)
			# use the output from the decoder as the input now
			_, topidx = decoder_output.topk(1) # the index of the highest value
			topidx = topidx.squeeze(-1)
			decoder_input = [] # create single sequence ohe for batch
			for entry in topidx:
				ohe = [0]*EMBEDDING_SIZE 
				ohe[entry.item()] = 1
				decoder_input.append([ohe])
			decoder_input = torch.Tensor(decoder_input).detach()
		
		decoder_outputs = torch.stack(decoder_outputs, dim=1).squeeze(2) # reshaping data
		decoder_outputs = self.softmax(decoder_outputs)
		return decoder_outputs, decoder_hidden

	def forward_step(self, x, hn):
		# function to run though rnn one step
		x, hn = self.rnn(x.to(device), hn.to(device))
		x = self.linear(x)
		return x, hn
		

### TRAINING

In [11]:
encoder = Encoder(EMBEDDING_SIZE, 128).to(device)
decoder = Decoder(EMBEDDING_SIZE, 128).to(device)
param_groups = [
	{'params': encoder.parameters(), 'lr': 0.0001},
	{'params': decoder.parameters(), 'lr': 0.0001}
]

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(param_groups)

dataset_loader = DataLoader(EN_VN_dataset(english, vietnamese), shuffle=True)

In [12]:
for epoch in range(1000):

	running_loss = 0.0
	for i, data in enumerate(dataset_loader):
		inputs, labels = data
		inputs, labels = inputs.to(device), labels.to(device)

		optimizer.zero_grad()

		encoder_output, encoder_hn = encoder(inputs)
		decoder_output, decoder_hn = decoder(encoder_hn)

		loss = loss_fn(decoder_output, labels)
		loss.backward()
		optimizer.step()

		running_loss += loss.item()
	
	print(running_loss/len(dataset_loader))

print('Finished Training')

0.9214334686597189
0.9213189681371053
0.9211345911026001
0.9211928248405457
0.92110143105189
0.9210033814112345
0.92095414797465
0.9208553234736124
0.9206564029057821
0.9205314517021179
0.9203871091206869
0.9202152887980143
0.9200607538223267
0.9198159774144491
0.9195836186408997
0.919513463973999
0.919240415096283
0.9189004898071289
0.9184925556182861
0.918007493019104
0.9174188772837321
0.916743000348409
0.9160358111063639
0.9149985512097677
0.9138615727424622
0.9124477108319601
0.9109162290891012
0.9095138510068258
0.9075016776720682
0.9051607449849447
0.9022326469421387
0.8981791933377584
0.8944084048271179
0.8897102872530619
0.8842517336209615
0.8785704175631205
0.8727697134017944
0.8674850463867188
0.8626009225845337
0.8582887848218282
0.8550631999969482
0.8521216710408529
0.8500313957532247
0.8479772806167603
0.8465405305226644
0.8450056314468384
0.8439117272694906
0.8427092234293619
0.8417660593986511
0.8409524957338969
0.8399601777394613
0.83924533923467
0.8381751775741577
0.8

### EVALUATE

In [13]:
_, input_idxs = inputs.topk(1, dim=2)
english_reverse = {}
for k, v in english_dict.items():
	english_reverse[v] = k
for idx in input_idxs.squeeze().tolist():
	print(english_reverse[idx], end=' ')

print()

_, output_idxs = decoder_output.topk(1, dim=2)
vietnamese_reverse = {}
for k,v in vietnamese_dict.items():
	vietnamese_reverse[v] = k
for idx in output_idxs.squeeze().tolist():
	print(vietnamese_reverse[idx], end=' ')

<SOS> we are going to church tomorrow <EOS> <PAD> <PAD> 
<EOS> ngay may minh di le <EOS> <PAD> <PAD> <PAD> 

In [10]:
_, labels_idxs = labels.topk(1, dim=2)
labels_idxs = torch.flip(labels_idxs, dims=[1])
for idx in labels_idxs.squeeze().tolist():
	print(vietnamese_reverse[idx], end=' ')

<PAD> <PAD> <PAD> <PAD> <EOS> ca an thit toi <SOS> 

### NOTES

At the word level, add <PAD> or any other special tokens
make sure the data has the same order, structure as how you are planning to train the network
- for example, decoder translates backwards so data should be backwards

To help understand input to the decoder, remind yourself, how is the hidden and input x shaped/processed to be input to the RNN node. (can think as concatenated or seperate input)