In [1]:
import torch
import torch.nn as nn
import random
random.seed(420)
from torch.utils.tensorboard import SummaryWriter
from dataset import german_vocab, english_vocab, train_iterator, valid_iterator, test_iterator, en_pad_idx
from utils import translate_sentence, bleu, save_checkpoint, load_checkpoint
from tqdm import tqdm



Vocabulary size: German=8000, English=6000
GERMAN | Index of <pad>: 0, <sos>: 1, <eos>: 2, <unk>: 3
ENGLISH | Index of <pad>: 0, <sos>: 1, <eos>: 2, <unk>: 3


In [2]:
class Encoder(nn.Module):
	def __init__(self, vocab_size, embed_size, hidden_size, num_stacked_layers, embed_dropout, rnn_dropout):
		super(Encoder, self).__init__()
		self.hidden_size = hidden_size
		self.num_stacked_layers = num_stacked_layers
		self.embed_layer = nn.Embedding(vocab_size, embed_size)
		self.lstm = nn.LSTM(embed_size, hidden_size, num_stacked_layers, batch_first=True, dropout=rnn_dropout)
		self.dropout = nn.Dropout(embed_dropout)

	def forward(self, x):
		# x -> batch_size, sequences
        
		# embed -> batch_size, sequences, embed_size
		embed = self.dropout(self.embed_layer(x))

		b_size = x.shape[0]
		h0 = torch.zeros([self.num_stacked_layers, b_size, self.hidden_size])
		c0 = torch.zeros([self.num_stacked_layers, b_size, self.hidden_size])
		# out - Hidden state at each timestep: (b_size, timesteps, hidden_size)
        # _ : Final hidden state and cell state (for each LSTM layer)
		out, (hidden, cell) = self.lstm(embed, (h0, c0))
		return hidden, cell

class Decoder(nn.Module):
	def __init__(self, vocab_size, embed_size, hidden_size, num_stacked_layers, embed_dropout, rnn_dropout):
		super(Decoder, self).__init__()
		self.hidden_size = hidden_size
		self.num_stacked_layers = num_stacked_layers
		self.embed_layer = nn.Embedding(vocab_size, embed_size)
		self.lstm = nn.LSTM(embed_size, hidden_size, num_stacked_layers, batch_first=True, dropout=rnn_dropout)
		self.dropout = nn.Dropout(embed_dropout)
		self.fc = nn.Linear(hidden_size, vocab_size)

	def forward(self, x, hidden, cell):
		# x -> (batch_size, 1)
        
		# embed -> (batch_size, 1, embed_size)
		embed = self.dropout(self.embed_layer(x))

		# hidden -> (num_stacked_layers, batch_size, hidden_state)
		out, (hidden, cell) = self.lstm(embed, (hidden, cell))

		# out -> (batch_size, sequences, hidden_state) 
		# x -> (batch_size, output_vocab_size)
		x = self.fc(out).squeeze()
		return x, hidden, cell

class Seq2Seq(nn.Module):
	def __init__(self, encoder, decoder):
		super(Seq2Seq, self).__init__()
		self.encoder = encoder
		self.decoder = decoder
		
	def forward(self, source, target, teacher_force_ratio=0.5):
		# source -> (batch_size, sequences)
		b_size = source.shape[0]
		target_len = target.shape[1]
		target_vocab_size = len(english_vocab)

		outputs = torch.zeros(target_len, b_size, target_vocab_size).to(device)

		# Encode the source sentence
		hidden, cell = self.encoder(source)

		# First token
		x = target[:, 0]

		for t in range(1, target_len):
			output, hidden, cell = self.decoder(x.unsqueeze(1), hidden, cell)
			outputs[t] = output

			# predicted_tokens (indices) -> (batch_size,)
			predicted_tokens = output.argmax(1)
			x = target[:, t] if random.random() < teacher_force_ratio else predicted_tokens
		return outputs

In [3]:
epochs = 25
lr = 0.001
load_model = False

# Model hyperparameters
load_model = False
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# device = torch.device("mps")
input_encoder_size = len(german_vocab)
input_decoder_size = len(english_vocab)
encoder_embedding_size = 300
decoder_embedding_size = 300
hidden_size = 1024
lstm_stacked_layers = 2
embed_dropout = 0.5
lstm_dropout = 0.0

In [4]:
# Encoder
encoder_net = Encoder(
    input_encoder_size, 
    encoder_embedding_size, 
    hidden_size, 
    lstm_stacked_layers, 
    embed_dropout, 
    lstm_dropout
).to(device)

# Decoder
decoder_net = Decoder(
    input_decoder_size,
    decoder_embedding_size,
    hidden_size,
    lstm_stacked_layers,
    embed_dropout,
    lstm_dropout
).to(device)

# Wrapper model
model = Seq2Seq(encoder_net, decoder_net).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.CrossEntropyLoss(ignore_index=en_pad_idx)

if load_model:
    load_checkpoint(torch.load("my_checkpoint_9.pth.tar"), model, optimizer)

In [5]:
total_params = sum(p.numel() for p in encoder_net.parameters())
print("Total parameters (ENCODER):", total_params)

total_params = sum(p.numel() for p in decoder_net.parameters())
print("Total parameters (DECODER):", total_params)

total_params = sum(p.numel() for p in model.parameters())
print("Total parameters (COMPLETE):", total_params)

Total parameters (ENCODER): 16228096
Total parameters (DECODER): 21778096
Total parameters (COMPLETE): 38006192


In [6]:
# Tensorboard
writer = SummaryWriter("runs/plot_loss")
step = 0

In [None]:
model.train(True)
for epoch in range(epochs):
    print(f"Epoch: {epoch}/{epochs}")
    running_loss = 0.
    for src, trg, _, _ in tqdm(train_iterator):
        inp_data = src.permute(1, 0) # (batch_size, src_len)
        target = trg.permute(1, 0) # (batch_size, trg_len)
        
        output = model(inp_data.to(device), target.to(device))  
        optimizer.zero_grad()
        
        output = output.reshape(-1, output.shape[2]) # (trg_len * batch_size, output_vocab_size)
        target = target.reshape(-1) # (trg_len * batch_size)
        loss = loss_function(output, target)
        running_loss += loss.item()
        loss.backward()
        
        # Clipping gradients
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        optimizer.step()

        # Plot to tensorboard
        writer.add_scalar("Training loss", loss, global_step=step)
        step += 1
         
    print("Loss: ", running_loss/len(train_iterator))        
    # Save model
    checkpoint = {"state_dict": model.state_dict(), "optim": optimizer.state_dict()}
    save_checkpoint(checkpoint, f"my_checkpoint_{epoch}.pth.tar")
    print("Checkpoint saved")

Epoch: 0/25


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1813/1813 [8:18:44<00:00, 16.51s/it]


Loss:  5.5367731829736355
=> Saving checkpoint
Checkpoint saved
Epoch: 1/25


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1813/1813 [2:29:55<00:00,  4.96s/it]


Loss:  5.493858135410652
=> Saving checkpoint
Checkpoint saved
Epoch: 2/25


 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                          | 1500/1813 [28:52<06:33,  1.26s/it]