In [2]:
import json
import torch
import json
import os
from LayoutDM import CLDM
from dataset import ImageLayout_Val
from torch.utils.data import Dataset, DataLoader
from diffusers import DDPMScheduler
from PIL import Image
import numpy as np
import gc
from tqdm import tqdm

In [3]:
diffusion = DDPMScheduler(num_train_timesteps= 1000)

In [4]:
device = f"cuda:{6}" if torch.cuda.is_available() else "cpu"

In [5]:
def seg_loss(predicted_box, sample, zero_count):
    src = sample['sr']
    src_list = []
    
    for i in src:
        src_ = Image.open(i)
        src_list.append(src_)
    
    box = predicted_box.cpu().numpy()
    box = (box + 1) / 2
    
    match_list = []

    for i in range(box.shape[0]):
        img = src_list[i]
        width, height = img.size
        cx, cy, w, h = box[i]
        x = int((cx - w / 2) * width)
        y = int((cy - h / 2) * height)
        x2 = int((cx + w / 2) * width)
        y2 = int((cy + h / 2) * height)
        boxes = (x, y, x2, y2)
        
        # 이미지 크롭
        crop = img.crop(boxes)
        crop = np.array(crop)


        if crop.size == 0:
            print(f"Warning: Crop size is zero for box {boxes}.")
            match_list.append(0)  
            zero_count+=1
            continue
        
        blue_channel = crop[:, :, 2]
        blue_channel_flatten = blue_channel.flatten()
        

        match_pixel_size = np.sum(blue_channel_flatten == 128) / blue_channel_flatten.size
        match_list.append(match_pixel_size)
    
    # NaN 방지를 위해 match_list가 비어 있지 않은지 확인
    if len(match_list) == 0:
        print("Error: All crops have zero size. Returning NaN.")
        return float('nan')
    
    value = sum(match_list) / len(match_list)
    print(value)
    return value, zero_count

## Load Model with Seperate Param dict

In [6]:
val_path = "/nas2/lait/1000_Members/jjoonvely/carla_new/pre_seg_combined_val.json" 

In [7]:
val = ImageLayout_Val(val_path)
dataset = DataLoader(val, batch_size=256)

In [8]:
epoch = [100,150,200,250,300]

In [11]:
src={0: {'score': 0.5814610461090901, 'zero_count': 67}}
for index,value in enumerate(epoch):
    save_path = f"/workspace/joonsm/City_Layout/log_dir/baseline_512/checkpoints/checkpoint-{value}/pytorch_model.bin"
    model = CLDM(use_temp=False)
    model.load_state_dict(torch.load(save_path, map_location=device))
    model = model.to(device)
    model.eval()
    zero_count = 0
    batch_value = []
    step = 0
    with torch.no_grad():
        for step,batch in tqdm(enumerate(dataset), total=len(dataset)):
            shape = batch['box'].shape
            noisy_batch = {'image':batch['image'].to(device),
                    'box': torch.rand(*shape, dtype=torch.float32, device=device)}
            for i in range(250)[::-1]:
                t = torch.tensor([i]*shape[0], device=device)
                with torch.no_grad():
                    noise_pred = model(noisy_batch, timesteps=t)
                    bbox_pred = diffusion.step(noise_pred, t[0].detach().item(),  noisy_batch['box'], return_dict=True)
                    noisy_batch['box'] = bbox_pred.prev_sample
            predicted = bbox_pred.prev_sample
            value,zero_count = seg_loss(predicted, batch,zero_count)
            print(zero_count)
            batch_value.append(value)
            step +=1 

        final = sum(batch_value)/step
        src[index]={'score': final, 'zero_count':zero_count}
    del model
    gc.collect()


  0%|          | 0/9 [00:00<?, ?it/s]



 11%|█         | 1/9 [03:53<31:11, 233.90s/it]

0.6671114893710176
3


 22%|██▏       | 2/9 [07:16<25:08, 215.44s/it]

0.6979320155629881
3


 33%|███▎      | 3/9 [10:33<20:42, 207.00s/it]

0.680714820132087
3


 44%|████▍     | 4/9 [14:07<17:30, 210.01s/it]

0.7080360125993962
5


 56%|█████▌    | 5/9 [18:13<14:51, 222.98s/it]

0.6916217882870705
6


 67%|██████▋   | 6/9 [22:04<11:16, 225.58s/it]

0.6976600310917832
6


 78%|███████▊  | 7/9 [25:45<07:28, 224.01s/it]

