In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm
from glob import glob
import torch.optim as optim
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchinfo import summary
from PIL import Image
from edm2_pytorch.model import SongUNet, DhariwalUNet, VPPrecond, VEPrecond, iDDPMPrecond, EDMPrecond



In [None]:
# 클래스 정의
class_list = ['Normal','Ischemic','Hemorrhagic']

params = {
    # 데이터 설정
    'data_path': '../../data/2D_CT/',
    'image_count': 10000,
    'image_size': 64,
    'inch': 1,
    'outch': 1,

    # 학습 설정
    'lr': 1e-4,
    'batch_size': 32,
    'epochs': 10000,
    'save_every': 10,
    'save_path': '../../result/edm2/CT_fast',

    # EDM 샘플링 관련
    'P_mean': -1.2,
    'P_std': 1.2,
    'rho': 7.0,
    'sigma_min': 0.002,
    'sigma_max': 80.0,
    'sigma_data': 0.5,
    'threshold': 0.0,

    # 모델 구조
    'cdim': 64,                        # base channels
    'channel_mult': [1, 2, 2, 2],      # 채널 증가 비율
    'attn_resolutions': [],           # self-attention이 들어갈 해상도 (예: [16])
    'layers_per_block': 4             # 각 레벨마다 residual block 수
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


Device: cuda:0


In [7]:

# 변환 정의
# trans = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
# ])
trans = transforms.Compose([
    transforms.ToTensor()
])
def transback(x):
    return (x.clamp(-1, 1) + 1) * 0.5

# 이미지 로드
image_paths, image_labels = [], []
for i, cname in enumerate(class_list):
    paths = sorted(glob(os.path.join(params['data_path'], cname, '*.png')))[:params['image_count']]
    image_paths.extend(paths)
    image_labels.extend([i] * len(paths))

N = len(image_paths)
C, H, W = params['inch'], params['image_size'], params['image_size']
train_images = torch.zeros((N, C, H, W), dtype=torch.float32)

print("Loading images into tensor...")
for i, path in enumerate(tqdm(image_paths)):
    img = Image.open(path).convert('L').resize((W, H))
    
    train_images[i] = trans(img)
train_images=train_images*2-1.
train_labels = torch.tensor(image_labels, dtype=torch.long)

# 커스텀 Dataset
class CustomDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):
        img = self.images[index]
        lab = self.labels[index]
        # if random.random() > 0.5:
        #     img = transforms.functional.hflip(img)
        # if random.random() > 0.5:
        #     img = transforms.functional.vflip(img)
        return img, lab

    def __len__(self):
        return len(self.images)

# DataLoader
train_dataset = CustomDataset(train_images, train_labels)
dataloader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True,drop_last=True)


Loading images into tensor...


100%|██████████| 2849/2849 [00:15<00:00, 186.13it/s]


In [5]:
# 모델 초기화
model = EDMPrecond(
    img_resolution=params['image_size'],
    img_channels=params['inch'],
    label_dim=len(class_list),
    use_fp16=False,
    sigma_min=params['sigma_min'],
    sigma_max=params['sigma_max'],
    sigma_data=params['sigma_data'],
    model_type='DhariwalUNet',  # 또는 'SongUNet'
    model_channels=params['cdim'],
    channel_mult=params['channel_mult'],
    channel_mult_emb=4,
    num_blocks=params['layers_per_block'],
    attn_resolutions=params['attn_resolutions'],
    dropout=0.1,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=params['lr'])

# 모델 요약
summary(
    model,
    input_data=(
        torch.randn(1, params['inch'], params['image_size'], params['image_size']).to(device),  # noised input
        torch.tensor([params['sigma_data']], device=device),  # sigma
        torch.nn.functional.one_hot(torch.tensor([0]), num_classes=len(class_list)).float().to(device)  # dummy class label
    ),
    col_names=["input_size", "output_size", "num_params", "kernel_size"],
    depth=4,
    verbose=1
)



Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
EDMPrecond                               [1, 1, 64, 64]            [1, 1, 64, 64]            --                        --
├─DhariwalUNet: 1-1                      [1, 1, 64, 64]            [1, 1, 64, 64]            --                        --
│    └─PositionalEmbedding: 2-1          [1]                       [1, 64]                   --                        --
│    └─Linear: 2-2                       [1, 64]                   [1, 256]                  16,640                    --
│    └─Linear: 2-3                       [1, 256]                  [1, 256]                  65,792                    --
│    └─Linear: 2-4                       [1, 3]                    [1, 256]                  768                       --
│    └─ModuleDict: 2-5                   --                        --                        --                        --
│    │    └─Co

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
EDMPrecond                               [1, 1, 64, 64]            [1, 1, 64, 64]            --                        --
├─DhariwalUNet: 1-1                      [1, 1, 64, 64]            [1, 1, 64, 64]            --                        --
│    └─PositionalEmbedding: 2-1          [1]                       [1, 64]                   --                        --
│    └─Linear: 2-2                       [1, 64]                   [1, 256]                  16,640                    --
│    └─Linear: 2-3                       [1, 256]                  [1, 256]                  65,792                    --
│    └─Linear: 2-4                       [1, 3]                    [1, 256]                  768                       --
│    └─ModuleDict: 2-5                   --                        --                        --                        --
│    │    └─Co

