In [1]:
import json
import torch
import json
import os
from LayoutDM import CLDM
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler
from PIL import Image
import numpy as np
from dataset import ImageLayout
import gc
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Load Model with Seperate Param dict

In [2]:
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda:5


In [4]:
diffusion = DDPMScheduler(num_train_timesteps=250, prediction_type='sample', clip_sample=True)

In [3]:
def seg_loss(predicted_box, sample, zero_count):
    src = sample['sr']
    src_list = []
    
    for i in src:
        src_ = Image.open(i)
        src_list.append(src_)
    
    box = predicted_box.cpu().numpy()
    box = (box + 1) / 2
    
    match_list = []

    for i in range(box.shape[0]):
        img = src_list[i]
        width, height = img.size
        cx, cy, w, h = box[i]
        x = int((cx - w / 2) * width)
        y = int((cy - h / 2) * height)
        x2 = int((cx + w / 2) * width)
        y2 = int((cy + h / 2) * height)
        boxes = (x, y, x2, y2)
        
        # 이미지 크롭
        crop = img.crop(boxes)
        crop = np.array(crop)


        if crop.size == 0:
            print(f"Warning: Crop size is zero for box {boxes}.")
            match_list.append(0)  
            zero_count+=1
            continue
        
        blue_channel = crop[:, :, 2]
        blue_channel_flatten = blue_channel.flatten()
        

        match_pixel_size = np.sum(blue_channel_flatten == 128) / blue_channel_flatten.size
        match_list.append(match_pixel_size)
    
    # NaN 방지를 위해 match_list가 비어 있지 않은지 확인
    if len(match_list) == 0:
        print("Error: All crops have zero size. Returning NaN.")
        return float('nan')
    
    value = sum(match_list) / len(match_list)
    print(value)
    return value, zero_count

## Load Model with Seperate Param Dict

In [5]:
val = ImageLayout(type='val')
dataset = DataLoader(val, batch_size=256)

In [6]:
epoch =[50,100,150,200,250,300,350,400,450,500]

In [7]:
src={}
for index,value in enumerate(epoch):
    save_path = f"/workspace/joonsm/City_Layout/log_dir/FPN[18]_freeze/checkpoints/checkpoint-{value}/pytorch_model.bin"
    model = CLDM(use_temp=False,backbone_name='resnet18')
    model.load_state_dict(torch.load(save_path, map_location=device))
    model = model.to(device)
    model.eval()
    zero_count = 0
    batch_value = []
    step = 0
    with torch.no_grad():
        for step,batch in tqdm(enumerate(dataset), total=len(dataset)):
            shape = batch['box'].shape
            noisy_batch = {'image':batch['image'].to(device),
                    'box': torch.rand(*shape, dtype=torch.float32, device=device)}
            for i in range(250)[::-1]:
                t = torch.tensor([i]*shape[0], device=device)
                with torch.no_grad():
                    noise_pred = model(noisy_batch, timesteps=t)
                    bbox_pred = diffusion.step(noise_pred, t[0].detach().item(),  noisy_batch['box'], return_dict=True)
                    noisy_batch['box'] = bbox_pred.prev_sample
            predicted = bbox_pred.prev_sample
            value,zero_count = seg_loss(predicted, batch,zero_count)
            print(zero_count)
            batch_value.append(value)
            step +=1 

        final = sum(batch_value)/step
        src[index]={'score': final, 'zero_count':zero_count}
    del model
    gc.collect()

  0%|          | 0/9 [00:00<?, ?it/s]



 11%|█         | 1/9 [01:42<13:42, 102.81s/it]

0.3104280299220361
3


 22%|██▏       | 2/9 [03:27<12:05, 103.69s/it]

0.24534482550987977
5


 33%|███▎      | 3/9 [07:00<15:21, 153.60s/it]

0.3090861652759628
5


 44%|████▍     | 4/9 [10:25<14:31, 174.22s/it]

0.271743584698855
5


 56%|█████▌    | 5/9 [15:38<14:56, 224.11s/it]

0.3137387383284165
7


 67%|██████▋   | 6/9 [20:47<12:38, 252.99s/it]

0.33119956261887834
7
0.2659713160260648


 78%|███████▊  | 7/9 [26:02<09:06, 273.23s/it]

8


 89%|████████▉ | 8/9 [31:11<04:44, 284.68s/it]

0.3457719582384782
9


100%|██████████| 9/9 [33:40<00:00, 224.52s/it]

0.3912896419885556
9



 11%|█         | 1/9 [05:14<41:54, 314.30s/it]

0.47509861788060953
0


 22%|██▏       | 2/9 [10:29<36:44, 314.93s/it]

0.4476966377906927
0


 33%|███▎      | 3/9 [15:41<31:20, 313.35s/it]