0.6985024807333452
9


 89%|████████▉ | 8/9 [29:49<03:50, 230.43s/it]

0.5393047652732719
10


100%|██████████| 9/9 [31:50<00:00, 212.33s/it]

0.4319607455510853
10



 11%|█         | 1/9 [03:54<31:13, 234.17s/it]

0.7807626370025701
0


 22%|██▏       | 2/9 [07:37<26:33, 227.58s/it]

0.822963909222528
1


 33%|███▎      | 3/9 [11:22<22:40, 226.77s/it]

0.8227911538353345
2


 44%|████▍     | 4/9 [15:06<18:46, 225.36s/it]

0.8237969517698219
3


 56%|█████▌    | 5/9 [18:47<14:55, 223.85s/it]

0.7763608246670627
3


 67%|██████▋   | 6/9 [22:40<11:21, 227.04s/it]

0.8062254853459676
3


 78%|███████▊  | 7/9 [26:30<07:35, 227.87s/it]

0.7874479116092563
4


 89%|████████▉ | 8/9 [30:12<03:46, 226.11s/it]

0.6987819423734535
6


100%|██████████| 9/9 [32:16<00:00, 215.20s/it]

0.6282182984460464
6



 11%|█         | 1/9 [03:54<31:15, 234.47s/it]

0.7631597852310826
0


 22%|██▏       | 2/9 [07:25<25:45, 220.72s/it]

0.8021768269353181
0


 33%|███▎      | 3/9 [11:28<23:04, 230.72s/it]

0.8357307153256627
0


 44%|████▍     | 4/9 [15:13<19:02, 228.42s/it]

0.8087874543675672
0


 56%|█████▌    | 5/9 [18:44<14:49, 222.28s/it]

0.8169837780649737
0


 67%|██████▋   | 6/9 [22:38<11:18, 226.28s/it]

0.803213278230208
0


 78%|███████▊  | 7/9 [26:27<07:34, 227.32s/it]

0.793596257860088
0


 89%|████████▉ | 8/9 [30:23<03:49, 229.83s/it]

0.7189074356853353
0


100%|██████████| 9/9 [32:20<00:00, 215.59s/it]

0.5624968278457235
0



 11%|█         | 1/9 [03:41<29:28, 221.08s/it]

0.7516591225201111
0


 22%|██▏       | 2/9 [07:19<25:38, 219.75s/it]

0.7620499306694175
0


 33%|███▎      | 3/9 [11:25<23:09, 231.61s/it]

0.7686184826685093
0


 44%|████▍     | 4/9 [15:13<19:10, 230.16s/it]

0.8101438627912981
0


 56%|█████▌    | 5/9 [18:59<15:14, 228.60s/it]

0.7700208553959342
0


 67%|██████▋   | 6/9 [22:34<11:12, 224.07s/it]

0.7783066706833531
0


 78%|███████▊  | 7/9 [26:29<07:34, 227.49s/it]

0.7992962476899308
0


 89%|████████▉ | 8/9 [30:28<03:51, 231.37s/it]

0.6564657910550509
0


100%|██████████| 9/9 [32:32<00:00, 216.90s/it]

0.5618794930759113
0



 11%|█         | 1/9 [03:58<31:46, 238.27s/it]

0.8184171286232694
0


 22%|██▏       | 2/9 [07:45<27:03, 231.88s/it]

0.8720270949410808
0


 33%|███▎      | 3/9 [11:27<22:44, 227.47s/it]

0.8555487813278698
0


 44%|████▍     | 4/9 [15:03<18:34, 222.83s/it]

0.8876349727279031
0


 56%|█████▌    | 5/9 [19:04<15:18, 229.52s/it]

0.882640431692447
0


 67%|██████▋   | 6/9 [22:51<11:25, 228.49s/it]

0.8936254381489235
0


 78%|███████▊  | 7/9 [26:26<07:28, 224.18s/it]

0.8907291202407682
0


 89%|████████▉ | 8/9 [30:19<03:47, 227.00s/it]

0.7485442818691056
0


100%|██████████| 9/9 [32:33<00:00, 217.05s/it]

0.675669929832084
0





In [12]:
src

{0: {'score': 0.6458715720668938, 'zero_count': 10},
 1: {'score': 0.77192767936356, 'zero_count': 6},
 2: {'score': 0.767228039949551, 'zero_count': 0},
 3: {'score': 0.7398267173943907, 'zero_count': 0},
 4: {'score': 0.8360930199337169, 'zero_count': 0}}