In [None]:
from training.main import *

In [None]:
CHECKPOINT_PATH = "/path/to/checkpoint.pt"
SUGARCREPE_PATH  = "/path/to/sugar-crepe"
COCO_PATH = "/path/to/coco/val2017/"

In [None]:
pretrained_paths = [CHECKPOINT_PATH]
model_config = {"L": 64, "V": 64, "reduce_depth": 1, "sparo_type": "cont:const", "share_kv": True}
model_type = "ViT-B-16-SPARO"
num_edge_patches = 14

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

random_seed(0)
torch.manual_seed(0)
np.random.seed(0)
models = []
for _ in pretrained_paths:
    model, preprocess_train, preprocess_val = create_model_and_transforms(
        model_type,
        "",
        precision="amp_bfloat16",
        device="cuda",
        jit=False,
        force_quick_gelu=False,
        force_custom_text=False,
        force_patch_dropout=None,
        force_image_size=None,
        pretrained_image=False,
        image_mean=None,
        image_std=None,
        aug_cfg={},
        output_dict=True,
        override_config=model_config,
    )
    model.eval()
    models.append(model)

In [None]:
for i, pretrained_path in enumerate(pretrained_paths):
    checkpoint = pt_load(pretrained_path, map_location='cpu')
    sd = checkpoint["state_dict"]
    if next(iter(sd.items()))[0].startswith('module'):
        sd = {k[len('module.'):]: v for k, v in sd.items()}
    models[i].load_state_dict(sd)

In [None]:
import json
from PIL import Image

data_dict = {
    'add_obj'    : f'{SUGARCREPE_PATH}/data/add_obj.json',
    'add_att'    : f'{SUGARCREPE_PATH}/data/add_att.json',
    'replace_obj': f'{SUGARCREPE_PATH}/data/replace_obj.json',
    'replace_att': f'{SUGARCREPE_PATH}/data/replace_att.json',
    'replace_rel': f'{SUGARCREPE_PATH}/data/replace_rel.json',
    'swap_obj'   : f'{SUGARCREPE_PATH}/data/swap_obj.json',
    'swap_att'   : f'{SUGARCREPE_PATH}/data/swap_att.json',
}
dataset = {}
for c, data_path in data_dict.items():
    dataset[c] = json.load(open(data_path, 'r', encoding='utf-8'))


class TextRetrievalDataset(torch.utils.data.Dataset):
    def __init__(self, image_root, data_dict):
        self.image_root = image_root

        self.datasets = list([list(v.values()) for v in data_dict.values()])
        self.lengths = [len(v) for v in self.datasets]

        self.dataset = sum(self.datasets, [])
        self.length = len(self.dataset)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        data = self.dataset[idx]
        image_path = os.path.join(self.image_root, data['filename'])
        image = Image.open(image_path)
        return image, data['caption'], data['negative_caption']


def collate_fn(data):
    images, captions, neg_captions = zip(*data)
    return images, captions, neg_captions


tokenizer = get_tokenizer(model_type)
dset = TextRetrievalDataset(COCO_PATH, dataset)
loader = torch.utils.data.DataLoader(dset, batch_size=512, collate_fn=collate_fn, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)
data_iterator = iter(loader)

In [None]:
import torch.nn.functional as F
from open_clip.tokenizer import decode
from PIL import Image
from torch.distributions.categorical import Categorical
from matplotlib import pyplot as plt

ALIGNMENT_THRESHOLD = 0.75

raw_images, raw_texts, _ = next(data_iterator)
images = torch.stack([preprocess_val(raw_image) for raw_image in raw_images]).cuda()
texts = tokenizer(raw_texts).cuda()
image_sparos, image_attns = [], []
text_sparos, text_attns = [], []
with torch.no_grad():
    for model in models:
        _image_sparos, _image_attns = model.encode_image(images, normalize=True, return_sparo=True, return_attn=True)
        _text_sparos, _text_attns = model.encode_text(texts, normalize=True, return_sparo=True, return_attn=True)
        _image_attns = _image_attns.squeeze(-2)
        _text_attns = _text_attns.squeeze(-2)
        image_sparos.append(_image_sparos)
        image_attns.append(_image_attns)
        text_sparos.append(_text_sparos)
        text_attns.append(_text_attns)
image_sparos = torch.stack(image_sparos).transpose(0,1)
image_attns = torch.stack(image_attns).transpose(0,1)
text_sparos = torch.stack(text_sparos).transpose(0,1)
text_attns = torch.stack(text_attns).transpose(0,1)

