# Generate metadata for training data

In [1]:
import os
import json
from PIL import Image
import numpy as np
from transformers import Blip2Processor, Blip2ForConditionalGeneration,BitsAndBytesConfig
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
img_type = "river"
prefix_path = f"/home/swc/AICUP_generative/training/{img_type}/train/"
train_path = f"training/{img_type}/train/"
img_path = f"training/{img_type}/train/images/"
conditioning_img_path = f"training/{img_type}/train/conditioning_images/"
img_list = os.listdir(img_path)
img_list.sort()
conditioning_img_list = os.listdir(conditioning_img_path)
conditioning_img_list.sort()

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
base_model = "Salesforce/blip2-opt-6.7b-coco"
quantization_config = BitsAndBytesConfig(load_in_8bit=True,llm_int8_enable_fp32_cpu_offload=True)
processor = Blip2Processor.from_pretrained(base_model)
model = Blip2ForConditionalGeneration.from_pretrained(
    base_model,
    quantization_config=quantization_config,
    torch_dtype=torch.float16,
    device_map="auto",
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:36<00:00,  9.12s/it]


## Generate prompt to metadata

### For river
區分俯瞰和河岸

In [5]:
with tqdm(total=len(img_list)) as pbar:
    if img_type == "river":
        progress = 0
        with open(f"{train_path}/metadata.jsonl", "w") as f:
            for img, conditioning_img in zip(img_list, conditioning_img_list):
                pbar.update(1)
                
                img_mat_pil = Image.open(conditioning_img_path + conditioning_img)
                img_mat = np.array(img_mat_pil)
                prompt = "river, "
                for row in img_mat:
                    row = row[:, 0]
                    if row.max() > 128:
                        if row[0] > 128 or row[-1] > 128:
                            prompt += "lush shore"
                        else:
                            prompt += "rural ,aerial view"
                            inputs = processor(
                                images=img_mat_pil, text="river", return_tensors="pt"
                            ).to(device, torch.float16)
                            generated_ids = model.generate(**inputs)
                            generated_text = processor.batch_decode(
                                generated_ids, skip_special_tokens=True
                            )[0].strip()
                            if "bridge" in generated_text:
                                prompt += ", bridge"
                            if "field" in generated_text:
                                prompt += ", field"
                    break
                line = {
                    "file_name": f"images/{img}",
                    "image": f"{prefix_path}images/{img}",
                    "conditioning_image": f"{prefix_path}conditioning_images/{conditioning_img}",
                    "text": prompt,
                }
                f.write(json.dumps(line) + "\n")
                line = {
                    "file_name": f"conditioning_images/{conditioning_img}",
                    "image": f"{prefix_path}images/{img}",
                    "conditioning_image": f"{prefix_path}conditioning_images/{conditioning_img}",
                    "text": prompt,
                }
                f.write(json.dumps(line) + "\n")

1/2160 0.05%
2/2160 0.09%
3/2160 0.14%
4/2160 0.19%
5/2160 0.23%
6/2160 0.28%
7/2160 0.32%
8/2160 0.37%
9/2160 0.42%
10/2160 0.46%
11/2160 0.51%
12/2160 0.56%
13/2160 0.60%
14/2160 0.65%
15/2160 0.69%
16/2160 0.74%
17/2160 0.79%
18/2160 0.83%
19/2160 0.88%
20/2160 0.93%
21/2160 0.97%
22/2160 1.02%
23/2160 1.06%
24/2160 1.11%
25/2160 1.16%
26/2160 1.20%
27/2160 1.25%
28/2160 1.30%
29/2160 1.34%
30/2160 1.39%
31/2160 1.44%
32/2160 1.48%
33/2160 1.53%
34/2160 1.57%
35/2160 1.62%
36/2160 1.67%
37/2160 1.71%
38/2160 1.76%
39/2160 1.81%
40/2160 1.85%
41/2160 1.90%
42/2160 1.94%
43/2160 1.99%
44/2160 2.04%
45/2160 2.08%
46/2160 2.13%
47/2160 2.18%
48/2160 2.22%
49/2160 2.27%
50/2160 2.31%
51/2160 2.36%
52/2160 2.41%
53/2160 2.45%
54/2160 2.50%
55/2160 2.55%
56/2160 2.59%
57/2160 2.64%
58/2160 2.69%
59/2160 2.73%
60/2160 2.78%
61/2160 2.82%
62/2160 2.87%
63/2160 2.92%
64/2160 2.96%
65/2160 3.01%
66/2160 3.06%
67/2160 3.10%
68/2160 3.15%
69/2160 3.19%
70/2160 3.24%
71/2160 3.29%
72/2160 3.33%
7

### For road
翻轉顛倒的圖片