In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import urllib
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoFeatureExtractor
from frozen.experiment import Experiment
from frozen.data import COCODataset, CC3MDataset, IMAGE_TOKEN, SPECIAL_TOKEN_DICT

In [3]:
!ls /teamspace/studios/this_studio/logs/run-sample-CCM/1.0.0/checkpoints/epoch=0-step=2.ckpt

'/teamspace/studios/this_studio/logs/run-sample-CCM/1.0.0/checkpoints/epoch=0-step=2.ckpt'


In [3]:
ckpt_path = "/teamspace/studios/this_studio/Frozen-Interleaved/frozen/wandb/run-20241212_080250-2c5rec8q/files/checkpoints/epoch=0-step=250.ckpt"
# ckpt_path = osp.abspath(osp.expanduser(ckpt_path))
device = 'cuda:0'
experiment = Experiment.load_from_checkpoint(ckpt_path).half().to(device)

Some weights of the model checkpoint at microsoft/resnet-50 were not used when initializing ResNetModel: ['classifier.1.weight', 'classifier.1.bias']
- This IS expected if you are initializing ResNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ResNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
image_encoder = experiment.model.config['image_encoder']
text_encoder = experiment.model.config['text_encoder']
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder)
tokenizer = AutoTokenizer.from_pretrained(text_encoder)
num_image_tokens = experiment.model.config['num_image_tokens']
if not IMAGE_TOKEN in tokenizer.all_special_tokens:
    tokenizer.add_special_tokens(SPECIAL_TOKEN_DICT)

In [6]:
data = DCCDataset(
    path='/teamspace/studios/this_studio/DialogCC/DialogCC',
    split='train'
    image_transform=feature_extractor,
    tokenizer=tokenizer,
    num_image_tokens=num_image_tokens,
)
# )/teamspace/studios/this_studio/Frozen-Interleaved/frozen/datasets/conceptualcaptions/script/DownloadConceptualCaptions/sample_validation.tsv



In [7]:
# data.image_token_id
data[0]