0.4744161594759754
0


 44%|████▍     | 4/9 [20:58<26:14, 315.00s/it]

0.4348650918194513
0


 56%|█████▌    | 5/9 [26:10<20:55, 313.79s/it]

0.4733996354687169
0


 67%|██████▋   | 6/9 [31:24<15:42, 314.09s/it]

0.447535119338585
0


 78%|███████▊  | 7/9 [36:39<10:28, 314.31s/it]

0.4682645970257384
0


 89%|████████▉ | 8/9 [41:54<05:14, 314.57s/it]

0.3662558725231278
0


100%|██████████| 9/9 [44:21<00:00, 295.73s/it]

0.18601917521562333
0



 11%|█         | 1/9 [05:17<42:19, 317.46s/it]

0.42099564011685836
0


 22%|██▏       | 2/9 [10:34<37:01, 317.31s/it]

0.3585371541498864
0


 33%|███▎      | 3/9 [15:52<31:44, 317.48s/it]

0.4034783932790318
0


 44%|████▍     | 4/9 [21:11<26:30, 318.16s/it]

0.39744907762938875
0


 56%|█████▌    | 5/9 [26:31<21:14, 318.63s/it]

0.4218897083505425
0


 67%|██████▋   | 6/9 [31:47<15:53, 317.83s/it]

0.3490201139827277
0


 78%|███████▊  | 7/9 [37:10<10:38, 319.43s/it]

0.4168542910307083
0


 89%|████████▉ | 8/9 [42:26<05:18, 318.47s/it]

0.3622822011771834
0


100%|██████████| 9/9 [44:59<00:00, 299.89s/it]

0.3133703441199627
0



  0%|          | 0/9 [00:00<?, ?it/s]



 11%|█         | 1/9 [05:23<43:05, 323.20s/it]

0.3859369586407576
4


 22%|██▏       | 2/9 [10:46<37:43, 323.34s/it]

0.48320322958961553
4


 33%|███▎      | 3/9 [16:06<32:11, 321.89s/it]

0.46389580528580415
5


 44%|████▍     | 4/9 [21:29<26:51, 322.33s/it]

0.5073851197168262
6


 56%|█████▌    | 5/9 [26:52<21:30, 322.62s/it]

0.49603939562323995
6


 67%|██████▋   | 6/9 [32:16<16:08, 323.00s/it]

0.4645767183370223
8


 78%|███████▊  | 7/9 [37:34<10:42, 321.46s/it]

0.46284015576484855
10


 89%|████████▉ | 8/9 [43:01<05:22, 322.95s/it]

0.31324421904979827
11


100%|██████████| 9/9 [45:30<00:00, 303.35s/it]

0.23435695245324473
11



 11%|█         | 1/9 [05:16<42:14, 316.86s/it]

0.18298570152593407
0


 22%|██▏       | 2/9 [10:36<37:09, 318.51s/it]

0.13639918742912657
1


 33%|███▎      | 3/9 [15:55<31:52, 318.80s/it]

0.16445437214985276
1


 44%|████▍     | 4/9 [21:12<26:29, 317.88s/it]

0.15868652663198837
1


 56%|█████▌    | 5/9 [26:34<21:17, 319.48s/it]

0.2161375219316468
3


 67%|██████▋   | 6/9 [31:50<15:55, 318.47s/it]

0.1185071370896285
3


 78%|███████▊  | 7/9 [37:10<10:37, 318.73s/it]

0.20258022529719075
3


 89%|████████▉ | 8/9 [42:29<05:18, 318.96s/it]

0.38898653758107515
3


100%|██████████| 9/9 [45:02<00:00, 300.25s/it]

0.4177192871575171
3



  0%|          | 0/9 [00:00<?, ?it/s]



 11%|█         | 1/9 [05:11<41:33, 311.69s/it]

0.28771420691571287
2


 22%|██▏       | 2/9 [10:26<36:33, 313.39s/it]

0.2655054502370157
7


 33%|███▎      | 3/9 [15:40<31:23, 313.93s/it]

0.2851135696052294
10


 44%|████▍     | 4/9 [20:56<26:12, 314.42s/it]

0.24878294287235192
12


 56%|█████▌    | 5/9 [26:05<20:50, 312.55s/it]

0.29809780448547024
14


 67%|██████▋   | 6/9 [31:22<15:42, 314.24s/it]

0.3244971314248172
15


 78%|███████▊  | 7/9 [36:33<10:26, 313.20s/it]

0.2794524018440588
15


 89%|████████▉ | 8/9 [41:47<05:13, 313.41s/it]

0.44602016452964927
16


100%|██████████| 9/9 [44:15<00:00, 295.02s/it]

