In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm
from glob import glob
import torch.optim as optim
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchinfo import summary
from PIL import Image
from edm2_pytorch.model import EDM2UNet,EMA,sample_with_ema,euler_sampler
from edm2_pytorch.loss import edm2_loss
from edm2_pytorch.util import get_sigmas_karras_with_p


In [2]:
# 클래스 정의
class_list = ['Normal', 'Hemorrhagic']

# 하이퍼파라미터 설정 (빠른 수렴/확인용)
params = {
    'image_size': 256,             # 💡 작게 줄여서 빠르게 학습 가능
    'lr': 1e-4,                   # 💡 학습률 증가 (너무 크면 불안정할 수 있음)
    'batch_size': 32,            # 💡 적절히 큰 배치 (메모리 여유에 따라 조절)
    'epochs': 10000,              # 💡 빠른 실험을 위한 작은 epoch 수
    'data_path': '../../data/2D_CT/',
    'image_count': 10000,         # 💡 일부 데이터만 사용 (빠르게 epoch 반복)
    'inch': 1,
    'outch':1,
    'cdim': 64,
    'sigma_min': 0.01,
    'sigma_max': 1.0,
    'rho': 1.0,
    'threshold': 0.0,
    'save_every': 10,            # 💡 저장 주기를 자주 (결과 빨리 확인)
    'save_path': '/edm2/CT_fast',
    'P_mean':-0.4, 
    "P_std":1.0
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


Device: cuda:0


In [3]:

# 변환 정의
# trans = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
# ])
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])
def transback(x):
    return (x.clamp(-1, 1) + 1) * 0.5

# 이미지 로드
image_paths, image_labels = [], []
for i, cname in enumerate(class_list):
    paths = sorted(glob(os.path.join(params['data_path'], cname, '*.png')))[:params['image_count']]
    image_paths.extend(paths)
    image_labels.extend([i] * len(paths))

N = len(image_paths)
C, H, W = params['inch'], params['image_size'], params['image_size']
train_images = torch.zeros((N, C, H, W), dtype=torch.float32)

print("Loading images into tensor...")
for i, path in enumerate(tqdm(image_paths)):
    img = Image.open(path).convert('L').resize((W, H))
    train_images[i] = trans(img)

train_labels = torch.tensor(image_labels, dtype=torch.long)

# 커스텀 Dataset
class CustomDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):
        img = self.images[index]
        lab = self.labels[index]
        # if random.random() > 0.5:
        #     img = transforms.functional.hflip(img)
        # if random.random() > 0.5:
        #     img = transforms.functional.vflip(img)
        return img, lab

    def __len__(self):
        return len(self.images)

# DataLoader
train_dataset = CustomDataset(train_images, train_labels)
dataloader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True,drop_last=True)


Loading images into tensor...


100%|██████████| 1649/1649 [00:11<00:00, 148.15it/s]


In [None]:

# 모델 초기화
model = EDM2UNet(
    in_ch=params['inch'],
    base=64,
    cond_dim=256,
    num_classes=params['cdim']
).to(device)

ema = EMA(model)  # ✅ EMA 추가
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=params['lr'],
    betas=(0.9, 0.999),
    weight_decay=0.0
)

# 모델 요약
image_input = torch.randn(4, params['inch'], params['image_size'], params['image_size']).to(device)
sigma_input = torch.ones(4).to(device) * 10.0
class_input = torch.randint(0, len(class_list), (4,)).to(device)
# Noise schedule
sigmas =get_sigmas_karras_with_p(n=15,P_mean=params['P_mean'], P_std=params['P_std'], device=device).to(device)
model.eval()
summary(model, input_data=(image_input, sigma_input, class_input), col_names=["input_size", "output_size", "num_params"])





Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
EDM2UNet                                 [4, 1, 256, 256]          [4, 1, 256, 256]          1
├─Embedding: 1-1                         [4]                       [4, 256]                  16,384
├─Sequential: 1-2                        [4, 256]                  [4, 256]                  --
│    └─Linear: 2-1                       [4, 256]                  [4, 256]                  65,792
│    └─SiLU: 2-2                         [4, 256]                  [4, 256]                  --
│    └─Linear: 2-3                       [4, 256]                  [4, 256]                  65,792
├─Conv2d: 1-3                            [4, 1, 256, 256]          [4, 64, 256, 256]         640
├─DownBlock: 1-4                         [4, 64, 256, 256]         [4, 128, 128, 128]        --
│    └─ResBlock: 2-4                     [4, 64, 256, 256]         [4, 128, 256, 256]        --
│    │    └─GroupNorm: 

