In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm
from glob import glob
import torch.optim as optim
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchinfo import summary
from PIL import Image
from edm2_pytorch.model import UNetEDM2,EMA
from edm2_pytorch.sampler import EDM2Schedule, EDM2Sampler
from edm2_pytorch.loss import edm2_loss
from edm2_pytorch.util import get_sigmas_karras_with_p


In [2]:
# 클래스 정의
class_list = ['Normal', 'Hemorrhagic']

# 하이퍼파라미터 설정 (빠른 수렴/확인용)
params = {
    'image_size': 256,             # 💡 작게 줄여서 빠르게 학습 가능
    'lr': 1e-4,                   # 💡 학습률 증가 (너무 크면 불안정할 수 있음)
    'batch_size': 32,            # 💡 적절히 큰 배치 (메모리 여유에 따라 조절)
    'epochs': 10000,              # 💡 빠른 실험을 위한 작은 epoch 수
    'data_path': '../../data/2D_CT/',
    'image_count': 10000,         # 💡 일부 데이터만 사용 (빠르게 epoch 반복)
    'inch': 1,
    'outch':1,
    'cdim': 64,
    'rho': 2.0,
    'threshold': 0.0,
    'save_every': 10,            # 💡 저장 주기를 자주 (결과 빨리 확인)
    'save_path': '/edm2/CT_fast',
    'P_mean':-1.2, 
    "P_std":1.2
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


Device: cuda:0


In [3]:

# 변환 정의
# trans = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
# ])
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])
def transback(x):
    return (x.clamp(-1, 1) + 1) * 0.5

# 이미지 로드
image_paths, image_labels = [], []
for i, cname in enumerate(class_list):
    paths = sorted(glob(os.path.join(params['data_path'], cname, '*.png')))[:params['image_count']]
    image_paths.extend(paths)
    image_labels.extend([i] * len(paths))

N = len(image_paths)
C, H, W = params['inch'], params['image_size'], params['image_size']
train_images = torch.zeros((N, C, H, W), dtype=torch.float32)

print("Loading images into tensor...")
for i, path in enumerate(tqdm(image_paths)):
    img = Image.open(path).convert('L').resize((W, H))
    train_images[i] = trans(img)

train_labels = torch.tensor(image_labels, dtype=torch.long)

# 커스텀 Dataset
class CustomDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):
        img = self.images[index]
        lab = self.labels[index]
        # if random.random() > 0.5:
        #     img = transforms.functional.hflip(img)
        # if random.random() > 0.5:
        #     img = transforms.functional.vflip(img)
        return img, lab

    def __len__(self):
        return len(self.images)

# DataLoader
train_dataset = CustomDataset(train_images, train_labels)
dataloader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True,drop_last=True)


Loading images into tensor...


100%|██████████| 1649/1649 [00:11<00:00, 144.54it/s]


In [4]:
# 모델 초기화
model = UNetEDM2(
    in_channels=params['inch'],
    out_channels=params['outch'],
    base_channels=params['cdim'],
    class_embed_dim=128,
    num_classes=len(class_list)
).to(device)

ema = EMA(model)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=params['lr'],
    betas=(0.9, 0.999),
    weight_decay=0.0
)

# 모델 요약
image_input = torch.randn(4, params['inch'], params['image_size'], params['image_size']).to(device)
sigma_input = torch.ones(4).to(device) * 10.0
class_input = torch.randint(0, len(class_list), (4,)).to(device)

sigmas = get_sigmas_karras_with_p(n=80, P_mean=params['P_mean'], P_std=params['P_std'], device=device).to(device)
model.eval()
summary(model, input_data=(image_input, sigma_input, class_input), col_names=["input_size", "output_size", "num_params"])





Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
UNetEDM2                                 [4, 1, 256, 256]          [4, 1, 256, 256]          --
├─Sequential: 1-1                        [4, 1]                    [4, 128]                  --
│    └─Linear: 2-1                       [4, 1]                    [4, 128]                  256
│    └─SiLU: 2-2                         [4, 128]                  [4, 128]                  --
│    └─Linear: 2-3                       [4, 128]                  [4, 128]                  16,512
├─Embedding: 1-2                         [4]                       [4, 128]                  256
├─Conv2d: 1-3                            [4, 1, 256, 256]          [4, 64, 256, 256]         640
├─ResidualBlock: 1-4                     [4, 64, 256, 256]         [4, 64, 256, 256]         --
│    └─GroupNorm: 2-4                    [4, 64, 256, 256]         [4, 64, 256, 256]         128
│    └─SiLU: 2-5           

