In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import os
import os.path as osp
import urllib
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoFeatureExtractor
from frozen.experiment import Experiment
from frozen.data import COCODataset, CC3MDataset, IMAGE_TOKEN, SPECIAL_TOKEN_DICT

In [3]:
!ls /teamspace/studios/this_studio/logs/run-sample-CCM/1.0.0/checkpoints/epoch=0-step=2.ckpt

'/teamspace/studios/this_studio/logs/run-sample-CCM/1.0.0/checkpoints/epoch=0-step=2.ckpt'


In [4]:
ckpt_path = "/teamspace/studios/this_studio/logs/run-sample-CCM/1.0.0/checkpoints/epoch=0-step=2.ckpt"
# ckpt_path = osp.abspath(osp.expanduser(ckpt_path))
device = 'cuda:0'
experiment = Experiment.load_from_checkpoint(ckpt_path).half().to(device)

Some weights of the model checkpoint at microsoft/resnet-50 were not used when initializing ResNetModel: ['classifier.1.weight', 'classifier.1.bias']
- This IS expected if you are initializing ResNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ResNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
image_encoder = experiment.model.config['image_encoder']
text_encoder = experiment.model.config['text_encoder']
feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder)
tokenizer = AutoTokenizer.from_pretrained(text_encoder)
num_image_tokens = experiment.model.config['num_image_tokens']
if not IMAGE_TOKEN in tokenizer.all_special_tokens:
    tokenizer.add_special_tokens(SPECIAL_TOKEN_DICT)

In [12]:
data = CC3MDataset(
    path='Frozen-Interleaved/frozen/datasets/conceptualcaptions/',
    image_transform=feature_extractor,
    tokenizer=tokenizer,
    num_image_tokens=num_image_tokens,
)
# )/teamspace/studios/this_studio/Frozen-Interleaved/frozen/datasets/conceptualcaptions/script/DownloadConceptualCaptions/sample_validation.tsv



In [17]:
# data.image_token_id
data[0]

{'pixel_values': tensor([[[[-1.9467, -1.9638, -1.9809,  ..., -2.0152, -2.0152, -2.0152],
           [-1.9638, -1.9467, -1.9467,  ..., -2.0152, -1.9980, -1.9809],
           [-1.9467, -1.9295, -1.9295,  ..., -1.9980, -1.9809, -1.9809],
           ...,
           [-1.8782, -1.8953, -1.9124,  ..., -2.0665, -2.0665, -2.0494],
           [-1.9467, -1.9467, -1.9638,  ..., -2.0665, -2.0665, -2.0494],
           [-1.9809, -1.9638, -1.9809,  ..., -2.0665, -2.0665, -2.0665]],
 
          [[-1.8606, -1.8782, -1.8957,  ..., -1.9307, -1.9307, -1.9307],
           [-1.8782, -1.8606, -1.8606,  ..., -1.9307, -1.9132, -1.8957],
           [-1.8606, -1.8431, -1.8431,  ..., -1.9132, -1.8957, -1.8957],
           ...,
           [-1.7906, -1.8081, -1.8256,  ..., -1.9832, -1.9832, -1.9657],
           [-1.8606, -1.8606, -1.8782,  ..., -1.9832, -1.9832, -1.9657],
           [-1.8957, -1.8782, -1.8957,  ..., -1.9832, -1.9832, -1.9832]],
 
          [[-1.6302, -1.6476, -1.6650,  ..., -1.6999, -1.6999, -1.6999

In [55]:
print(tokenizer.decode(50118))





In [69]:
# Few shot
item0 = data[0]
prompt = ' '.join([IMAGE_TOKEN for i in range(num_image_tokens)]) + 'Image of life in photography with cinematic tone'
item = data[1] # image
prompt += ' '.join([IMAGE_TOKEN for i in range(num_image_tokens)]) + 'What is this image?'
print(prompt)
tokens = data.tokenizer(prompt)
input_ids = torch.tensor(tokens['input_ids']).unsqueeze(0)
decoded_pairs = [(token_id, data.tokenizer.decode([token_id])) for token_id in input_ids[0]]
print(decoded_pairs)
attention_mask = torch.tensor(tokens['attention_mask']).unsqueeze(0)
# attention_mask = torch.tensor([1, 1, 1]).unsqueeze(0)
# inputs = data.tokenizer(prompt, return_tensors='pt')
image_token_id = data.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
print(image_token_id)
image_token_mask = input_ids == image_token_id
print(image_token_mask)

kwargs = {
    'pixel_values': torch.concat([item0['pixel_values'].half().to(device), item['pixel_values'].half().to(device)], axis=0),
    'input_ids': input_ids.to(device),
    'attention_mask': attention_mask.to(device),
    'image_token_mask': image_token_mask.long().to(device),
    'num_beams': 5,
}
with torch.no_grad():
    experiment.model.eval()
    output = experiment.model.generate(**kwargs)

print(output)

decoded = tokenizer.batch_decode(
    output.sequences,  # filter out None values
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False,
)

# display(item['raw_image'])
print(decoded[0])

<image> <image>Image of life in photography with cinematic tone<image> <image>What is this image?
[(tensor(2), '</s>'), (tensor(50265), '<image>'), (tensor(50265), '<image>'), (tensor(8532), 'Image'), (tensor(9), ' of'), (tensor(301), ' life'), (tensor(11), ' in'), (tensor(11075), ' photography'), (tensor(19), ' with'), (tensor(25306), ' cinematic'), (tensor(6328), ' tone'), (tensor(50265), '<image>'), (tensor(50265), '<image>'), (tensor(2264), 'What'), (tensor(16), ' is'), (tensor(42), ' this'), (tensor(2274), ' image'), (tensor(116), '?')]
50265
tensor([[False,  True,  True, False, False, False, False, False, False, False,
         False,  True,  True, False, False, False, False, False]])
torch.Size([2, 3, 224, 224])
tensor([[-0.0971, -0.0236,  0.0179,  ...,  0.1426, -0.0822,  0.0392],
        [-0.0383,  0.0132,  0.0381,  ...,  0.0201,  0.0230,  0.0150]],
       device='cuda:0', dtype=torch.float16) torch.Size([2, 1024])
BeamSearchDecoderOnlyOutput(sequences=tensor([[    2, 50265, 50

In [62]:
item['pixel_values'].shape

torch.Size([1, 3, 224, 224])

In [63]:
torch.concat([item['pixel_values'].half().to(device), item['pixel_values'].half().to(device)], axis=0).shape

torch.Size([2, 3, 224, 224])