In [10]:
for epoch in range(1, params['epochs'] + 1):
    model.train()
    total_loss = 0.0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{params['epochs']}")
    for step, (imgs, labels) in enumerate(pbar, start=1):
        imgs, labels = imgs.to(device), labels.to(device)

        # EDM 논문 공식 σ 샘플링 방식 (log-normal)
        rnd_normal = torch.randn([imgs.shape[0]], device=imgs.device)
        sigmas = (params['sigma_data'] ** 2 + (rnd_normal * params['P_std'] + params['P_mean']).exp() ** 2).sqrt()

        # 노이즈 추가
        noise = torch.randn_like(imgs)
        noised = imgs + sigmas.view(-1, 1, 1, 1) * noise

        # 클래스 one-hot encoding
        class_onehot = torch.nn.functional.one_hot(labels, num_classes=len(class_list)).float()

        # 모델 forward 및 손실 계산
        denoised = model(noised, sigmas, class_labels=class_onehot)
        target = imgs
        loss = ((denoised - target) ** 2).mean()

        # 역전파 및 업데이트
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / step
        pbar.set_postfix(loss=f"{avg_loss:.4f}")

    # 주기적으로 샘플 저장
    if epoch % params['save_every'] == 0:
        model.eval()
        with torch.no_grad():
            z = torch.randn(params['batch_size'], params['inch'], params['image_size'], params['image_size']).to(device)
            rand_labels = torch.randint(0, len(class_list), (params['batch_size'],), device=device)
            class_onehot = torch.nn.functional.one_hot(rand_labels, num_classes=len(class_list)).float()

            # 샘플링 루프: Euler-style
            sigma = torch.full((params['batch_size'], 1, 1, 1), params['sigma_max'], device=device)
            for _ in range(18):
                denoised = model(z, sigma, class_labels=class_onehot)
                d = (z - denoised) / sigma
                dt = -0.9 * sigma
                z = z + d * dt
                sigma = sigma + dt
                sigma = sigma.clamp(min=params['sigma_min'])

            samples = transback(z)
            save_image(samples, os.path.join(params['save_path'], f'sample_epoch_{epoch}.png'), nrow=8)
            torch.save(model.state_dict(), os.path.join(params['save_path'], f'model_epoch_{epoch}.pt'))


Epoch 1/10000:  44%|████▍     | 39/89 [00:03<00:04, 10.99it/s, loss=0.0185]


KeyboardInterrupt: 

In [None]:
img.min()

In [None]:
model.eval()
with torch.no_grad():
    for s in [sigmas[0], sigmas[4], sigmas[8], sigmas[12]]:
        sigma = torch.full((4,), s, device=device)
        clean_img = img[0:4].to(device)  # shape: [4, C, H, W]
        noise = torch.randn_like(clean_img)
        noised_img = clean_img + sigma[:, None, None, None] * noise

        pred_noise = model(noised_img, sigma,lab[0:4])
        mse = ((pred_noise - noise) ** 2).mean()
        print(f"sigma={s}, MSE={mse.item()}")

sigma=3.3201162815093994, MSE=0.011015418916940689
sigma=1.3921059370040894, MSE=0.010887052863836288
sigma=0.5159232020378113, MSE=0.02971632033586502
sigma=0.16222049295902252, MSE=0.08270673453807831


In [None]:
model.eval()
with torch.no_grad():
    sigma = torch.ones(4, device=device)
    clean = img[:4]
    noise = torch.randn_like(clean)
    noised = clean + sigma[:, None, None, None] * noise
    pred = model(noised, sigma)

    print("Diff to noise:", ((pred - noise) ** 2).mean().item())
    print("Diff to x:", ((pred - clean) ** 2).mean().item())

Diff to noise: 0.03787504509091377
Diff to x: 1.5307466983795166


In [None]:
with torch.no_grad():
    sigma = torch.ones(4, device=device)*sigmas[-1]
    clean = img[:4]
    noise = torch.randn_like(clean)
    noised = clean + sigma[:, None, None, None] * noise
    pred = model(noised, sigma)

    save_image(torch.cat([transback(clean), transback(noised), transback(pred), noised - sigma[:, None, None, None] * pred], dim=0),
               'denoise_debug.png', nrow=4)

In [None]:
x = torch.zeros(4, 1, 256, 256).to('cuda')
a = torch.randn_like(x)
b = torch.randn(4, 1, 256, 256).to('cuda')

print("a stats:", a.mean().item(), a.std().item())
print("b stats:", b.mean().item(), b.std().item())

a stats: -0.0005130228237248957 0.9990500807762146
b stats: -0.0020944587886333466 0.9990535974502563