In [5]:
# Training loop
for epc in range(params['epochs']):
    model.train()
    total_loss = 0
    steps = 0

    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, lab in tqdmDataLoader:
            img, lab = img.to(device), lab.to(device)
            optimizer.zero_grad()

            rnd_normal = torch.randn([img.shape[0], 1, 1, 1], device=img.device)
            sigma = (rnd_normal * params['P_std'] + params['P_mean']).exp()[:,0,0,0]
            if random.random() < params['threshold']:
                mask = torch.rand(lab.shape[0], device=device) < 0.5
                lab[mask] = -1

            loss = edm2_loss(model, img, sigma, lab, cfg_drop_prob=0.1, sigma_data=0.5)
            loss.backward()
            optimizer.step()
            ema.update()
            total_loss += loss.item()
            steps += 1
            tqdmDataLoader.set_postfix({
                'epoch': epc + 1,
                'loss': total_loss / steps,
                'lr': optimizer.param_groups[0]['lr']
            })

    if epc % params['save_every'] == 0:
        ema.ema_model.eval()  # ✅ EMA 모델 사용
        with torch.no_grad():
            each_device_batch = 4 // len(class_list)
            lab = torch.arange(len(class_list)).repeat(each_device_batch).to(device)
            genshape = (len(lab), params['outch'], params['image_size'], params['image_size'])

            samples = sample_with_ema(
                model=ema,  # ✅ EMA 모델로 샘플링
                shape=genshape,
                sigmas=sigmas,
                class_label=lab,
                cfg_scale=1.0,
                device=device
            )
            samples=torch.concatenate([samples,img[0].unsqueeze(0)],dim=0)
            samples=torch.concatenate([samples,img[1].unsqueeze(0)],dim=0)
            samples1 = transback(samples)

        result_path = '../../result' + params['save_path']
        model_path = '../../model' + params['save_path']
        os.makedirs(result_path, exist_ok=True)
        os.makedirs(model_path, exist_ok=True)

        save_image(samples1, f'{result_path}/generated_{epc+1}_pict.png', nrow=each_device_batch)
        torch.save({
            'model': model.state_dict(),
            'ema': ema.state_dict(),  # ✅ EMA도 저장
            'optimizer': optimizer.state_dict(),
        }, f'{model_path}/ckpt_{epc+1}.pt')
        torch.cuda.empty_cache()

100%|██████████| 51/51 [00:28<00:00,  1.78it/s, epoch=1, loss=16.1, lr=0.0001]
100%|██████████| 14/14 [00:01<00:00,  7.26it/s]
100%|██████████| 51/51 [00:28<00:00,  1.77it/s, epoch=2, loss=9.21, lr=0.0001]
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=3, loss=6.3, lr=0.0001] 
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=4, loss=9.3, lr=0.0001] 
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=5, loss=5.09, lr=0.0001]
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=6, loss=4.8, lr=0.0001] 
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=7, loss=4.09, lr=0.0001]
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=8, loss=13.5, lr=0.0001]
100%|██████████| 51/51 [00:28<00:00,  1.79it/s, epoch=9, loss=10.2, lr=0.0001]
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=10, loss=4.41, lr=0.0001]
100%|██████████| 51/51 [00:28<00:00,  1.80it/s, epoch=11, loss=4.22, lr=0.0001]
100%|██████████| 14/14 [00:01<00:00,  7.27it/s]
100%|██████████| 51/51 [00:28<00:

KeyboardInterrupt: 

In [6]:
lab

tensor([0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0,
        1, 1, 1, 0, 1, 1, 1, 1], device='cuda:0')