In [1]:
import re
import numpy as np
import sys
from peft import LoraConfig, get_peft_model,PeftModel, PeftConfig
sys.path.append("/home/tsuchida/KLab_MultiModalModel/tsuchida_workdir/..")
from PIL import Image
from typing import List
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

##### 動作確認

In [3]:
import torch
from torchvision import transforms
from PIL import Image,PngImagePlugin

# Decompressed Data Too Largeになることを防ぐ
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

class DatasetLoader(torch.utils.data.Dataset):
    def __init__(self, resize=256):
        self.images, self.tgt_texts, self.src_texts = [], [], []
        self.src_transforms = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225]),
        ])
        self.tgt_transforms = transforms.Compose([
            transforms.Resize((resize, resize)),
            transforms.ToTensor(),
        ])

    def __getitem__(self, idx):
        image, src_text, tgt_text = self.images[idx], self.src_texts[idx], self.tgt_texts[idx]
        image = Image.open(image).convert('RGB')
        src_image = self.src_transforms(image)
        tgt_image = self.src_transforms(image)

        return src_image, tgt_image, src_text, tgt_text
    
    def __len__(self):
        return len(self.images)


In [4]:
import json
import os
from PIL import Image
import torch
from torchvision.transforms import ToTensor

class CC3MDatasetLoader(DatasetLoader):
    def __init__(self,data_dir="/data01/cc3m",phase="train",imagesize=(256,256)):
        super().__init__()
        
        with open(os.path.join(data_dir,f"{phase}.tsv"),"r") as f:
            items = f.read()

        items = items.split("\n")
        items = [item.split("\t") for item in items]
        num = int(len(items)/2)
        # データセット半分のみ使用
        items = items[1:num]

        self.tgt_texts = [item[1] for item in items]
        self.src_texts = ["What does the image describe?"]*len(items)
        self.images = [os.path.join(data_dir,phase,item[0]) for item in items]

In [5]:
train_dataset = CC3MDatasetLoader(data_dir="/data01/cc3m", phase="train")
val_dataset = CC3MDatasetLoader(data_dir="/data01/cc3m", phase="val")

In [6]:
from transformers import AutoTokenizer
from torchvision import transforms
# src_tokenizer = AutoTokenizer.from_pretrained(args.language_model_name, model_max_length=256)
src_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base", model_max_length=256, use_fast=True)

# tgt_tokenizer = AutoTokenizer.from_pretrained(args.language_model_name, model_max_length=256, use_fast=True, extra_ids=0, additional_special_tokens =[f"<extra_id_{i}>" for i in range(100)] + [f"<loc_{i}>" for i in range(1000)] + [f"<img_{i}>" for i in range(args.image_vocab_size)])
resize=256
src_transforms = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225]),
])
tgt_transforms = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.ToTensor(),
])

In [7]:
import numpy as np
def custom_to_pil(x):
  x = x.detach().cpu().numpy()
  mean = np.array([0.485, 0.456, 0.406])
  std = np.array([0.229, 0.224, 0.225])
  x = (((x.transpose(1, 2, 0) * std) + mean) * 255.).astype(np.uint8)

  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x

##### 学習したものを可視化

In [8]:
import argparse
args = argparse.Namespace(
    # Model setting
    image_model_name="microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft",
    image_model_train=False,
    language_model_name="google/flan-t5-small",
    transformer_model_name="google/flan-t5-base",
    ffn=True,
    phase = "train",
    transformer_d_model=768,
    transformer_d_ff=3072,
    # transformer_d_model=1024,
    # transformer_d_ff=4096,
    transformer_d_kv=64,
    transformer_num_heads=12,
    transformer_num_layers=2,
    transformer_num_decoder_layers=12,
    image_vocab_size=16384,
    loc_vocab_size=1600,
    vae_ckpt_path="checkpoints/vqgan.pt",
    max_source_length=256,
    max_target_length=256,
    # Train setting
    pretrain="train", 
    # Dir setting
    root_dir="/data01/",
    result_dir="results/",
    loss = "CrossEntropy",
    loc_learn = "lora",
    float_type = 'bfloat16',
    lora_r = 4,
    lora_alpha = 4,
    lora_dropout = 0.1,
    lora_bias = "none"
)



In [9]:

from models.model import MyModel
model = MyModel(args).to(device)
print(model)
# path = "/home/tsuchida/KLab_MultiModalModel/pth/caption/epoch_50.pth"
# model.load(result_name=path)
tgt_tokenizer = AutoTokenizer.from_pretrained(args.language_model_name, model_max_length=args.max_target_length, use_fast=True, extra_ids=0, additional_special_tokens =[f"<loc_{i}>" for i in range(args.loc_vocab_size)])
model.transformer.resize_token_embeddings(len(tgt_tokenizer))

path = "/home/tsuchida/KLab_MultiModalModel/results/loc/lora/bf16/qkv/1e-5/openimage/1e-5lambda/enc2_dec12/epoch_40.pth"
# path = "/home/tsuchida/KLab_MultiModalModel/results/loc/bf16/scratch/openimage/enc2_dec12/epoch_40.pth"
path = "/home/tsuchida/KLab_MultiModalModel/results/loc/bf16/scratch/openimage/enc2_dec12/best.pth"
path = "/home/tsuchida/KLab_MultiModalModel/results/loc/bf16/lora/openimage/enc2_dec12/epoch_30.pth"
path = "/home/tsuchida/KLab_MultiModalModel/results/loc/bf16/lora/1e-4/openimage/enc2_dec12/bestLora"
path = "/home/tsuchida/KLab_MultiModalModel/results/loc/lora/bf16/scratch/1e-4/openimage/enc2_dec12/best.pth"
path = "/home/tsuchida/KLab_MultiModalModel/results/loc/bf16/scratch/1e-4/openimage/enc0_dec12/bestLora"
path = "/home/tsuchida/KLab_MultiModalModel/results/loc/bf16/scratch/1e-4/openimage/enc0_dec12/epoch_10.pth"
path = "/home/tsuchida/KLab_MultiModalModel/results/loc/sample/openimage/1e-5lambda/enc0_dec12/bestLora"
path = "/home/tsuchida/KLab_MultiModalModel/results/lora/vg/visual_genome_refexp/enc2_dec24/epoch_4"
path = "/home/tsuchida/KLab_MultiModalModel/results/scratch/visual_genome_refexp/enc2_dec24/epoch_2.pth"

