In [None]:
import sys
import os

import torch
from PIL import Image
from torchvision import transforms as T
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import trange
import matplotlib.pyplot as plt

sys.path.append("../")
from CLIP.clip import load, tokenize
from CLIP.clip.simple_tokenizer import SimpleTokenizer

In [None]:
device = torch.device("cuda")

In [None]:
model, transform = load("ViT-B/32", jit=False, device=device)
model = model.eval().float()

In [None]:
tokenizer = SimpleTokenizer()

In [None]:
def decode_text(model, embedding):
    # from embed weights to tokens
    embed_weights = model.token_embedding.weight.data
    if embedding.ndim > 2:
        embedding = embedding.squeeze(0)
    decoded_tokens = np.array([torch.argmin(torch.norm(embed_weights - e, dim=-1)).item() for e in embedding])
    decoded_tokens = decoded_tokens[decoded_tokens != 0][1:-1] # ignore empty stops, start, and end token
    decoded_text = tokenizer.decode(decoded_tokens)
    return decoded_text

In [None]:
# load image
base_images = os.listdir("base_images")
print(base_images)
img_name = base_images[-1]
img = Image.open("base_images/" + img_name).convert("RGB")
norm = transform.transforms[-1]
plt.imshow(img)

In [None]:
# create start text
start_token = 1
text = "A picture of a cat."
tokens = tokenize(text)
print(tokens)

embed = model.embed_text(tokens.to(device)).detach()
embed_std = embed.std()
embed_mean = embed.mean()

max_idx = tokens.argmax()
opt = embed[0, start_token: max_idx].clone().detach().requires_grad_(True)
opt_norm = opt.detach().norm(dim=-1).mean().item()
opt_mean = opt.detach().mean(dim=-1).mean().item()
opt_std = opt.detach().std(dim=-1).mean().item()

embed[0, start_token: max_idx] = 0
print(opt.shape)
print(opt_norm)
print(opt.mean())
print(opt.std())

In [None]:
t = T.Compose([#T.Resize((224, 224)),
               T.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.9, 1.1)),
               
    
               #T.RandomAffine([0, 10], 
               #     translate=(0, 0.3),
               #     scale=(0.8, 1.0), 
               #     shear=(0.5, 0.8),
               #     fillcolor=255),
              #T.RandomGrayscale(p=0.2),
              #T.RandomPerspective(distortion_scale=0.3,
              #                    p=0.3,
              #                    fill=255),
               T.ToTensor(),
               norm,
            ])

norm_img = norm(T.ToTensor()(img.resize((224, 224)))).unsqueeze(0).to(device)
with torch.no_grad():
    norm_img_feats = model.encode_image(norm_img)

# demonstrate transform
#T.ToPILImage()(t(img))


In [None]:
prefix = embed[0, :start_token]
suffix = embed[0, max_idx:]

In [None]:
import random 
def augment_text(opt, prefix, suffix, n_gram=1):
    pos = random.randint(0, len(opt) - n_gram)

    opt_part = opt[pos: pos + n_gram]

    new_emb = torch.cat([prefix, opt_part, suffix,])
    size = n_gram + len(prefix) + len(suffix)
    if size < 77:
        new_emb = torch.cat([new_emb, torch.stack([suffix[-1].clone() for _ in range(77 - size)])])
        
    return new_emb

In [None]:
lr = 0.01
steps = 100
bs = 16


optimizer = torch.optim.Adam([opt], lr=lr)#, weight_decay=0.2)


best_text = None
best_text_loss = 100
aug_losses = []
losses = []
text_losses = []
reg_losses = []

embed_weights = model.token_embedding.weight.data


pbar = trange(steps)
for step in range(steps):
    embedding = embed.clone()
    embedding[0, start_token:max_idx] += opt
    
    embedding_batch = augment_text(opt, prefix, suffix, n_gram=len(opt)).unsqueeze(0)
    #embedding_batch = torch.stack([augment_text(opt, prefix, suffix, n_gram=random.randint(1, len(opt))) for _ in range(bs)])
    
    #embedding_batch = torch.cat([embedding + torch.zeros_like(embedding).normal_(mean=0, std=embed_std.item() / 10) for _ in range(bs)])
    text_feats = model.encode_text(tokens, embedding=embedding_batch).to(device)
    
    img_batch = img
    img_batch = torch.stack([t(img) for _ in range(bs)])
    with torch.no_grad():
        img_feats = model.encode_image(img_batch.to(device))
        
    norm_loss = (opt.norm(dim=-1).mean() - opt_norm) ** 2
    mean_loss = (opt.mean(dim=-1).mean() - opt_mean) ** 2
    std_loss = (opt.std(dim=-1).mean() - opt_std) ** 2
    #reg_loss = norm_loss + mean_loss + std_loss
    #reg_loss = torch.mean(torch.stack([(o - embed_weights).mean() for o in opt]))
    reg_loss = torch.topk(torch.stack([torch.norm(o - embed_weights, dim=-1, p=3) for o in opt]), 1, largest=False).values.mean() * 0.02

    reg_losses.append(reg_loss.item())
    sim_loss = -1 * (torch.nn.functional.cosine_similarity(text_feats, img_feats)).mean()
    
    loss = sim_loss + reg_loss
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    # move opt back to orig stats - does not work
    #with torch.no_grad():
    #    opt = (opt - opt.mean(dim=-1, keepdim=True)) / opt.std(dim=-1, keepdim=True)
    #    opt = (opt * opt_std) + opt_mean
    
    # calc non-augmented loss
    with torch.no_grad():
        non_aug_text_feats = model.encode_text(tokens, embedding=embedding)
    non_aug_loss = -1 * (torch.nn.functional.cosine_similarity(non_aug_text_feats, norm_img_feats)).mean()
    # log losses
    aug_losses.append(sim_loss.item())
    losses.append(non_aug_loss.item())
    # decode text
    current_text = decode_text(model, embedding)
    # calc loss based on decoded text
    decode_loss = -1 * (torch.nn.functional.cosine_similarity(model.encode_text(tokenize(current_text).to(device)), norm_img_feats)).item()
    text_losses.append(decode_loss)
    
    if decode_loss < best_text_loss:
        best_text = current_text
        best_text_loss = decode_loss
    
    pbar.update(1)
    pbar.set_description(current_text + " - aug loss " + str(round(loss.item(), 2)) + " loss " + str(round(non_aug_loss.item(), 2)) + " decode loss " + str(round(decode_loss, 2)))

