In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
from glob import glob
import torch.optim as optim
import random
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchinfo import summary
from PIL import Image
from edm2_pytorch.model import EDM2UNet
from edm2_pytorch.loss import edm2_loss
from edm2_pytorch.sampler import edm2_sample
from edm2_pytorch.util import get_sigmas_karras


In [None]:


# 클래스 정의
class_list = ['Normal', 'Hemorrhagic']

# 하이퍼파라미터 설정
params = {
    'image_size': 64,
    'lr': 2e-5,
    'batch_size': 64,
    'epochs': 20000,
    'data_path': '../../data/2D_CT/',
    'image_count': 5000,
    'inch': 1,
    'outch': 1,
    'cdim': len(class_list),
    'sigma_min': 0.01,
    'sigma_max': 80.0,
    'rho': 3,
    'threshold': 0.0,
    'save_every': 5,
    'save_path': '/edm2/CT'
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


In [None]:

# 변환 정의
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def transback(data: torch.Tensor) -> torch.Tensor:
    return data * 0.5 + 0.5

# 이미지 로드
image_paths, image_labels = [], []
for i, cname in enumerate(class_list):
    paths = sorted(glob(os.path.join(params['data_path'], cname, '*.png')))[:params['image_count']]
    image_paths.extend(paths)
    image_labels.extend([i] * len(paths))

N = len(image_paths)
C, H, W = params['inch'], params['image_size'], params['image_size']
train_images = torch.zeros((N, C, H, W))

print("Loading images into tensor...")
for i, path in enumerate(tqdm(image_paths)):
    img = Image.open(path).convert('L').resize((W, H))
    train_images[i] = trans(img)

train_labels = torch.tensor(image_labels, dtype=torch.long)

# 커스텀 Dataset
class CustomDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):
        img = self.images[index]
        lab = self.labels[index]
        if random.random() > 0.5:
            img = transforms.functional.hflip(img)
        if random.random() > 0.5:
            img = transforms.functional.vflip(img)
        return img, lab

    def __len__(self):
        return len(self.images)

# DataLoader
train_dataset = CustomDataset(train_images, train_labels)
dataloader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)


In [None]:

# 모델 초기화
model = EDM2UNet(
    in_ch=params['inch'],
    base=64,
    cond_dim=256,
    num_classes=params['cdim']
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=1e-4)

# 모델 요약
image_input = torch.randn(4, 1, params['image_size'], params['image_size']).to(device)
sigma_input = torch.ones(4).to(device) * 10.0
class_input = torch.randint(0, len(class_list), (4,)).to(device)
model.eval()
summary(model, input_data=(image_input, sigma_input, class_input), col_names=["input_size", "output_size", "num_params"])

# Noise schedule
sigmas = get_sigmas_karras(n=50, sigma_min=params['sigma_min'], sigma_max=params['sigma_max'], rho=params['rho']).to(device)



In [None]:
# Training loop
for epc in range(params['epochs']):
    model.train()
    total_loss = 0
    steps = 0

    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, lab in tqdmDataLoader:
            img, lab = img.to(device), lab.to(device)
            optimizer.zero_grad()

            sigma_idx = torch.randint(0, len(sigmas), (img.shape[0],), device=device)
            sigma = sigmas[sigma_idx]

            if random.random() < params['threshold']:
                mask = torch.rand(lab.shape[0], device=device) < 0.5
                lab[mask] = -1

            loss = edm2_loss(model, img, sigma, lab)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            steps += 1
            tqdmDataLoader.set_postfix({
                'epoch': epc + 1,
                'loss': total_loss / steps,
                'lr': optimizer.param_groups[0]['lr']
            })

    # Sample & Save
    if epc % params['save_every'] == 0:
        model.eval()
        with torch.no_grad():
            each_device_batch = params['batch_size'] // len(class_list)
            lab = torch.arange(len(class_list)).repeat(each_device_batch).to(device)
            genshape = (len(lab), params['outch'], params['image_size'], params['image_size'])

            samples = edm2_sample(
                model=model,
                shape=genshape,
                sigmas=sigmas,
                class_label=lab,
                cfg_scale=2.0,
                device=device
            )
            samples = transback(samples)

        result_path = '../../result' + params['save_path']
        model_path = '../../model' + params['save_path']
        os.makedirs(result_path, exist_ok=True)
        os.makedirs(model_path, exist_ok=True)

        save_image(samples, f'{result_path}/generated_{epc+1}_pict.png', nrow=each_device_batch)
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, f'{model_path}/ckpt_{epc+1}.pt')
        torch.cuda.empty_cache()

In [None]:
print("Sample stats:")
print("min:", samples.min().item())
print("max:", samples.max().item())
print("mean:", samples.std().item())

In [None]:
np_img = (transback(img[0]).permute(1, 2, 0).cpu().numpy() * 255).astype('uint8')

# PIL 이미지로 변환
Image.fromarray(np_img[:,:,0], mode='L')

In [None]:
np_img.shape