In [2]:
import json
import torch
import json
import os
from cgllike import CLDM
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler
from PIL import Image
import numpy as np
from dataset_eval import ImageLayout
import gc
from tqdm import tqdm
from data_utils import norm_bbox

In [2]:
device = f"cuda:{7}" if torch.cuda.is_available() else "cpu"

In [3]:
diffusion = DDPMScheduler(num_train_timesteps=250, prediction_type='sample', clip_sample=True)

In [4]:
def seg_loss(predicted_box, sample, zero_count,cxcy,wh):
    src = sample['sr']
    src_list = []
    
    for i in src:
        src_ = Image.open(i)
        src_list.append(src_)
    
    box = predicted_box.cpu().numpy()
    box = (box + 1) / 2
    
    match_list = []
    size_list = []
    for i in range(box.shape[0]):
        img = src_list[i]
        width, height = img.size
        cx, cy, w, h = box[i]
        x = int((cx - w / 2) * width)
        y = int((cy - h / 2) * height)
        x2 = int((cx + w / 2) * width)
        y2 = int((cy + h / 2) * height)
        boxes = (x, y, x2, y2)

        # 이미지 크롭
        crop = img.crop(boxes)
        crop = np.array(crop)


        if crop.size == 0:
            print(f"Warning: Crop size is zero for box {boxes}.")
            match_list.append(0)  
            zero_count+=1
            continue
        
        blue_channel = crop[:, :, 2]
        blue_channel_flatten = blue_channel.flatten()
        

        match_pixel_size = np.sum(blue_channel_flatten == 128) / blue_channel_flatten.size
        match_list.append(match_pixel_size)

        if crop.size ==0:
            size_list.append(1)

        else:
            _cx = cx*width
            _cy = cy*height

            center_point = np.array([_cx,_cy])
            normalized_area = w*h

            distances = np.linalg.norm(cxcy - center_point, axis=1)

            min_index = np.argmin(distances)
            gtw, gth =wh[min_index]
            gtwh = gtw*gth
            size_list.append(abs(gtwh-normalized_area))

    # NaN 방지를 위해 match_list가 비어 있지 않은지 확인
    if len(match_list) == 0:
        print("Error: All crops have zero size. Returning NaN.")
        return float('nan')
    
    value = sum(match_list) / len(match_list)
    size = sum(size_list)/ len(size_list)

    print(value)
    print('size here')
    print(size_list)
    return value, zero_count, size

## Load Model with Seperate Param dict

In [3]:
val = ImageLayout(type='92.158.10')
dataset = DataLoader(val, batch_size=256)

In [5]:
val.__len__()

277

In [18]:
cxcy = np.array([[item[0], item[1]] for item in (val.box_list)])
wh = np.array([[item[2]/800, item[3]/600] for item in (val.box_list)])

In [19]:
len(wh)

277

In [25]:
epoch = []
for i in range(0,300, 30):
    epoch.append(i)

In [28]:
src={}
cxcy = np.array([[item[0], item[1]] for item in (val.box_list)])
for index,value in enumerate(epoch):
    save_path = f"/data1/joonsm/City_Layout/log_dir/CGL[FPN50]/checkpoints/checkpoint-{value}/pytorch_model.bin"
    model = CLDM(use_temp=False,backbone_name='resnet50')
    model.load_state_dict(torch.load(save_path, map_location=device))
    model = model.to(device)
    model.eval()
    zero_count = 0
    batch_value = []
    batch_size = []
    step = 0
    with torch.no_grad():
        for step,batch in tqdm(enumerate(dataset), total=len(dataset)):
            shape = batch['box'].shape
            noisy_batch = {'image':batch['image'].to(device),
                    'box': torch.rand(*shape, dtype=torch.float32, device=device)}
            for i in range(250)[::-1]:
                t = torch.tensor([i]*shape[0], device=device)
                with torch.no_grad():
                    noise_pred = model(noisy_batch, timesteps=t)
                    bbox_pred = diffusion.step(noise_pred, t[0].detach().item(),  noisy_batch['box'], return_dict=True)
                    noisy_batch['box'] = bbox_pred.prev_sample
            predicted = bbox_pred.prev_sample
            value,zero_count,size = seg_loss(predicted, batch,zero_count,cxcy,wh)
            print(zero_count)
            batch_value.append(value)
            batch_size.append(size)
            step +=1 

        final = sum(batch_value)/step
        final_size = sum(batch_size)/step
        src[index]={'score': final, 'zero_count':zero_count, 'size':final_size}
    del model
    gc.collect()

[32m2024-09-19 23:31:12.564[0m [1;30mINFO    [0m [34mlogger_set:26[0m -> Loading the resnet50 encoder
  0%|          | 1/542 [00:05<53:27,  5.93s/it]

0.32808908045977014
size here
[1.10442973673343e-05, 0.00035742091620340937, 0.006044052730780095, 0.004717467524499322]
0


  0%|          | 2/542 [00:09<42:01,  4.67s/it]

0.6990740740740741
size here
[0.004091740624854962, 0.0006930742425533634, 0.005858363249339163, 0.003964810081074636]
0


  1%|          | 3/542 [00:14<44:25,  4.95s/it]

0.09230769230769231
size here
[0.004043336057352524, 0.008262468109931797, 0.009676968113208811, 0.004897548348565275]
0


  1%|          | 4/542 [00:20<46:36,  5.20s/it]

0.579926575330987
size here
[0.0002087788433302194, 0.0033470623119423787, 0.00013034660397097473, 7.347458552879588e-05]
0


  1%|          | 5/542 [00:25<45:57,  5.14s/it]

0.7473544973544973
size here
[0.0005614581932779402, 0.004168895665199185, 0.007353968440927565, 0.0010816358519795662]
0


  1%|          | 6/542 [00:30<46:21,  5.19s/it]

0.9408022533022533
size here
[0.0007914381897232185, 4.539594566449522e-05, 0.003518863133651515, 0.00015649160300381488]
0


  1%|          | 6/542 [00:35<52:21,  5.86s/it]


KeyboardInterrupt: 