In [1]:
from instruct_tri2tri.tsr.system import InstructTri2Tri, TSR
from torch.utils.data import Dataset
import torch
import json
from tqdm import tqdm

class InstructTri2TriDataset(Dataset):
    def __init__(self,
                 data_path):
        super().__init__()
        datas = json.load(open(data_path))[:200]
        self.datas = datas

    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, index):
        data = self.datas[index]
        image_name = data['image']
        instruct_image_name = data['instruct_image']
        image = Image.open(f'data/{image_name}')
        instruct_image = Image.open(f'data/{instruct_image_name}')
        # image_name = image_name.split('/')[-1]
        instruct = data['instruct']
        return image, instruct_image, instruct


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = InstructTri2Tri.from_pretrained(
    'instruct_tri2tri/tsr/instruct_tri2tri_config',
    config_name="config.yaml",
    weight_name="model.ckpt",
)
dataset = InstructTri2TriDataset('data/objaverse/cap3d_automated_objaverse_highquality_instruct_550k.json')

In [3]:
model.requires_grad_(False)
model.image_tokenizer.requires_grad_(False)
model.tokenizer.requires_grad_(False)
model.text_encoder.requires_grad_(False)
model.backbone.requires_grad_(False)
model.post_processor.requires_grad_(False)
model.decoder.requires_grad_(False)
model.renderer.requires_grad_(False)
model.instruction_converter.requires_grad_(True)
model.cuda(3)

InstructTri2Tri(
  (image_tokenizer): DINOSingleImageTokenizer(
    (model): ViTModel(
      (embeddings): ViTEmbeddings(
        (patch_embeddings): ViTPatchEmbeddings(
          (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        )
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): ViTEncoder(
        (layer): ModuleList(
          (0-11): 12 x ViTLayer(
            (attention): ViTAttention(
              (attention): ViTSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): ViTSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
 

In [4]:
optimizer = torch.optim.AdamW(model.instruction_converter.parameters(), lr=2e-5)
loss_fn = torch.nn.MSELoss()
device = 'cuda:3'

In [5]:
from PIL import Image
image, instruct_image, instruct = dataset[0]
with torch.no_grad():
    target_tokens = model.forward_tsr([instruct_image], device, True)
    torch.cuda.empty_cache()
    pred_tokens = model([image], [instruct], device)
    dim = target_tokens.shape[3]
    loss = loss_fn(pred_tokens.reshape(-1, dim), target_tokens.reshape(-1, dim))

In [7]:
image, instruct_image, instruct = dataset[1]
with torch.no_grad():
    target_tokens = model.forward_tsr([instruct_image], device, True)
    torch.cuda.empty_cache()
    pred_tokens = model([image], [instruct], device)
    dim = target_tokens.shape[2]
    loss = loss_fn(pred_tokens.reshape(-1, dim), target_tokens.reshape(-1, dim))

TypeError: forward() takes 4 positional arguments but 5 were given

In [8]:
from PIL import Image
for image, instruct_image, instruct in tqdm(dataset):
    with torch.no_grad():
        target_tokens = model.forward_tsr([instruct_image], device, True)
        pred_tokens = model([image], [instruct], device)
    # dim = target_tokens.shape[2]
    # loss = loss_fn(pred_tokens.reshape(-1, dim), target_tokens.reshape(-1, dim))
    # loss.backward()
    # optimizer.step()
    # optimizer.zero_grad()

  6%|▌         | 12/200 [00:07<01:53,  1.65it/s]


KeyboardInterrupt: 

In [None]:
from transformers import CLIPTextModelWithProjection, AutoTokenizer

clip = CLIPTextModelWithProjection.from_pretrained('ckpts/clip-vit-large-patch14').cuda()
tokenizer = AutoTokenizer.from_pretrained('ckpts/clip-vit-large-patch14')

In [None]:
inputs = tokenizer(['asdasdasdasd', 'sadasdasd'], padding=True, return_tensors='pt')
for key, value in inputs.items():
    inputs[key] = value.cuda()
out = clip(**inputs)

In [None]:
out.keys()

odict_keys(['text_embeds', 'last_hidden_state'])

In [None]:
out.text_embeds.shape

torch.Size([2, 768])