0.388015679349415
18



  0%|          | 0/9 [00:00<?, ?it/s]



 11%|█         | 1/9 [05:24<43:13, 324.21s/it]

0.4204906595829496
3


 22%|██▏       | 2/9 [10:48<37:50, 324.38s/it]

0.3765165324248293
12


 33%|███▎      | 3/9 [16:13<32:27, 324.53s/it]

0.4585153059918615
18


 44%|████▍     | 4/9 [21:35<26:57, 323.50s/it]

0.44973617897015267
21


 56%|█████▌    | 5/9 [26:59<21:35, 323.85s/it]

0.4756014007110805
23


 67%|██████▋   | 6/9 [32:24<16:12, 324.01s/it]

0.4649200763500673
29


 78%|███████▊  | 7/9 [37:48<10:48, 324.26s/it]

0.421651826612726
32


 89%|████████▉ | 8/9 [43:13<05:24, 324.43s/it]

0.4251353102799296
35


100%|██████████| 9/9 [45:46<00:00, 305.13s/it]

0.3910933793835847
35



 11%|█         | 1/9 [05:16<42:12, 316.52s/it]

0.5219199729026822
0


 22%|██▏       | 2/9 [10:33<36:58, 316.88s/it]

0.5295901497805284
0


 33%|███▎      | 3/9 [15:45<31:27, 314.51s/it]

0.5627610308450554
1


 44%|████▍     | 4/9 [21:04<26:22, 316.46s/it]

0.5775253091101801
3


 56%|█████▌    | 5/9 [26:14<20:56, 314.08s/it]

0.5877303670991939
5


 67%|██████▋   | 6/9 [31:34<15:47, 315.93s/it]

0.5353524513562563
5


 78%|███████▊  | 7/9 [36:48<10:30, 315.36s/it]

0.5440748562687371
6


 89%|████████▉ | 8/9 [42:05<05:15, 316.00s/it]

0.5622864804485738
10


100%|██████████| 9/9 [44:37<00:00, 297.46s/it]

0.4835903203651347
10



 11%|█         | 1/9 [05:21<42:54, 321.82s/it]

0.5232510667356967
0


 22%|██▏       | 2/9 [10:40<37:18, 319.72s/it]

0.4277935906191279
0


 33%|███▎      | 3/9 [16:01<32:01, 320.29s/it]

0.5391247885767612
0


 44%|████▍     | 4/9 [21:22<26:43, 320.60s/it]

0.5397740815463319
0


 56%|█████▌    | 5/9 [26:43<21:23, 320.76s/it]

0.48344349320085317
0


 67%|██████▋   | 6/9 [32:04<16:02, 320.89s/it]

0.47961641706701225
0


 78%|███████▊  | 7/9 [37:25<10:41, 320.84s/it]

0.49687659344760376
0


 89%|████████▉ | 8/9 [42:46<05:20, 320.99s/it]

0.5024776363846384
0


100%|██████████| 9/9 [45:17<00:00, 301.95s/it]

0.28755630633211793
0



 11%|█         | 1/9 [05:08<41:08, 308.54s/it]

0.5758247227186912
0


 22%|██▏       | 2/9 [10:25<36:34, 313.53s/it]

0.4959273199092962
0


 33%|███▎      | 3/9 [15:37<31:15, 312.61s/it]

0.541704970975854
0


 44%|████▍     | 4/9 [20:49<26:03, 312.69s/it]

0.5012515095734907
0


 56%|█████▌    | 5/9 [26:01<20:49, 312.44s/it]

0.4892898392575362
0


 67%|██████▋   | 6/9 [31:16<15:39, 313.24s/it]

0.5636140246548162
0


 78%|███████▊  | 7/9 [36:31<10:27, 313.69s/it]

0.5099135689968932
0


 89%|████████▉ | 8/9 [41:46<05:14, 314.15s/it]

0.5458432240683982
2


100%|██████████| 9/9 [44:14<00:00, 294.92s/it]

0.2582441341849833
2





In [8]:
src

{0: {'score': 0.3093970914007919, 'zero_count': 9},
 1: {'score': 0.4192834340598356, 'zero_count': 0},
 2: {'score': 0.38265299153736554, 'zero_count': 0},
 3: {'score': 0.42349761716235085, 'zero_count': 11},
 4: {'score': 0.2207173885326622, 'zero_count': 3},
 5: {'score': 0.31368881680708, 'zero_count': 18},
 6: {'score': 0.4315178522563535, 'zero_count': 35},
 7: {'score': 0.544981215352927, 'zero_count': 10},
 8: {'score': 0.47554599710112705, 'zero_count': 0},
 9: {'score': 0.4979570349266622, 'zero_count': 2}}