# Basic Tools

### Score Calculator

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]="0"
from torch.utils.data import Dataset, DataLoader
import clip
from glob import glob
from tqdm import tqdm
import torch
from PIL import Image

device = "cuda"
model, preprocess = clip.load("ViT-L/14@336px", device=device)

class CLIPDataset(Dataset):
    def __init__(self, img_dir, transform = None, tokenizer = None):
        super().__init__()
        self.transform = transform
        self.tokenizer = tokenizer
        self.img_list = glob(os.path.join(img_dir, "*.png"))
        style_ref = self.img_list[0].split("/")[2]
        data_dir = "./data/"
        style_ref_path = os.path.join(data_dir, style_ref+".jpg")
        self.ref_img = self._load_img(style_ref_path)
                             
        
    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        img_path = self.img_list[index]
        prompt = os.path.basename(img_path).split(".")[0]
        img = self._load_img(img_path)
        text = self._load_txt(prompt)
        
        sample = dict(image=img, text=text, style=self.ref_img)
        return sample
        # return img, text, self.ref_img

    def _load_img(self, path):
        img = Image.open(path)
        if self.transform is not None:
            img = self.transform(img)
        return img
    
    def _load_txt(self, data):
        if self.tokenizer is not None:
            data = self.tokenizer(data).squeeze()
        return data
    
@torch.no_grad()
def calculate_clip_score(dataloader, model, device):
    text_score_acc = 0.
    style_score_acc = 0.
    sample_num = 0.
    logit_scale = model.logit_scale.exp()
    # img_list = dataset.img_list
    for batch_data in tqdm(dataloader):
        img = batch_data['image']
        img_features = model.encode_image(img.to(device))
        text = batch_data['text']
        text_features = model.encode_text(text.to(device))
        style = batch_data['style']
        style_features = model.encode_image(style.to(device))
        
        # normalize features
        img_features = clip_normalize(img_features)
        text_features = clip_normalize(text_features)
        style_features = clip_normalize(style_features)
            
        text_score = logit_scale * (img_features * text_features).sum()
        text_score_acc += text_score
        style_score = logit_scale * (img_features * style_features).sum()
        style_score_acc += style_score
        sample_num += img.shape[0]
    
    return text_score_acc / sample_num, style_score_acc / sample_num

def clip_normalize(features):
    features = features / features.norm(dim=1, keepdim=True).to(torch.float32)
    return features

style_name = "image_03_05"
test_dir = "./results/{}/eval_samples/*".format(style_name)
test_img_list = glob(test_dir)
test_img_list = sorted(test_img_list, key=lambda x: x.split('/')[-1][7:])
final_img_dir_list = [x for x in test_img_list[10:] if "286000" in x]

for idx, img_dir in enumerate(final_img_dir_list):
# img_dir = "./results/oriental_budda/it_data/"
    batch_size = 30
    dataset = CLIPDataset(img_dir = img_dir, transform=preprocess, tokenizer=clip.tokenize)
    dataloader = DataLoader(dataset, batch_size, pin_memory=True)

    text_score, style_score = calculate_clip_score(dataloader, model, device)
    text_score = text_score.cpu().item()
    style_score = style_score.cpu().item()
    print('With {} mistakes out of 10, CLIP Text Score: {}, CLIP Style Score: {}'.format(idx%6, text_score, style_score))
    # break

### Image Generation

In [None]:
import os
import gradio as gr
import open_clip
import torch
import taming.models.vqgan
import ml_collections
import einops
import random
# Model
from libs.muse import MUSE
import utils
import numpy as np
from glob import glob
from configs.custom_IT import get_config
from PIL import Image
import json

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def cfg_nnet(x, context, scale=None,lambdaA=None,lambdaB=None):
    _cond = nnet_ema(x, context=context)
    _cond_w_adapter = nnet_ema(x,context=context,use_adapter=True)
    _empty_context = torch.tensor(empty_context, device=device)
    _empty_context = einops.repeat(_empty_context, 'L D -> B L D', B=x.size(0))
    _uncond = nnet_ema(x, context=_empty_context)
    res = _cond + scale * (_cond - _uncond)
    if lambdaA is not None:
        res = _cond_w_adapter + lambdaA*(_cond_w_adapter - _cond) + lambdaB*(_cond - _uncond)
    return res

def decode(_batch):
    return vq_model.decode_code(_batch)


config = get_config()
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
prompt_device = "cuda:2"

prompt_model,_,_ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k')
prompt_model.to(prompt_device)
prompt_model.eval()
tokenizer = open_clip.get_tokenizer('ViT-bigG-14')

vq_model = taming.models.vqgan.get_model('vq-f16-jax.yaml')
vq_model.eval()
vq_model.requires_grad_(False)
vq_model.to(device)

## config
muse = MUSE(codebook_size=vq_model.n_embed, device=device, **config.muse)

train_state = utils.initialize_train_state(config, device)
train_state.resume(ckpt_root=config.resume_root)
nnet_ema = train_state.nnet_ema
nnet_ema.eval()
nnet_ema.requires_grad_(False)
nnet_ema.to(device)

empty_context = np.load("assets/contexts/empty_context.npy")



In [None]:
num_samples = 1
lambdaA = 2.0
lambdaB = 5.0
# seed = 1234
seed = -1
sample_steps = 36
style_ref = "oriental_egret"
mistake_count = 0

adapter_path = "./results/{}/ckpts_II{}/286000.ckpt/adapter.pth".format(style_ref, "_{}".format(mistake_count) if mistake_count != 0 else "")

config.sample.lambdaA = lambdaA
config.sample.lambdaB = lambdaB
config.sample.sample_steps = sample_steps

if adapter_path is not None:
    nnet_ema.adapter.load_state_dict(torch.load(adapter_path))
else:
    config.sample.lambdaA=None
    config.sample.lambdaB=None
print("load adapter Done!")

# prompt = "A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up"
style_dir = "./data/{}.json".format(style_ref)
with open(style_dir, "r") as f:
    style_ref = json.load(f)
style = list(style_ref.values())[0][-1]
content = "A man" 

# Encode prompt
prompt = content + " " + style
text_tokens = tokenizer(prompt).to(prompt_device)
text_embedding = prompt_model.encode_text(text_tokens)
text_embedding = text_embedding.repeat(num_samples, 1, 1) # B 77 1280
print(text_embedding.shape)

print(f"lambdaA: {lambdaA}, lambdaB: {lambdaB}, sample_steps: {sample_steps}")
if seed==-1:
    seed = random.randint(0,65535)
config.seed = seed
print(f"seed: {seed}")
set_seed(config.seed)
res = muse.generate(config,num_samples,cfg_nnet,decode,is_eval=True,context=text_embedding)
print(res.shape)
res = (res*255+0.5).clamp_(0,255).permute(0,2,3,1).to('cpu',torch.uint8).numpy()
im = [res[i] for i in range(num_samples)]
# return im
    
print(prompt)
Image.fromarray(im[0])