{'pixel_values': tensor([[[[-1.9467, -1.9638, -1.9809,  ..., -2.0152, -2.0152, -2.0152],
           [-1.9638, -1.9467, -1.9467,  ..., -2.0152, -1.9980, -1.9809],
           [-1.9467, -1.9295, -1.9295,  ..., -1.9980, -1.9809, -1.9809],
           ...,
           [-1.8782, -1.8953, -1.9124,  ..., -2.0665, -2.0665, -2.0494],
           [-1.9467, -1.9467, -1.9638,  ..., -2.0665, -2.0665, -2.0494],
           [-1.9809, -1.9638, -1.9809,  ..., -2.0665, -2.0665, -2.0665]],
 
          [[-1.8606, -1.8782, -1.8957,  ..., -1.9307, -1.9307, -1.9307],
           [-1.8782, -1.8606, -1.8606,  ..., -1.9307, -1.9132, -1.8957],
           [-1.8606, -1.8431, -1.8431,  ..., -1.9132, -1.8957, -1.8957],
           ...,
           [-1.7906, -1.8081, -1.8256,  ..., -1.9832, -1.9832, -1.9657],
           [-1.8606, -1.8606, -1.8782,  ..., -1.9832, -1.9832, -1.9657],
           [-1.8957, -1.8782, -1.8957,  ..., -1.9832, -1.9832, -1.9832]],
 
          [[-1.6302, -1.6476, -1.6650,  ..., -1.6999, -1.6999, -1.6999

In [8]:
print(tokenizer.decode(50118))





In [9]:
# Few shot
item0 = data[0]
prompt = ' '.join([IMAGE_TOKEN for i in range(num_image_tokens)]) + 'Image of life in photography with cinematic tone'
item = data[1] # image
prompt += ' '.join([IMAGE_TOKEN for i in range(num_image_tokens)]) + 'What is this image?'
print(prompt)
tokens = data.tokenizer(prompt)
input_ids = torch.tensor(tokens['input_ids']).unsqueeze(0)
decoded_pairs = [(token_id, data.tokenizer.decode([token_id])) for token_id in input_ids[0]]
print(decoded_pairs)
attention_mask = torch.tensor(tokens['attention_mask']).unsqueeze(0)
# attention_mask = torch.tensor([1, 1, 1]).unsqueeze(0)
# inputs = data.tokenizer(prompt, return_tensors='pt')
image_token_id = data.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
print(image_token_id)
image_token_mask = input_ids == image_token_id
print(image_token_mask)

kwargs = {
    'pixel_values': torch.concat([item0['pixel_values'].half().to(device), item['pixel_values'].half().to(device)], axis=0),
    'input_ids': input_ids.to(device),
    'attention_mask': attention_mask.to(device),
    'image_token_mask': image_token_mask.long().to(device),
    'num_beams': 5,
}
with torch.no_grad():
    experiment.model.eval()
    output = experiment.model.generate(**kwargs)

print(output)

decoded = tokenizer.batch_decode(
    output.sequences,  # filter out None values
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)

# display(item['raw_image'])
print(decoded[0])

<image> <image>Image of life in photography with cinematic tone<image> <image>What is this image?
[(tensor(2), '</s>'), (tensor(50265), '<image>'), (tensor(50265), '<image>'), (tensor(8532), 'Image'), (tensor(9), ' of'), (tensor(301), ' life'), (tensor(11), ' in'), (tensor(11075), ' photography'), (tensor(19), ' with'), (tensor(25306), ' cinematic'), (tensor(6328), ' tone'), (tensor(50265), '<image>'), (tensor(50265), '<image>'), (tensor(2264), 'What'), (tensor(16), ' is'), (tensor(42), ' this'), (tensor(2274), ' image'), (tensor(116), '?')]
50265
tensor([[False,  True,  True, False, False, False, False, False, False, False,
         False,  True,  True, False, False, False, False, False]])


BeamSearchDecoderOnlyOutput(sequences=tensor([[    2, 50265, 50265,  8532,     9,   301,    11, 11075,    19, 25306,
          6328, 50265, 50265,  2264,    16,    42,  2274,   116, 50118, 50118]],
       device='cuda:0'), sequences_scores=None, scores=None, beam_indices=None, attentions=None, hidden_states=None)
Image of life in photography with cinematic toneWhat is this image?




In [11]:
item

{'pixel_values': tensor([[[[ 0.9303,  0.8104,  0.7591,  ...,  1.0331,  1.0331,  1.0331],
           [ 0.9817,  0.9988,  0.9646,  ...,  1.0673,  1.0502,  1.0331],
           [ 0.9303,  0.9474,  1.0159,  ...,  1.0673,  1.0331,  1.0159],
           ...,
           [ 1.2214,  1.2728,  1.2214,  ...,  1.3413,  1.2214,  1.2043],
           [ 1.2728,  1.2728,  1.3070,  ...,  1.2899,  1.3927,  1.3755],
           [ 1.3070,  1.2899,  1.2728,  ...,  1.2214,  1.3242,  1.3927]],
 
          [[ 0.7479,  0.6254,  0.5553,  ...,  1.0980,  1.0980,  1.0980],
           [ 0.8179,  0.8354,  0.8004,  ...,  1.1331,  1.1155,  1.0980],
           [ 0.7829,  0.8004,  0.8704,  ...,  1.1331,  1.0980,  1.0805],
           ...,
           [ 0.2752,  0.3277,  0.2752,  ...,  0.6078,  0.5028,  0.4853],
           [ 0.3277,  0.3277,  0.3627,  ...,  0.5203,  0.6429,  0.6429],
           [ 0.3803,  0.3452,  0.3277,  ...,  0.4503,  0.5553,  0.6254]],
 
          [[ 0.8448,  0.7228,  0.6182,  ...,  1.2631,  1.2631,  1.2631

In [63]:
torch.concat([item['pixel_values'].half().to(device), item['pixel_values'].half().to(device)], axis=0).shape

torch.Size([2, 3, 224, 224])