# train
path = "/home/tsuchida/KLab_MultiModalModel/results/scratch/base/visual_genome_refexp/enc2_dec12/epoch_50.pth"
# lora
path = "/home/tsuchida/KLab_MultiModalModel/results/lora/base/vg/visual_genome_refexp/enc2_dec12/epoch_50"

path = "/home/tsuchida/KLab_MultiModalModel/results/1201/lora/2e-4/openimage_loc/enc2_dec12/epoch_50"
# path = "/home/tsuchida/KLab_MultiModalModel/results/1207/openimage_loc/enc2_dec12/epoch_2"


if (args.loc_learn == "lora"):
    model = PeftModel.from_pretrained(model, path)
else:
    model.load(result_name=path)
# model = set_peft_model_state_dict(model, path)
model = PeftModel.from_pretrained(model, path)
print(model)
model.eval()

T5Config {
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 2048,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "max_length": 256,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 2,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "task_specific_params": {
    "summarization": {
      "early_stopping": true,
      "length_penalty": 2.0,
      "max_length": 200,
      "min_length": 30,
      "no_repeat_ngram_size": 3,
      "num_beams": 4,
      "prefix": "summarize: "
    },
    "translation_en_to_de": {
      "early_stopping": true,
      "max_length": 300,
    

PeftModel(
  (base_model): LoraModel(
    (model): PeftModel(
      (base_model): LoraModel(
        (model): MyModel(
          (language_model): T5EncoderModel(
            (shared): Embedding(32128, 512)
            (encoder): T5Stack(
              (embed_tokens): Embedding(32128, 512)
              (block): ModuleList(
                (0): T5Block(
                  (layer): ModuleList(
                    (0): T5LayerSelfAttention(
                      (SelfAttention): T5Attention(
                        (q): Linear(in_features=512, out_features=384, bias=False)
                        (k): Linear(in_features=512, out_features=384, bias=False)
                        (v): Linear(in_features=512, out_features=384, bias=False)
                        (o): Linear(in_features=384, out_features=512, bias=False)
                        (relative_attention_bias): Embedding(32, 6)
                      )
                      (layer_norm): T5LayerNorm()
                      (dropout):

In [10]:
def calculate_iou(boxA: List[float], boxB: List[float]) -> float:
    # Compute the intersection area
    interArea = max(0, min(boxA[2], boxB[2]) - max(boxA[0], boxB[0])) * max(0, min(boxA[3], boxB[3]) - max(boxA[1], boxB[1]))
    if interArea == 0:
        print("interArea: ", interArea)
        return 0.0
    # Compute the area of both the prediction and ground-truth rectangles
    boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    # Compute the intersection over union
    iou = interArea / float(boxAArea + boxBArea - interArea)
    print("iou: ", iou)
    return iou

In [11]:
def show_result(dataset, idx=10):
    src_image, tgt_image, src_text, tgt_text = dataset[idx]
    with torch.no_grad():
        src_image = src_image.unsqueeze(0).to(device)
        tgt_text2 = tgt_text
        print('src_text:', src_text)
        print('tgt_text:', tgt_text)
        src_text = src_tokenizer(src_text, padding="longest", max_length=args.max_source_length, return_tensors='pt')['input_ids'].to(device) # ['pt', 'tf', 'np', 'jax']
        tgt_text = tgt_tokenizer(tgt_text, padding="longest", max_length=args.max_target_length, return_tensors='pt')['input_ids'].to(device) # ['pt', 'tf', 'np', 'jax']
        # print(src_text, tgt_text)

        # display(custom_to_pil(src_image[0]))
        src_attention_masks = torch.ones_like(src_text, device=device, dtype=torch.bool)
        src_attention_masks[src_text == 0] = 0
        tgt_attention_masks = torch.ones_like(tgt_text, device=device, dtype=torch.bool)
        tgt_attention_masks[tgt_text == 0] = 0
        preds= model(src_image, src_text, src_attention_masks, tgt_text, tgt_attention_masks,return_loss=False)

        # print(loss)
        preds = tgt_tokenizer.batch_decode(preds[:,1:-1])
        print('pred:', preds[0])
        print(type(preds))
        img = custom_to_pil(src_image[0])
        


In [12]:
show_result(train_dataset, idx=0)

src_text: What does the image describe?
tgt_text: old cars by building function
pred: eignen choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir c

In [13]:
show_result(train_dataset, idx=0)
show_result(train_dataset, idx=1)
show_result(train_dataset, idx=2)
show_result(train_dataset, idx=6)
show_result(train_dataset, idx=7)
show_result(val_dataset, idx=0)
show_result(val_dataset, idx=1)
show_result(val_dataset, idx=5)
show_result(val_dataset, idx=6)

src_text: What does the image describe?
tgt_text: old cars by building function
pred: eignen choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir choir c

In [14]:
src_tokenizer.encode('Whale')

[30300, 1]