In [2]:
import os
import math
import argparse
import numpy as np
import json
from tqdm import tqdm
from PIL import Image
import torch
from transformers import (
    CLIPModel,
    CLIPProcessor,
    AutoModel,
    AutoImageProcessor,
)
from torchvision import transforms
from torchmetrics.image.fid import FrechetInceptionDistance
# (改) 改回下面的注释
# from .GraphAdapter import PadToSquare
# from .infer_pipeline import InstructG2IPipeline
from instructg2i.GraphAdapter import PadToSquare
from instructg2i.infer_pipeline import InstructG2IPipeline
import wandb

def load_config(config_path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

parser = argparse.ArgumentParser(description="Run inference with InstructG2I.")
# (改) 改下面路径
# parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
parser.add_argument("--config", 
                    type=str, 
                    default='./config/test_goodreaders.json',
                    help="Path to the config file.")
args = parser.parse_args()

config = load_config(args.config)
args = argparse.Namespace(**config)

def read_data(test_dir):
    data = []
    with open(os.path.join(test_dir, 'metadata.jsonl')) as f:
        readin = f.readlines()
        for line in tqdm(readin):
            tmp = json.loads(line)
            data.append({
                'text': tmp['text'],
                'center_image': Image.open(os.path.join(test_dir, tmp['center'])).convert("RGB"),
                'neighbor_image': [Image.open(os.path.join(test_dir, fname)).convert("RGB") for fname in tmp[args.neighbor_key]]
            })
    return data

# Evaluator
# clip_id = "openai/clip-vit-large-patch14"
# dino_id = "facebook/dinov2-large"
clip_id = os.path.join(args.cache_dir,"clip-vit-large-patch14")
dino_id = os.path.join(args.cache_dir,"dinov2-large")
clip_model = CLIPModel.from_pretrained(clip_id, cache_dir=args.cache_dir).to(args.device)
clip_processor = CLIPProcessor.from_pretrained(clip_id, cache_dir=args.cache_dir)
dino_model = AutoModel.from_pretrained(dino_id, cache_dir=args.cache_dir).to(args.device)
dino_processor = AutoImageProcessor.from_pretrained(dino_id, cache_dir=args.cache_dir)

usage: ipykernel_launcher.py [-h] [--config CONFIG]
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/ai/.local/share/jupyter/runtime/kernel-v390f4240f7745dbf9ccb0e3488434a0bd483b2361.json


SystemExit: 2

In [3]:
%tb

SystemExit: 2

In [None]:
print(args.neighbor_key)

# read the data
print('Reading data...')
dataset = read_data(args.test_dir)

# image transformation function
neighbor_transforms = transforms.Compose(
            [
                PadToSquare(fill=(args.resolution, args.resolution, args.resolution), padding_mode='constant'),
                transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(args.resolution)
            ]
    )

def neighbor_transform_func(neighbor_images, gt_image):
    neighbor_image = [neighbor_transforms(n_img) for n_img in neighbor_images]
    neighbor_image += [neighbor_transforms(Image.fromarray(np.uint8(np.zeros_like(np.array(gt_image)))).convert('RGB'))] * (args.neighbor_num - len(neighbor_image))
    return neighbor_image

def neighbor_mask_func(neighbor_images):
    neighbor_mask = [1] * len(neighbor_images)
    neighbor_mask += [0] * (args.neighbor_num - len(neighbor_mask))
    return neighbor_mask

# init the pipeline(加载模型！最重要一步)
print('Loading diffusion model...')
pipe_graph2img = InstructG2IPipeline.from_pretrained(args.model_dir, args.neighbor_num, device=args.device)

# run inference
print('Scoring...')
img_clip_scores = []
dinov2_scores = []

print(f'Total testing data:{len(dataset)}, max index: {args.max_index}')
assert args.max_index <= len(dataset)
num_diff_iter = math.ceil(args.max_index / args.diffusion_infer_batch_size)
num_score_iter = math.ceil(args.max_index / args.score_batch_size)

# diffusion model inference
gt_images = []
gen_images = []

In [None]:
idx=0
start = idx * args.diffusion_infer_batch_size
end = min(args.max_index, (idx + 1) * args.diffusion_infer_batch_size)

# get current batch data
texts = [dataset[idd]['text'] for idd in range(start, end)]
neighbor_images = [neighbor_transform_func(dataset[idd]["neighbor_image"][:args.neighbor_num], dataset[idd]["center_image"]) for idd in range(start, end)]
# 其实并不是mask，没有遮盖住实际存在的邻居，而是确定每一个位置是否有邻居，为存在邻居标记1，不存在标记0，总长度为args.neighbor_num
neighbor_masks = [neighbor_mask_func(dataset[idd]["neighbor_image"][:args.neighbor_num]) for idd in range(start, end)]
        
gen_image = pipe_graph2img(prompt=texts, neighbor_image=neighbor_images, neighbor_mask=torch.LongTensor(neighbor_masks), num_inference_steps=args.num_inference_steps).images

In [None]:
gen_image[0].show()