all_alignments = []
for image_sparos_per_model, text_sparo_per_model in zip(image_sparos, text_sparos):
    alignments_per_model = [(F.normalize(image_sparo, dim=-1) * F.normalize(text_sparo, dim=-1)).sum(-1) for image_sparo, text_sparo in zip(image_sparos_per_model, text_sparo_per_model)]
    all_alignments.append(alignments_per_model[-1])
all_alignments = torch.stack(all_alignments)

bad_positions = texts.argmax(dim=-1, keepdim=True).unsqueeze(-2) == text_attns.argmax(dim=-1)
bad_positions |= text_attns.argmax(dim=-1) == 0
bad_positions |= all_alignments.unsqueeze(-2) < ALIGNMENT_THRESHOLD
bad_samples = bad_positions.squeeze(-2).all(dim=-1)
raw_images = [raw_images[i] for i, m in enumerate(~bad_samples) if m]
raw_texts = [raw_texts[i] for i, m in enumerate(~bad_samples) if m]
images = images[~bad_samples]
texts = texts[~bad_samples]
image_sparos = image_sparos[~bad_samples]
image_attns = image_attns[~bad_samples]
text_sparos = text_sparos[~bad_samples]
text_attns = text_attns[~bad_samples]
bad_positions = bad_positions[~bad_samples]
all_alignments = all_alignments[~bad_samples]

assert not ((text_attns > 0).sum(dim=-1).squeeze(-2).max(dim=-1).values > texts.argmax(dim=-1)+1).any()

In [None]:
top_slots = torch.where(~bad_positions.squeeze(-2), all_alignments, 0).mean(dim=0).topk(32).indices
top_samples = all_alignments[:, top_slots].mean(dim=-1).topk(16).indices

In [None]:
SHARPNESS = 1.0
MAGNITUDE = 0.75

red = np.concatenate((np.ones((num_edge_patches,num_edge_patches,1)),
                      np.zeros((num_edge_patches,num_edge_patches,1)),
                      np.zeros((num_edge_patches,num_edge_patches,1))), axis=-1)

for ind in top_slots:
    print(f" === SLOT {ind} === ")
    top_samples = all_alignments[:, ind].topk(16).indices
    for alignments_ind, alignments_vals in enumerate(all_alignments[top_samples]):
        if alignments_vals[ind].item() < 0.0:
            continue
        raw_image, text, image_sparos_per_model, text_sparo_per_model, image_attn_per_model, text_attn_per_model = (
            [raw_images[i] for i in top_samples][alignments_ind],
            texts[top_samples][alignments_ind],
            image_sparos[top_samples][alignments_ind],
            text_sparos[top_samples][alignments_ind],
            image_attns[top_samples][alignments_ind],
            text_attns[top_samples][alignments_ind],
        )
        decoded_tokens = []
        for token in text:
            decoded_token = decode(token.unsqueeze(0))
            decoded_tokens.append(decoded_token)
            if decoded_token == "<end_of_text>":
                break
        image_sparo, text_sparo, image_attn, text_attn = image_sparos_per_model[0], text_sparo_per_model[0], image_attn_per_model[0], text_attn_per_model[0]
        print(f"Slot image-text alignment: {alignments_vals[ind].item():.3f}")
        im_att = image_attn[ind]
        im_att = im_att ** SHARPNESS
        im_att /= im_att.max() / MAGNITUDE
        tx_att = text_attn[ind]
        tx_att = tx_att ** SHARPNESS
        tx_att /= tx_att.max()
        cls_weight = im_att[0]
        patch_weights = im_att[1:].view(num_edge_patches,num_edge_patches)
        mask = np.concatenate((red, patch_weights.cpu().numpy()[..., None]), axis=-1)
        background = raw_image.convert("L").convert("RGBA")
        mask = Image.fromarray(np.uint8(mask * 255), mode="RGBA").resize(background.size, resample=Image.Resampling.BICUBIC)  # BICUBIC, NEAREST
        background.paste(mask, (0,0), mask)
        display(background)
        print(f"CLS weight: {cls_weight:.2f}\n")
        for decoded_token, weight in zip(decoded_tokens, tx_att):
            if not decoded_token.endswith(" ") and not decoded_token.endswith(">"):
                decoded_token = decoded_token + "-"
            if decoded_token == "<end_of_text>":
                print(f"     {decoded_token} {weight.item():.2f}")
                break
            elif weight.item() >= 1.0:
                print(f"**** {decoded_token} {weight.item():.2f}")
            elif weight.item() >= 0.75:
                print(f"***  {decoded_token} {weight.item():.2f}")
            elif weight.item() >= 0.5:
                print(f"**   {decoded_token} {weight.item():.2f}")
            elif weight.item() >= 0.25:
                print(f"*    {decoded_token} {weight.item():.2f}")
            else:
                print(f"     {decoded_token} {weight.item():.2f}")

    print("\n==========================\n")