In [1]:
import torch
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids

In [122]:
from transformers import AutoTokenizer, BartForConditionalGeneration
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base", forced_bos_token_id=0)
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

TXT = "I wish that <mask> was true."
input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
logits = model(input_ids).logits
positions = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero()
mask_pos = [mask.item() for mask in positions]

with torch.no_grad():
    output = model(input_ids)

last_hidden_state = output[0].squeeze()

for mask_ind in mask_pos:
    mask_hidden_state = last_hidden_state[mask_ind]
    tok_ids = torch.topk(mask_hidden_state, k=10, dim=0)[1]
    print(tok_ids[0])

    words = [tokenizer.decode(i.item()).strip() for i in tok_ids]

    TXT = TXT.replace("<mask>", words[0], 1)
    print(TXT, words)

print(last_hidden_state.shape)
print(model.model.encoder.embed_tokens(input_ids).shape)

tensor(24)
I wish that it was true. ['it', 'was', 'the', 'story', 'this', 'all', 'that', 'my', 'I', 'what']
torch.Size([9, 50265])
torch.Size([1, 9, 768])


In [169]:

TXT = "I really like <mask> cupcakes"
input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
positions = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero()
mask_pos = [mask.item() for mask in positions]

state = model.model.encoder(input_ids)
#state.last_hidden_state += torch.randn_like(state.last_hidden_state) * .1
outputs = model.model.decoder(input_ids = shift_tokens_right(input_ids, model.model.config.pad_token_id, model.model.config.decoder_start_token_id), encoder_hidden_states=state[0])
lm_logits = model.lm_head(outputs[0])
lm_logits = lm_logits + model.final_logits_bias.to(lm_logits.device)
tokenizer.decode(lm_logits[0][mask_pos[0]].argmax())

' these'

In [90]:
tokenizer.decode(model.lm_head(model.model.decoder(inputs_embeds=model.model.encoder(inputs_embeds=outputs[0])[0])[0])[0].argmax(dim=-1))

'voisisisvovovois'

In [174]:
target_distance = torch.tensor(.1)
threshold = .01
distance = target_distance + 1
passed = False
iterations_passed = 0
iterations_max = 100

distance_metric = torch.nn.CosineSimilarity(dim=1, eps=1e-08)


In [175]:
torch.autograd.set_detect_anomaly(True)

embedding2 = model.model.encoder(input_ids).last_hidden_state

embedding = torch.tensor(embedding2, requires_grad=True)
optimizer = torch.optim.Adam([embedding], lr=.005)

outputs = model.model.decoder(input_ids = shift_tokens_right(input_ids, model.model.config.pad_token_id, model.model.config.decoder_start_token_id), encoder_hidden_states=embedding)
base = torch.tensor(model.model.encoder(inputs_embeds=outputs.last_hidden_state).last_hidden_state, requires_grad=True)

lm_logits = model.lm_head(outputs[0])
prior_token = tokenizer.decode(lm_logits[0][mask_pos[0]].argmax())
passed = False
counter = 0
distance = threshold + 1


while (distance > threshold or not passed) or tokenizer.decode(lm_logits[0][mask_pos[0]].argmax()) == prior_token:
    if distance < threshold and counter > 10:
        passed = True

    if passed == True:
        iterations_passed += 1
    
    decoded_output = model.model.decoder(input_ids = shift_tokens_right(input_ids, model.model.config.pad_token_id, model.model.config.decoder_start_token_id), encoder_hidden_states=embedding)
    encoded_output = model.model.encoder(inputs_embeds=decoded_output.last_hidden_state)
    lm_logits = model.lm_head(decoded_output.last_hidden_state)

    if passed:
        distance = torch.abs(target_distance - torch.mean(torch.abs(distance_metric(base, encoded_output.last_hidden_state)))) * max(0, (iterations_max - iterations_passed) / iterations_max) - torch.mean(distance_metric(base, encoded_output.last_hidden_state)) * (iterations_passed / iterations_max)
    else:
        distance = torch.abs(target_distance - torch.mean(torch.abs(distance_metric(base, encoded_output.last_hidden_state))))

    loss = distance + torch.sum(torch.abs(base - embedding)) * .00001

    print(distance, passed, tokenizer.decode(lm_logits[0][mask_pos[0]].argmax()))
    loss.backward(retain_graph=True)
    optimizer.step()
    optimizer.zero_grad()

    counter += 1


  embedding = torch.tensor(embedding2, requires_grad=True)
  base = torch.tensor(model.model.encoder(inputs_embeds=outputs.last_hidden_state).last_hidden_state, requires_grad=True)


tensor(0.9000, grad_fn=<AbsBackward0>) False  these
tensor(0.8919, grad_fn=<AbsBackward0>) False  these
tensor(0.8046, grad_fn=<AbsBackward0>) False  these
tensor(0.7802, grad_fn=<AbsBackward0>) False  these
tensor(0.7736, grad_fn=<AbsBackward0>) False  these
tensor(0.7519, grad_fn=<AbsBackward0>) False  these
tensor(0.7521, grad_fn=<AbsBackward0>) False  these
tensor(0.7352, grad_fn=<AbsBackward0>) False  these
tensor(0.7231, grad_fn=<AbsBackward0>) False  these
tensor(0.7036, grad_fn=<AbsBackward0>) False  these
tensor(0.6888, grad_fn=<AbsBackward0>) False  these
tensor(0.6503, grad_fn=<AbsBackward0>) False  these
tensor(0.6144, grad_fn=<AbsBackward0>) False  these
tensor(0.5837, grad_fn=<AbsBackward0>) False  these
tensor(0.5729, grad_fn=<AbsBackward0>) False  these
tensor(0.5530, grad_fn=<AbsBackward0>) False  these
tensor(0.5281, grad_fn=<AbsBackward0>) False  these
tensor(0.5095, grad_fn=<AbsBackward0>) False  these
tensor(0.4879, grad_fn=<AbsBackward0>) False  these
tensor(0.471