In [1]:
import os
import torch
import numpy as np
from torchvision.utils import save_image
from torchvision import transforms
from torch.nn.functional import one_hot
from edm2_pytorch.model import EDMPrecond  # 필요 시 경로 수정
from tqdm import tqdm

# 클래스 정의
class_list = ['Normal', 'Ischemic', 'Hemorrhagic']

In [2]:
# 설정값 (학습 시 사용한 것과 동일해야 함)
params = {
    'save_path': '../../result/edm2/CT',
    'model_path': '../../model/edm2/CT/model_epoch_671.pt',  # 적절히 수정
    'image_size': 256,
    'inch': 1,
    'outch': 1,
    'batch_size': 16,

    # EDM 관련 파라미터
    'sigma_min': 0.002,
    'sigma_max': 80.0,
    'sigma_data': 0.5,
    'rho': 7.0,

    # 모델 구조
    'cdim': 64,
    'channel_mult': [1, 2, 4, 8, 8],
    'attn_resolutions': [16],
    'layers_per_block': 2,
}

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda:1


In [3]:
# 모델 로드
model = EDMPrecond(
    img_resolution=params['image_size'],
    img_channels=params['inch'],
    label_dim=len(class_list),
    use_fp16=False,
    sigma_min=params['sigma_min'],
    sigma_max=params['sigma_max'],
    sigma_data=params['sigma_data'],
    model_type='DhariwalUNet',
    model_channels=params['cdim'],
    channel_mult=params['channel_mult'],
    channel_mult_emb=4,
    num_blocks=params['layers_per_block'],
    attn_resolutions=params['attn_resolutions'],
    dropout=0.1,
).to(device)

model.load_state_dict(torch.load(params['model_path'], map_location=device))
model.eval()

EDMPrecond(
  (model): DhariwalUNet(
    (map_noise): PositionalEmbedding()
    (map_layer0): Linear()
    (map_layer1): Linear()
    (map_label): Linear()
    (enc): ModuleDict(
      (256x256_conv): Conv2d()
      (256x256_block0): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (256x256_block1): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
      )
      (128x128_down): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
        (skip): Conv2d()
      )
      (128x128_block0): UNetBlock(
        (norm0): GroupNorm()
        (conv0): Conv2d()
        (affine): Linear()
        (norm1): GroupNorm()
        (conv1): Conv2d()
        (skip): Conv2d()
      )
      (128x128_block1)

In [7]:
# 샘플링 함수 정의
def transback(x):
    return (x.clamp(-1, 1) + 1) * 0.5

@torch.no_grad()
def sample_images():
    # 각 클래스 당 동일 개수로 배치 구성
    num_per_class = params['batch_size'] // len(class_list)
    label_list = []
    for i in range(len(class_list)):
        label_list.extend([i] * num_per_class)
    label_tensor = torch.tensor(label_list, device=device)
    class_onehot = one_hot(label_tensor, num_classes=len(class_list)).float()

    # 노이즈 초기화
    z = torch.randn(len(label_tensor), params['inch'], params['image_size'], params['image_size']).to(device)

    # EDM 샘플링 (Euler-like 방식)
    sigma = torch.full((z.shape[0], 1, 1, 1), params['sigma_max'], device=device)
    for _ in tqdm(range(18), desc="Sampling"):
        denoised = model(z, sigma, class_labels=class_onehot)
        d = (z - denoised) / sigma
        dt = -0.9 * sigma
        z = z + d * dt
        sigma = sigma + dt
        sigma = sigma.clamp(min=params['sigma_min'])

    samples = transback(z)
    os.makedirs(params['save_path'], exist_ok=True)
    save_image(samples, os.path.join(params['save_path'], 'sample_loaded_model.png'), nrow=num_per_class)
    print(f"샘플 이미지 저장 완료: {os.path.join(params['save_path'], 'sample_loaded_model.png')}")

# 실행
sample_images()

Sampling: 100%|██████████| 18/18 [00:02<00:00,  6.63it/s]


샘플 이미지 저장 완료: ../../result/edm2/CT/sample_loaded_model.png
