In [21]:
import os

DATA_DIR = "data" # This may need to be changed on different machines

# Make sure we're in the correct directory and make sure the data directory exists
if not os.path.exists(DATA_DIR):
    os.chdir("../..") # Move up two directories because we're in src/nb and the data directory/path should be in/start at the root directory 
    assert os.path.exists(DATA_DIR), f"ERROR: DATA_DIR={DATA_DIR} not found"  # If we still can't see the data directory something is wrong

from tqdm.notebook import tqdm
import numpy as np

import torch
from torch import nn
# get Dataset class
from src.lib.decoder import Decoder
from src.lib.paraphrase_model import Paraphraser
from src.lib.style_classifier import StyleEncoder
from src.lib.style_transfer import StyleTransferer
from src.lib.util import to_device
from transformers import GPT2LMHeadModel, AdamW, GPT2Tokenizer

In [2]:

def load_decoder(state_dict_path):
    state_dict = torch.load(state_dict_path)
    for key in state_dict:
        state_dict[key] = state_dict[key].cpu()
    decoder = Decoder()
    decoder.load_state_dict(state_dict)
    return decoder
    

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = "training_results/decoder_0_0.0979/model.pth"
decoder = load_decoder(path)

Some weights of the model checkpoint at models/gpt2_large were not used when initializing GPT2LMHeadModel: ['transformer.extra_embedding_project.weight', 'transformer.extra_embedding_project.bias']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Explore the built in generate method to see if our model has brain damage or if if there's a problem with our generate method

In [15]:
gpt2 = decoder.gpt2.to(device)

In [16]:
input_sentence = "paraphrase this you insufferable diva<bos>"
# tokenize
tokenized = decoder.tokenizer(input_sentence, return_tensors="pt")
input_ids = tokenized["input_ids"].to(device)
attn_mask = tokenized["attention_mask"].to(device)
generated_ids = gpt2.generate(input_ids, attention_mask=attn_mask)

# decode the generated ids
generated_sentence = decoder.tokenizer.decode(generated_ids[0])
generated_sentence

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


"paraphrase this you insufferable diva <bos> you're a fool, you fool, you"

In [14]:
paraphrase = Paraphraser()
decoder.gpt2 = paraphrase.model

Some weights of the model checkpoint at models/gpt2_large were not used when initializing GPT2LMHeadModel: ['transformer.extra_embedding_project.weight', 'transformer.extra_embedding_project.bias']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
style_encoder = StyleEncoder()

In [None]:
# style_encoding = style_encoder.get_style_vector([style_sentence]).squeeze(0)
# style_encoding.shape

In [None]:
# positional_embeds = decoder.gpt2.transformer.wpe.weight
# token_embeds = decoder.gpt2.transformer.wte.weight

In [18]:
# def build_input(semantic_sentence, style_sentence, style_encoder, decoder):
#     positional_embeds = decoder.gpt2.transformer.wpe.weight
#     token_embeds = decoder.gpt2.transformer.wte.weight

#     style_encoding = style_encoder.get_style_vector([style_sentence]).squeeze(0) # (768)

#     max_length = 50
#     tokenized = decoder.tokenizer([semantic_sentence], return_tensors="pt", truncation=True, max_length=max_length)
#     para_ids = tokenized["input_ids"] # (1, para_length)
#     para_attn = tokenized["attention_mask"].squeeze(0) # (para_length)

#     para_embeds = token_embeds[para_ids].squeeze(0).detach() # (para_length, 1280)
#     para_pos = positional_embeds[np.arange(0, len(para_embeds))].detach() # (para_length, 1280)

#     target_embeds = token_embeds[[decoder.tokenizer.bos_token_id, decoder.tokenizer.pad_token_id]].detach() # (2, 1280)
#     target_pos = positional_embeds[[len(para_ids), len(para_ids)+1]].detach() # (2, 1280)

#     target_attn = torch.tensor([1, 0])

#     attn_mask = torch.ones(2 + len(para_attn) + len(target_attn)) # (2 + para_length + target_length)
#     attn_mask[1:len(para_attn)+1] = para_attn
#     attn_mask[len(para_attn)+2:] = target_attn # just one for the BOS token
#     attn_mask = attn_mask

#     bos_pos = positional_embeds[len(para_ids) + 1].detach()

#     style_encoding = style_encoding.unsqueeze(0)
#     para_embeds = para_embeds.unsqueeze(0)
#     para_pos = para_pos.unsqueeze(0)
#     bos_pos = bos_pos.unsqueeze(0)
#     target_embeds = target_embeds.unsqueeze(0)
#     target_pos = target_pos.unsqueeze(0)
#     attn_mask = attn_mask.unsqueeze(0)

#     x = (
#         style_encoding,
#         (para_embeds, para_pos),
#         bos_pos,
#         (target_embeds, target_pos),
#         attn_mask.unsqueeze(0)
#     )

#     return x


# def generate(x, decoder, truncate=False, max_length=50, device=None):
#     if device is None:
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     x = to_device(x, device)
#     decoder = decoder.to(device)
    
#     # print sum of attn_mask
    

