# Libraries Related

In [27]:
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
import random
from tqdm import tqdm

# custom stuff
from data import load_data
from model_parts import Seq2Seq, BottleNeckCNN, AttnDecoder, Encoder
from utils import Config
# from train import train

#### Configuration Class

In [2]:
config = Config()

#### Data Preprocessing

In [3]:
train_iter, val_iter, test_iter = load_data(config=config, path='zinc_250k_cat_encoded.csv')

# Training Part

In [29]:
def train(model, data, hidden_dec, criterion,
          optimizer, teacher_forcing=False, beta=1):
  # Nested function to handle teacher forcing
  # This function does not take in account bottleneck, so the output will be
  # Invalid, modify this before using no teacher forcing.
  def train_without_teachforce():
    enc_out, _ = model.encoder(data, None)
    enc_out, bottleneck, mu, logvar = seq2seq.bottleneck(enc_out.permute(0,2,1))
    enc_out = enc_out.permute(0, 2, 1)
    current_token = data[:,0].unsqueeze(1)
    h = hidden_dec
    loss = 0
    # Keep Track of recon loss
    recon = 0
    for i in range(data.shape[1] - 1):
      out, h = model.decoder(current_token, h, enc_out)
      # Loss Calculation
      rec_loss = criterion(out.permute(0,2,1), data[:,i+1].unsqueeze(1))
      loss += rec_loss
      # Update recon for printing puposes
      recon += rec_loss
      # Set next token
      current_token = out.topk(1, dim=2)[1].squeeze(-1).detach()
    # Calculate and add KLD Loss
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss += beta * kld_loss
    
    return loss, recon.item(), kld_loss.item()
  #______________________________________________
  if teacher_forcing:
    out, bottleneck, mu, logvar = model(data, hidden_dec)
    rec_loss = criterion(out.permute(0,2,1), data[:,1:])
    kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss = rec_loss + beta * kld_loss
    rec_loss = rec_loss
    kld_loss = kld_loss
  else:
    loss, rec_loss, kld_loss = train_without_teachforce()
  optimizer.zero_grad()
  loss.backward()
  torch.nn.utils.clip_grad_value_(model.parameters(), config.grad_clip)
  optimizer.step()

  return loss, rec_loss, kld_loss

In [25]:
seq2seq = Seq2Seq(config).to(config.device)
criterion = torch.nn.NLLLoss(ignore_index=0, reduction="sum")
optimizer = torch.optim.Adam(seq2seq.parameters(), lr=config.lr)    

In [None]:
seq2seq.train()
running_loss = [0, 0, 0]
for j in range(15):
    data_iterable = tqdm(enumerate(train_iter))
    for i, data in data_iterable:
        data = data.to(config.device)
        hidden_dec = seq2seq.decoder.init_hidden(data.shape[0], config.device)
        # Teacher Forcing
        if random.random() < config.teachforce_ratio:
            do_teach_force = True
        else:
            do_teach_force = False
        losses = train(seq2seq, data, hidden_dec,
                                criterion, optimizer, do_teach_force, config.beta)
        for n in range(len(losses)):
            running_loss[n] += losses[n]/config.batch_size
        print_after = 10
#         data_iterable.set_description(f"Loss: {loss:.2f} Recon: {recon:.2f} KLD: {kld:.2f}")
        if (i+1) % print_after == 0:
            loss = running_loss[0]/print_after
            recon = running_loss[1]/print_after
            kld = running_loss[2]/print_after
            data_iterable.set_description(f"Loss: {loss:.2f} Recon: {recon:.2f} KLD: {kld:.2f}")
            running_loss = [0, 0, 0]


Loss: 115.32 Recon: 115.32 KLD: 2112.03: : 47it [00:18,  2.62it/s]

In [None]:
def validate(model, data, hidden_dec, criterion):
  with torch.no_grad():
    out, _, _, _ = model(data, hidden_dec)
    # loss = criterion(out.permute(0,2,1), data[:,1:])
    loss = torch.nn.NLLLoss(ignore_index=0)(out.permute(0,2,1), data[:,1:])
  return loss.item()

In [None]:
seq2seq.eval()
running_loss = 0
for i, data in enumerate(val_iter):
    data = data.to(config.device)
    hidden_dec = seq2seq.decoder.init_hidden(data.shape[0], config.device)
    running_loss += validate(seq2seq, data, hidden_dec, criterion)

running_loss/len(val_iter)

In [None]:
def validate_infer(model, data, hidden_dec, criterion):
  with torch.no_grad():
    batch_size = data.shape[0]
    enc_out, _ = seq2seq.encoder(data, None)
    enc_out, bottleneck, _, _ = seq2seq.bottleneck(enc_out.permute(0,2,1))
    enc_out = enc_out.permute(0, 2, 1)
    current_token = torch.ones(
        (batch_size,1), dtype=torch.long, device=config.device
        )
    output_seq = []
    output_scores = []
    h = seq2seq.decoder.init_hidden(batch_size, config.device)
    for i in range(config.max_length - 1):
      out, h = seq2seq.decoder(current_token, h, enc_out)
      output_scores.append(out)
      current_token = out.topk(1, dim=2)[1].squeeze(-1)
      output_seq.append(current_token)
  output_seq = torch.cat(output_seq, dim=1)
  output_scores = torch.cat(output_scores, dim=1)
  # loss = criterion(output_scores.permute(0,2,1), data[:, 1:])
  loss = torch.nn.NLLLoss(ignore_index=0)(output_scores.permute(0,2,1), data[:, 1:])
  return loss.item(), output_seq

In [None]:
seq2seq.eval()
running_loss = 0
for i, data in enumerate(val_iter):
    data = data.to(config.device)
    hidden_dec = seq2seq.decoder.init_hidden(data.shape[0], config.device)
    loss, output_seq = validate_infer(seq2seq, data, hidden_dec, criterion)
    running_loss += loss
running_loss/len(val_iter)

In [None]:
output_seq[1]

In [None]:
data[1][1:]

In [None]:
lols = set()
for i in output_seq:
  lols.add(" ".join(map(str, list(i.cpu().numpy()))))
len(lols)