In [None]:
p = plt.hist(opt[0].cpu().detach().flatten().numpy(), bins=100)

In [None]:
reg_loss = torch.topk(torch.stack([torch.norm(o - embed_weights, dim=-1, p=3) for o in opt]), 1, largest=False).values.mean()
reg_loss

In [None]:
#'ó charismatic herr sunsetgrayson' - for me
#'" landscapes voyage schelthur' - for autumn
#'patient �📝: rito grows ' - for ouzi
#'roomie stru<|startoftext|>saharan collie ' for ouzi
#'photoshopped yikes .- hotdog watches ' - for hot-dog

In [None]:
best_text

In [None]:
torch.nn.functional.cosine_similarity(model.encode_text(tokenize(best_text).to(device)), norm_img_feats).item()

In [None]:
torch.nn.functional.cosine_similarity(model.encode_text(tokenize("A picture of a landscape").to(device)), norm_img_feats).item()

In [None]:
plt.plot(aug_losses, label="aug loss")
plt.plot(losses, label="loss")
plt.plot(text_losses, label="text loss")
plt.plot(reg_losses, label="reg loss")
plt.legend()

In [None]:
decode_text(model, embedding)

In [None]:
from tqdm.notebook import tqdm

In [None]:
#token_losses = []

#ds = torch.utils.data.TensorDataset(embed_weights)
#dl = torch.utils.data.DataLoader(ds, batch_size=64)

all_text_feats = []

embed_weights = model.token_embedding.weight.data


for token_emb in tqdm(embed_weights):    
    embedding = embed.clone()
    embedding[0, start_token:max_idx] += token_emb.unsqueeze(0)
    
    with torch.no_grad():
        text_feats = model.encode_text(tokens, embedding=embedding)
        
    all_text_feats.append(text_feats)
    
    #loss = -1 * (torch.nn.functional.cosine_similarity(text_feats, norm_img_feats)).detach()
    #token_losses.append(loss)
    

In [None]:
feats = torch.stack(all_text_feats).squeeze()

In [None]:
#torch.save(feats.cpu(), "feats.pt")
#feats = torch.load("feats.pt")

In [None]:
losses = torch.nn.functional.cosine_similarity(feats, norm_img_feats)

In [None]:
#decode_text(model, feats[torch.argmax(losses).unsqueeze(0)])

In [None]:
k = 200
best_tokens = torch.topk(losses, k).indices.cpu().numpy()
decoded_text = tokenizer.decode(best_tokens)
print(torch.topk(losses, k).values.cpu().numpy())
print(decoded_text)

In [None]:
opt.shape

In [None]:
embedding.shape

In [None]:
embed_weights.shape

In [None]:
model.encode_text(tokens, embedding=embed_weights[0].unsqueeze(0))

In [None]:
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")

In [None]:
input_ids = tokenizer("Let me tell you something about this girl named Emilia Wiehe.", return_tensors="pt").input_ids

In [None]:
gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)

In [None]:
gen_text = tokenizer.batch_decode(gen_tokens)[0]

In [None]:
gen_text

In [None]:
gen_text

In [None]:
gen_tokens = model.generate(past_key_values=torch.zeros(5, 2048), do_sample=True, temperature=0.9, max_length=100, use_cache=True)

In [None]:
gen_text = tokenizer.batch_decode(gen_tokens)[0]

In [None]:
gen_text

In [None]:
gen_tokens

In [None]:
from tqdm.notebook import tqdm

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [None]:
generated = tokenizer.encode("My feet")
context = torch.tensor([generated])
past = None

for i in tqdm(range(100)):
    output = model(context, past_key_values=past, use_cache=True)
    logits = output["logits"]
    past = output["past_key_values"]
    token = torch.argmax(logits[..., -1, :])

    generated += [token.tolist()]
    context = token.unsqueeze(0)

sequence = tokenizer.decode(generated)

print(sequence)

In [None]:
for vec in past:
    for ten in vec:
        ten.requires_grad_(True)

In [None]:
gen_tokens = model.generate(past_key_values=past, do_sample=True, temperature=0.9, max_length=100, use_cache=True)

In [None]:
gen_text = tokenizer.batch_decode(gen_tokens)[0]
gen_text

In [None]:
len(out)

In [None]:
out[0].shape

In [None]:
len(out[1])

In [None]:
len(out[1][0])

In [None]:
out[1][0][0].shape

In [None]:
for step in range(opt_steps):
    text_feats = model.encode_text(opt_tokens.cuda())
    
    loss = -1 * torch.nn.functional.cosine_similarity(text_feats, img_feats)
    
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    print(opt_tokens)