In [5]:
# Training loop
for epc in range(params['epochs']):
    model.train()
    total_loss = 0
    steps = 0

    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, lab in tqdmDataLoader:
            img, lab = img.to(device), lab.to(device)
            optimizer.zero_grad()

            rnd_normal = torch.randn([img.shape[0], 1, 1, 1], device=img.device)
            sigma = (rnd_normal * params['P_std'] + params['P_mean']).exp()[:, 0, 0, 0]

            loss = edm2_loss(model, img, sigma, lab, cfg_drop_prob=0.1, sigma_data=1.0)
            loss.backward()
            optimizer.step()
            ema.update(model)
            total_loss += loss.item()
            steps += 1
            tqdmDataLoader.set_postfix({
                'epoch': epc + 1,
                'loss': total_loss / steps,
                'lr': optimizer.param_groups[0]['lr']
            })

    if epc % params['save_every'] == 0:
        ema.ema_model.eval()
        with torch.no_grad():
            each_device_batch = 4 // len(class_list)
            lab = torch.tensor([i % len(class_list) for i in range(4)], device=device)
            genshape = (len(lab), params['outch'], params['image_size'], params['image_size'])

            samples = EDM2Sampler(
                model=ema.ema_model,
                schedule=EDM2Schedule(steps=len(sigmas), P_mean=params['P_mean'], P_std=params['P_std'], rho=params['rho'], device=device),
                sampler_type='euler',
                cfg_scale=3.0,
                device=device
            ).sample(genshape, class_labels=lab)

            samples1 = torch.cat([samples, transback(img[0]).unsqueeze(0), transback(img[1]).unsqueeze(0)], dim=0)

        result_path = '../../result' + params['save_path']
        model_path = '../../model' + params['save_path']
        os.makedirs(result_path, exist_ok=True)
        os.makedirs(model_path, exist_ok=True)

        save_image(samples1, f'{result_path}/generated_{epc+1}_pict.png', nrow=each_device_batch)
        torch.save({
            'model': model.state_dict(),
            'ema': ema.ema_model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, f'{model_path}/ckpt_{epc+1}.pt')
        torch.cuda.empty_cache()

100%|██████████| 51/51 [00:20<00:00,  2.50it/s, epoch=1, loss=0.405, lr=0.0001]
Sampling (euler): 100%|██████████| 79/79 [00:03<00:00, 22.15it/s]
100%|██████████| 51/51 [00:20<00:00,  2.43it/s, epoch=2, loss=0.159, lr=0.0001]
100%|██████████| 51/51 [00:20<00:00,  2.53it/s, epoch=3, loss=0.127, lr=0.0001]
100%|██████████| 51/51 [00:20<00:00,  2.48it/s, epoch=4, loss=0.11, lr=0.0001] 
100%|██████████| 51/51 [00:19<00:00,  2.56it/s, epoch=5, loss=0.103, lr=0.0001] 
100%|██████████| 51/51 [00:20<00:00,  2.55it/s, epoch=6, loss=0.0996, lr=0.0001]
 10%|▉         | 5/51 [00:02<00:21,  2.13it/s, epoch=7, loss=0.0999, lr=0.0001]


KeyboardInterrupt: 

In [None]:
model.eval()
with torch.no_grad():
    sigma = torch.full((4,), 1.0, device=device)
    clean_img = img[0:4].to(device)  # shape: [4, C, H, W]
    noise = torch.randn_like(clean_img)
    noised_img = clean_img + sigma[:, None, None, None] * noise

    pred_noise = model(noised_img, sigma,lab[0:4])
    mse = ((pred_noise - noise) ** 2).mean()
    print("Noise prediction MSE:", mse.item())

Noise prediction MSE: 0.035668447613716125


In [None]:
model.eval()
with torch.no_grad():
    sigma = torch.ones(4, device=device)
    clean = img[:4]
    noise = torch.randn_like(clean)
    noised = clean + sigma[:, None, None, None] * noise
    pred = model(noised, sigma)

    print("Diff to noise:", ((pred - noise) ** 2).mean().item())
    print("Diff to x:", ((pred - clean) ** 2).mean().item())

Diff to noise: 0.02724508009850979
Diff to x: 1.5636131763458252


In [None]:
with torch.no_grad():
    sigma = torch.ones(4, device=device)
    clean = img[:4]
    noise = torch.randn_like(clean)
    noised = clean + sigma[:, None, None, None] * noise
    pred = model(noised, sigma)

    save_image(torch.cat([transback(clean), transback(noised), transback(pred), noised - sigma[:, None, None, None] * pred], dim=0),
               'denoise_debug.png', nrow=4)