In [1]:
import torch 
import pandas as pd
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import tqdm
import sys
sys.path.append('./data_prep')
from sentence_dataset_class import ProcessedSentences
from sentence_processing import build_vocab,sentence_processing
sys.path.append('./transformer_testing')
from tomislav_transformer import Seq2SeqTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def stringify_series(df):
    df['input_data'] = df['input_data'].astype('string')
    df['output_data'] = df['output_data'].astype('string')
    return df

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
df_train = pd.read_json('data/train_data.json')
df_test = pd.read_json('data/test_data.json')

In [5]:
token_transform = get_tokenizer('basic_english')

In [6]:
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

In [7]:
train_input_vocab = build_vocab(df_train['input_data'],token_transform,special_symbols)
train_output_vocab = build_vocab(df_train['output_data'],token_transform,special_symbols)

In [8]:
train_dataset = ProcessedSentences(
    input_data = df_train['input_data'].values,
    output_data = df_train['output_data'].values,
)

In [9]:
test_dataset = ProcessedSentences(
    input_data = df_test['input_data'].values,
    output_data = df_test['output_data'].values
)

In [10]:
def collate_fn(batch):
    input_tensor = []
    output_tensor = []
    for input,output in batch:
        input_tensor.append(sentence_processing(input,
                                             train_input_vocab,
                                             token_transform,
                                             special_symbols.index('<bos>'),special_symbols.index('<eos>')
                                             )
                            )
        output_tensor.append(sentence_processing(output,
                                             train_output_vocab,
                                             token_transform,
                                             special_symbols.index('<bos>'),special_symbols.index('<eos>')
                                             ))
    src_batch = pad_sequence(input_tensor, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(output_tensor, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [11]:
train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True,collate_fn=collate_fn)

In [36]:
for input,output in train_dataloader:
    out = output[:-1,:]
    x = out[:,1]
   
    break

In [37]:
real_sent = list(x.numpy())

In [39]:
' '.join(train_output_vocab.lookup_tokens(real_sent))

"<bos> if i can get out of here at a decent time i ' ll finish it tonight . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>"

In [None]:
torch.manual_seed(0)
input_vocab_size = len(train_input_vocab)
output_vocab_size = len(train_output_vocab)
emb_size = 256
n_head = 2
ffn_hid_dim = 512
batch_size = 128
num_encoder_layers = 1
num_decoder_layers = 1

In [None]:
transformer = Seq2SeqTransformer(
    num_encoder_layers,
    num_decoder_layers,
    emb_size,
    n_head,
    input_vocab_size,
    output_vocab_size,
    ffn_hid_dim)

In [None]:
transformer = transformer.to(device)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [None]:
def train_epoch(model,optimizer):
    model.train()
    losses = 0
    train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)
    for input_sent, output_sent in tqdm.tqdm(train_dataloader):
        input_sent = input_sent.to(device)
        output_sent = output_sent.to(device)
        
        output_input = output_sent[:-1,:]
        
        input_mask, output_mask, input_padding_mask, output_padding_mask = create_mask(input_sent,output_input)
        logits = model(
            input_sent,
            output_input,
            input_mask,
            output_mask,
            input_padding_mask,
            output_padding_mask,
            input_padding_mask)
        optimizer.zero_grad()
        
        output_out = output_sent[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), output_out.reshape(-1))
        loss.backward()
        
        optimizer.step()
        losses += loss.item()
    return losses/len(train_dataloader)
def evaluate(model):
    model.eval()
    losses = 0
    test_dataloader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)
    for input_sent, output_sent in tqdm.tqdm(test_dataloader):
        input_sent = input_sent.to(device)
        output_sent = output_sent.to(device)
        
        output_input = output_sent[:-1,:]
        input_mask, output_mask, input_padding_mask, output_padding_mask = create_mask(input_sent,output_input)
        logits = model(
            input_sent,
            output_input,
            input_mask,
            output_mask,
            input_padding_mask,
            output_padding_mask,
            input_padding_mask)
        
        
        output_out = output_sent[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), output_out.reshape(-1))
        
        losses += loss.item()
    return losses/len(test_dataloader)
    
    

In [None]:
import tqdm
num_epochs = 2

for epoch in range(1,num_epochs+1):
    train_loss = train_epoch(transformer,optimizer)
    val_loss = evaluate(transformer)
    print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}")

100%|██████████| 1661/1661 [03:03<00:00,  9.04it/s]
100%|██████████| 202/202 [00:09<00:00, 22.22it/s]


Epoch: 1, Train loss: 5.485, Val loss: 4.843


100%|██████████| 1661/1661 [03:00<00:00,  9.20it/s]
100%|██████████| 202/202 [00:09<00:00, 21.96it/s]

Epoch: 2, Train loss: 4.742, Val loss: 4.477





In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys
  # actual function to translate input sentence into target language
def translate(model: torch.nn.Module, input_sentence: str):
    model.eval()
    src = sentence_processing(input_sentence,
                              train_input_vocab,
                              token_transform,
                              BOS_IDX,
                              EOS_IDX).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(train_output_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [None]:
print(translate(transformer,"good drinks , and good company ."))

 <unk> , and new york and new york and new york and new
