In [89]:
import os, sys
from uuid import uuid4
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
sys.path.append('..')
from tqdm import tqdm
import argparse
from omegaconf import OmegaConf
from types import MethodType

import torch
import lightning as L

from src.pal_rm_b_t2i.lightningmodule import LearnerWrapLightning
from src.utils import create_filename


In [90]:
# init pms (through argparser)
conf_learner = '../config/prefLearner_config/t2i-b-dim1024-k1-general-embeddings-mlp2.yaml'
conf_ds = '../config/ds_config/pickapicv2_cliph_original_ds.yaml'
cache = '../necessary_cache/pickapicv2-dataset-tables'
k = 1
name = 'pickapicv2'

# load configs
conf_learner = OmegaConf.load(conf_learner)
conf_ds = OmegaConf.load(conf_ds)
conf_wandb = OmegaConf.create({'project': 'pal_t2i', 'name': 'pickapicv2', 'save_dir': './logs_wandb/'})
user_ids = torch.load(os.path.join(cache,"user_ids.pt"))
# modify configs
conf_learner.preference_learner_params.k = k
conf_learner.max_epochs_new_pair = 50
conf_ds.batch_size = 16384
# init save path
random_uuid4 = str(uuid4())[:8]  # This takes only the first 8 characters.
filename = create_filename(name, str(k), str(conf_ds.batch_size))
filename += f"-{random_uuid4}"
folder_path = os.path.join("./figs", filename)
print("ckpt store path:", filename)

learner = LearnerWrapLightning(**conf_learner)
learner.preference_learner.user_learner.init_weight(user_ids)

ckpt store path: pickapicv2-1-16384-1b9b3397
the upper bound is 4


In [91]:
state_dict = torch.load(
    '../ckpts/t2i-ckpts/seen-pickapicv2-cliph-modelB-angle-logistic-k1_trial0-1-16384-ca559ce4-epoch=05.ckpt',
    map_location='cpu'
)['state_dict']

import re
# def _tmp_state_dict_converter(state_dict):
#     new_state_dict = {}
#     for k, v in state_dict.items():
#         if 'projector_f.mlp.' in k:
#             new_state_dict[k.replace('projector_f.m.', 'projector_f.mlp')] = v
#         elif 'item_learner.projector.mlp' in k:
#             new_state_dict[k.replace('item_learner.projector.m', 'item_learner.projector.mlp')] = v
#         elif re.match(r'user_learner\.projectors\.\d+\.mlp', k):
#             new_state_dict[k.replace('user_learner.projectors.', 'user_learner.projectors.').replace('.m', '.mlp')] = v
#         else:
#             new_state_dict[k] = v
#     return new_state_dict

def _tmp_state_dict_converter(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if '.m.' in k:
            new_state_dict[k.replace('.m.', '.mlp.')] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

In [92]:
# learner.load_state_dict(state_dict)
learner.load_state_dict(_tmp_state_dict_converter(state_dict))

<All keys matched successfully>

In [101]:
def wrap_mix_forward_b(pal: torch.nn.Module, mix_weight: torch.tensor):
    # override the forward function of the pal to be a standard reward model
    # after the modification, the pal will be able to output:
    # model a: the reward diff given a prompt
    # model b: the reward logits given a prompt
    def mix_forward_userlearner(self, latent_prompts):
        prompt_logits = self.infer_gk(latent_prompts)   # (bs, k, dims)
        bs = prompt_logits.size(0)
        w = self.softmax(mix_weight.repeat(bs, 1))  # (bs, k)
        w = w.unsqueeze(1)  # (bs, 1, k)
        y_hat = torch.bmm(w, prompt_logits) # (bs, 1, dims)
        y_hat = y_hat.squeeze(1)    # (bs, dims)
        self.tmp_store_user_ideal_points = y_hat
        return y_hat
    def mix_forward_itemlearner(self, items):
        x = self.connector_x(items)
        if self.learner_type in ['dist','dist_normalization','angle','dist_logistic','angle_hinge']:  # |f(x)-f(u)|_2 or <f(x), f(u)>
            return self.projector(x)
        elif self.learner_type == 'norm':   # # |f(x-u)|_2 (do the self.projector() part in the PreferenceLearner)
            return x
        else:
            raise ValueError(f"Unknown learner_type={self.learner_type}.")
    def mix_map_preflearner(self, x):
        # ({
        # 'input_ids': prompt_input_ids,\
        # 'attention_mask': prompt_attention_mask,
        # },\
        # {
        # 'input_ids': eval_input_ids,\
        # 'attention_mask': eval_attention_mask,\
        # })
        prompt, items = x
        items_prime = self.item_learner(items)
        prompt_prime = self.user_learner(prompt)
        return items_prime, prompt_prime
    def mix_forward_preflearner(self, x):
        items, prompt = x
        items_prime, prompt_prime = self.map_to_pref_embedding_space((prompt, items))
        print(f"{items_prime[0]=}")
        print(f"{prompt_prime[0]=}")
        print(f"{items_prime.shape=}")
        print(f"{prompt_prime.shape=}")
        if self.pref_learner_type == 'angle':
            items_prime = items_prime / torch.norm(items_prime, dim=-1, keepdim=True)
            prompt_prime = prompt_prime / torch.norm(prompt_prime, dim=-1, keepdim=True)
            prompt_prime = prompt_prime.unsqueeze(1)
            logit_scale = self.logit_scale.exp()
            clamped_logit_scale = torch.clamp(logit_scale, max=100)
            # print(clamped_logit_scale)
            # print((prompt_prime * items_prime).sum(dim=-1))
            sim_score = (prompt_prime * items_prime).sum(dim=-1) * clamped_logit_scale   # (bs, max_token_length)
            return sim_score
        else:
            raise NotImplementedError
    def forwad(self, batch):
        y_hat = self.preference_learner(batch)
        return y_hat
    pal.preference_learner.user_learner.forward = MethodType(mix_forward_userlearner, pal.preference_learner.user_learner)
    pal.preference_learner.item_learner.forward = MethodType(mix_forward_itemlearner, pal.preference_learner.item_learner)
    pal.preference_learner.map_to_pref_embedding_space = MethodType(mix_map_preflearner, pal.preference_learner)
    pal.preference_learner.forward = MethodType(mix_forward_preflearner, pal.preference_learner)
    pal.forward = MethodType(forwad, pal)
    return pal


In [102]:
pal = wrap_mix_forward_b(learner, torch.tensor([0.5]))

In [103]:
tmp_prompt_latent = torch.randn(1, 1024)
tmp_item_latent = torch.randn(1, 1024)
pal([tmp_prompt_latent, tmp_item_latent])

items_prime[0]=tensor([-0.8441,  0.8796,  0.6785,  ...,  1.2911,  1.8600,  1.2384],
       grad_fn=<SelectBackward0>)
prompt_prime[0]=tensor([ 1.5011, -1.0693, -1.5910,  ..., -1.7996,  0.4741, -1.2623],
       grad_fn=<SelectBackward0>)
items_prime.shape=torch.Size([1, 1024])
prompt_prime.shape=torch.Size([1, 1024])


tensor([[-2.8308]], grad_fn=<MulBackward0>)