#     generated_ids = []
#     generated_logits = []
#     with torch.no_grad():
#         i = 0
#         while i < max_length or not truncate:
#             i += 1

#             style_encoding, para, bos_pos, target, attn_mask = x
#             target_embeds, target_pos = target

#             print(attn_mask.sum())

#             output = decoder(x)

#             logits = output.logits[0, -1, :]
#             token_id = logits.argmax()

#             generated_logits.append(logits)


#             # check if token_id is eos
#             if token_id == decoder.tokenizer.eos_token_id:
#                 break

#             # add generated id 
#             generated_ids.append(token_id.item())
#             next_embedding = decoder.gpt2.transformer.wte.weight[token_id]
#             next_pos_embed = decoder.gpt2.transformer.wpe.weight[target_pos.shape[1] + 1]

#             # update the target embedding
#             target_embeds = torch.cat([target_embeds, next_embedding.unsqueeze(0).unsqueeze(0)], dim=1)
#             # update the target position
#             target_pos = torch.cat([target_pos, next_pos_embed.unsqueeze(0).unsqueeze(0)], dim=1)
#             # update the attention mask
#             attn_mask = torch.cat([attn_mask, torch.ones(1, 1, 1).to(device)], dim=2)

#             # repackage x
#             x = (
#                 style_encoding,
#                 para,
#                 bos_pos,
#                 (target_embeds, target_pos),
#                 attn_mask
#             )
        
#     return generated_ids, generated_logits



In [19]:
semantic_sentence = "Hello, how are you?"
style_sentence = "The all-seeing sun Ne'er saw her match since first the world begun."

In [23]:
st = StyleTransferer(style_encoder, decoder, device)
st.transfer_style(semantic_sentence, style_sentence, truncate=True)


"you're a little, you're a little bit, you're a little bit, you're a little bit, you're a little bit, you're a little bit, you're a little bit, you're a little bit, you're a"

In [20]:
input_sequence = build_input(semantic_sentence, style_sentence, style_encoder, decoder)
tokens, logits = generate(input_sequence, decoder, device=device, truncate=True)

tensor(9., device='cuda:0')


IndexError: tensors used as indices must be long, byte or bool tensors

In [9]:
print([logits[0].argmax()])
print([logits[1].argmax()])
print([logits[2].argmax()])
print([logits[3].argmax()])
print([logits[4].argmax()])

[tensor(15365, device='cuda:0')]
[tensor(15365, device='cuda:0')]
[tensor(15365, device='cuda:0')]
[tensor(15365, device='cuda:0')]
[tensor(15365, device='cuda:0')]


In [10]:
# decode the tokens
text = decoder.tokenizer.decode(tokens)
text

'ENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTSENTS'

In [None]:
max_length = 50
tokenized = decoder.tokenizer([semantic_sentence], return_tensors="pt", truncation=True, max_length=max_length)
para_ids = tokenized["input_ids"]
para_attn = tokenized["attention_mask"].squeeze(0)
print(para_ids.shape)
print(para_attn.shape)

In [None]:
para_embeds = token_embeds[para_ids].squeeze(0).detach()
para_pos = positional_embeds[np.arange(0, len(para_embeds))].detach()
# para_embeds += para_pos
print(para_embeds.shape)
print(para_pos.shape)

In [None]:
target_embeds = token_embeds[decoder.tokenizer.bos_token_id].unsqueeze(0).detach()
target_pos = positional_embeds[len(para_ids)].unsqueeze(0).detach()
# target_embeds += target_pos
print(target_pos.shape)
print(target_embeds.shape)

In [None]:
target_attn = torch.tensor([1])

attn_mask = torch.ones(2 + len(para_attn) + len(target_attn))
attn_mask[1:len(para_attn)+1] = para_attn
attn_mask[len(para_attn)+2:] = target_attn # just one for the BOS token
attn_mask = attn_mask
print(attn_mask.shape)


In [None]:
bos_pos = positional_embeds[len(para_ids) + 1].detach()
print(bos_pos.shape)

In [None]:
print("style_encoding", style_encoding.shape)
print("para_embeds", para_embeds.shape)
print("para_pos", para_pos.shape)
print("bos_pos", bos_pos.shape)
print("target_embeds", target_embeds.shape)
print("target_pos", target_pos.shape)
print("attn_mask", attn_mask.shape)

# unsqueeze to add batch dimension
style_encoding = style_encoding.unsqueeze(0)
para_embeds = para_embeds.unsqueeze(0)
para_pos = para_pos.unsqueeze(0)
bos_pos = bos_pos.unsqueeze(0)
target_embeds = target_embeds.unsqueeze(0)
target_pos = target_pos.unsqueeze(0)
attn_mask = attn_mask.unsqueeze(0)

In [None]:
x = (
    style_encoding,
    (para_embeds, para_pos),
    bos_pos,
    (target_embeds, target_pos),
    attn_mask.unsqueeze(0)
)


In [None]:
x = to_device(x, "cuda")
decoder = decoder.to("cuda")

In [None]:
output = decoder(x)

In [None]:
token_id = output.logits.argmax().item()