In [None]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append('..')
from omegaconf import OmegaConf
from types import MethodType

import torch
import lightning as L

from src.pal_rm_b_t2i.lightningmodule import LearnerWrapLightning


In [12]:
# init pms (through argparser)
dname = 'pickapicv2'
k = 1
conf_learner = f'../config/prefLearner_config/t2i-b-dim1024-k{k}-general-embeddings-mlp2.yaml'
conf_ds = f'../config/ds_config/{dname}_cliph_original_ds.yaml'
cache = f'../necessary_cache/{dname}-dataset-tables'
ckpt_path = f'../../pickapic_checkpts/'

# load configs
conf_learner = OmegaConf.load(conf_learner)
conf_ds = OmegaConf.load(conf_ds)

conf_learner.preference_learner_params.k = k
conf_learner.max_epochs_new_pair = 50
conf_ds.batch_size = 16384

user_ids = torch.load(os.path.join(cache,"user_ids.pt"), weights_only=True)

learner = LearnerWrapLightning(**conf_learner)
learner.preference_learner.user_learner.init_weight(user_ids)

the upper bound is 4


In [None]:
state_dict = torch.load(
    f'{ckpt_path}/palb-{dname}-cliph-k{k}/seen-pickapicv2-cliph-modelB-angle-logistic-k{k}_trial0-1-16384-ca559ce4-epoch=05.ckpt',
    map_location='cpu',
    weights_only=True
)['state_dict']

def _tmp_state_dict_converter(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if '.m.' in k:
            new_state_dict[k.replace('.m.', '.mlp.')] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

learner.load_state_dict(_tmp_state_dict_converter(state_dict))

<All keys matched successfully>

In [36]:
model = torch.load(f'{ckpt_path}/palb-{dname}-cliph-k{k}/seen-pickapicv2-cliph-modelB-angle-logistic-k{k}_trial0-1-16384-ca559ce4-epoch=05.ckpt',
                   map_location='cpu',
                   weights_only=True)

print(model.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers'])


In [16]:
def wrap_mix_forward_b(pal: torch.nn.Module, mix_weight: torch.tensor):
    # override the forward pass of the original model to be a standard reward model
    # after the modification, the PAL reward model will output:
    # PAL-A: the reward difference given a prompt
    # PAL-B: the reward logits given a prompt

    def mix_forward_userlearner(self, latent_prompts):
        prompt_logits = self.infer_gk(latent_prompts)   # (bs, k, dims)
        bs = prompt_logits.size(0)
        w = self.softmax(mix_weight.repeat(bs, 1))  # (bs, k)
        w = w.unsqueeze(1)  # (bs, 1, k)
        y_hat = torch.bmm(w, prompt_logits) # (bs, 1, dims)
        y_hat = y_hat.squeeze(1)    # (bs, dims)
        self.tmp_store_user_ideal_points = y_hat
        return y_hat
    
    def mix_forward_itemlearner(self, items):
        x = self.connector_x(items)
        if self.learner_type in ['dist','dist_normalization','angle','dist_logistic','angle_hinge']:  # |f(x)-f(u)|_2 or <f(x), f(u)>
            return self.projector(x)
        elif self.learner_type == 'norm':   # # |f(x-u)|_2 (do the self.projector() part in the PreferenceLearner)
            return x
        else:
            raise ValueError(f"Unknown learner_type={self.learner_type}.")
        
    def mix_map_preflearner(self, x):
        prompt, items = x
        items_prime = self.item_learner(items)
        prompt_prime = self.user_learner(prompt)
        return items_prime, prompt_prime
    
    def mix_forward_preflearner(self, x):
        items, prompt = x
        items_prime, prompt_prime = self.map_to_pref_embedding_space((prompt, items))
        print(f"{items_prime[0]=}")
        print(f"{prompt_prime[0]=}")
        print(f"{items_prime.shape=}")
        print(f"{prompt_prime.shape=}")
        if self.pref_learner_type == 'angle':
            items_prime = items_prime / torch.norm(items_prime, dim=-1, keepdim=True)
            prompt_prime = prompt_prime / torch.norm(prompt_prime, dim=-1, keepdim=True)
            prompt_prime = prompt_prime.unsqueeze(1)
            logit_scale = self.logit_scale.exp()
            clamped_logit_scale = torch.clamp(logit_scale, max=100)
            # print(clamped_logit_scale)
            # print((prompt_prime * items_prime).sum(dim=-1))
            sim_score = (prompt_prime * items_prime).sum(dim=-1) * clamped_logit_scale   # (bs, max_token_length)
            return sim_score
        else:
            raise NotImplementedError
        
    def forward(self, batch):
        y_hat = self.preference_learner(batch)
        return y_hat
    
    pal.preference_learner.user_learner.forward = MethodType(mix_forward_userlearner, pal.preference_learner.user_learner)
    pal.preference_learner.item_learner.forward = MethodType(mix_forward_itemlearner, pal.preference_learner.item_learner)
    pal.preference_learner.map_to_pref_embedding_space = MethodType(mix_map_preflearner, pal.preference_learner)
    pal.preference_learner.forward = MethodType(mix_forward_preflearner, pal.preference_learner)
    pal.forward = MethodType(forward, pal)
    
    return pal

In [39]:
pal = wrap_mix_forward_b(learner, torch.tensor([0.5]))

In [53]:
# Example images from HPS v2 interactive dataset tool: https://tgxs002.github.io/hpd_test_vis/
img_0_url = 'https://tgxs002.github.io/hpd_test_vis/static/assets/03079.jpg'
img_1_url = 'https://tgxs002.github.io/hpd_test_vis/static/assets/03071.jpg'
prompt = 'A person holding a very small slice on pizza between their fingers.'

from IPython.display import Image, display
import requests

img_0 = requests.get(img_0_url).content
img_1 = requests.get(img_1_url).content

image0 = Image(img_0)
image1 = Image(img_1)

In [56]:
tmp_prompt_latent = torch.randn(1, 1024)
tmp_item_latent = torch.randn(1, 1024)
pal([tmp_prompt_latent, tmp_item_latent])

items_prime[0]=tensor([-1.7085,  1.2021,  1.6133,  ...,  0.0625,  0.1212, -0.2356],
       grad_fn=<SelectBackward0>)
prompt_prime[0]=tensor([ 1.1747,  0.1154, -1.5194,  ...,  1.6674, -1.8732, -0.7937],
       grad_fn=<SelectBackward0>)
items_prime.shape=torch.Size([1, 1024])
prompt_prime.shape=torch.Size([1, 1024])


tensor([[-2.6417]], grad_fn=<MulBackward0>)