# Generate metadata for training data

In [1]:
import os
import json
from PIL import Image
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration,BitsAndBytesConfig
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
img_type = "river"
prefix_path = f"/home/swc/AICUP_generative/training/{img_type}/train/"
train_path = f"training/{img_type}/train/"
img_path = f"training/{img_type}/train/images/"
conditioning_img_path = f"training/{img_type}/train/conditioning_images/"
img_list = os.listdir(img_path)
img_list.sort()
conditioning_img_list = os.listdir(conditioning_img_path)
conditioning_img_list.sort()

In [29]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [30]:
base_model = "Salesforce/blip2-opt-6.7b-coco"
quantization_config = BitsAndBytesConfig(load_in_8bit=True,llm_int8_enable_fp32_cpu_offload=True)
processor = Blip2Processor.from_pretrained(base_model)
model = Blip2ForConditionalGeneration.from_pretrained(
    base_model,
    quantization_config=quantization_config,
    torch_dtype=torch.float16,
    device_map="auto",
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:35<00:00,  8.85s/it]


## Generate prompt to metadata

### For river

圖像翻轉，生成資料

In [3]:
"a.png".split(".")

['a', 'png']

In [4]:
with tqdm(total=len(img_list)) as pbar:
    for img, conditioning_img in zip(img_list, conditioning_img_list):
        pbar.update(1)
        img_pil = Image.open(img_path + img)
        new_img_name = img.split(".")
        img_pil.transpose(Image.FLIP_LEFT_RIGHT).save(f"{img_path}{new_img_name[0]}_flip.{new_img_name[1]}")
        conditioning_img_pil = Image.open(conditioning_img_path + conditioning_img)
        new_img_name = conditioning_img.split(".")
        conditioning_img_pil.transpose(Image.FLIP_LEFT_RIGHT).save(f"{conditioning_img_path}{new_img_name[0]}_flip.{new_img_name[1]}")


100%|██████████| 2160/2160 [00:06<00:00, 322.91it/s]


區分俯瞰和河岸

In [7]:
with tqdm(total=len(img_list)) as pbar:
    if img_type == "river":
        progress = 0
        with open(f"{train_path}/metadata.jsonl", "w") as f:
            for img, conditioning_img in zip(img_list, conditioning_img_list):
                pbar.update(1)
                
                img_mat_pil = Image.open(conditioning_img_path + conditioning_img)
                img_mat = np.array(img_mat_pil)
                prompt = "river, "
                for row in img_mat:
                    row = row[:, 0]
                    if row.max() > 128:
                        if row[0] > 128 or row[-1] > 128:
                            prompt += "lush shore"
                        else:
                            prompt += "rural ,aerial view"
                    break
                line = {
                    "file_name": f"images/{img}",
                    "image": f"{prefix_path}images/{img}",
                    "conditioning_image": f"{prefix_path}conditioning_images/{conditioning_img}",
                    "text": prompt,
                }
                f.write(json.dumps(line) + "\n")
                line = {
                    "file_name": f"conditioning_images/{conditioning_img}",
                    "image": f"{prefix_path}images/{img}",
                    "conditioning_image": f"{prefix_path}conditioning_images/{conditioning_img}",
                    "text": prompt,
                }
                f.write(json.dumps(line) + "\n")

100%|██████████| 4320/4320 [00:02<00:00, 1881.80it/s]


### For road
翻轉顛倒的圖片

In [2]:
img_type = "road"
prefix_path = f"/home/swc/AICUP_generative/training/{img_type}/train/"
train_path = f"training/{img_type}/train/"
img_path = f"training/{img_type}/train/images/"
conditioning_img_path = f"training/{img_type}/train/conditioning_images/"
img_list = os.listdir(img_path)
img_list.sort()
conditioning_img_list = os.listdir(conditioning_img_path)
conditioning_img_list.sort()

- 翻轉

In [28]:
with tqdm(total=len(img_list)) as pbar:
    if img_type == "road":
        progress = 0
        
        for img, conditioning_img in zip(img_list, conditioning_img_list):
            pbar.update(1)
            
            img_mat_pil = Image.open(conditioning_img_path + conditioning_img)
            img_mat = np.array(img_mat_pil)
            for i in range(img_mat.shape[0]):
                if (img_mat[i,:]>128).sum()>0 :
                    top_of_road = i
                    break
            for i in range(img_mat.shape[0]-1,-1,-1):
                if (img_mat[i,:]>128).sum()>0 :
                    bottom_of_road = i
                    break
            if (img_mat[top_of_road,:,0]>128).sum() > (img_mat[bottom_of_road,:,0]>128).sum():
                img_mat_pil.rotate(180).save(conditioning_img_path + conditioning_img)
                img_pil = Image.open(img_path +img)
                img_pil.rotate(180).save(img_path + img)

100%|██████████| 2160/2160 [00:03<00:00, 719.15it/s] 


- 生成metadata

In [41]:
with tqdm(total=len(img_list)) as pbar:
    if img_type == "road":
        progress = 0
        with open(f"{train_path}/metadata.jsonl", "w") as f:
            for img, conditioning_img in zip(img_list, conditioning_img_list):
                pbar.update(1)
                
                img_pil = Image.open(img_path + img)
                inputs = processor(
                                images=img_pil, text="river", return_tensors="pt"
                            ).to(device, torch.float16)
                generated_ids = model.generate(**inputs)
                generated_text = processor.batch_decode(
                                generated_ids, skip_special_tokens=True
                            )[0].strip()
                prompt = "road"
                if "motorcycle" in generated_text:
                    prompt+=", motorcycle"
                if "car" in generated_text:
                    prompt+=", car"
                
                line = {
                    "file_name": f"images/{img}",
                    "image": f"{prefix_path}images/{img}",
                    "conditioning_image": f"{prefix_path}conditioning_images/{conditioning_img}",
                    "text": prompt,
                }
                f.write(json.dumps(line) + "\n")
                line = {
                    "file_name": f"conditioning_images/{conditioning_img}",
                    "image": f"{prefix_path}images/{img}",
                    "conditioning_image": f"{prefix_path}conditioning_images/{conditioning_img}",
                    "text": prompt,
                }
                f.write(json.dumps(line) + "\n")

100%|██████████| 2160/2160 [1:26:54<00:00,  2.41s/it]
