In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from glob import glob
import random
from conditionDiffusion.unet import ImprovedUnet
from conditionDiffusion.utils import get_named_beta_schedule
from conditionDiffusion.diffusion import GaussianDiffusion
from conditionDiffusion.Scheduler import GradualWarmupScheduler
from PIL import Image
import torchvision

print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda", 0)
print(f"Device:\t\t{device}")
topilimage = torchvision.transforms.ToPILImage()

def createDirectory(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print("Error: Failed to create the directory.")

In [None]:
params = {
    'image_size': 512,
    'lr': 1e-5,           # ‚¨ÜÔ∏è 2e-5 ‚Üí 1e-4 (AMPÏôÄ Ìï®Íªò ÏÇ¨Ïö© Ïãú ÏïàÏ†ïÏ†Å)
    'beta1': 0.9,
    'beta2': 0.999,
    'batch_size': 1,
    'epochs': 1000,
    'n_classes': None,
    'image_count': 100,
    'inch': 3,
    'mask_ch': 3,
    'modch': 64,
    'outch': 3,
    'chmul': [1, 2, 4, 8],
    'numres': 2,
    'dtype': torch.float32,
    'cdim': 10,
    'useconv': True,
    'droprate': 0.1,
    'T': 1000,
    'w': 1.8,
    'v': 0.3,
    'multiplier': 1.0,
    'threshold': 0.1,
    'ddim': True,
    'gen_n': 8,
    'use_checkpoint': True,
    'num_heads': 4,
    'ema_decay': 0.9999,
    'grad_clip': 1.0,      # ‚≠ê Gradient clipping Ï∂îÍ∞Ä
    'warmup_epochs': 100,  # ‚≠ê Warmup Í∏∞Í∞Ñ Î™ÖÏãú
}

In [None]:
trans = transforms.Compose([
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

def transback(data):
    return data / 2 + 0.5

class CustomDataset(Dataset):
    def __init__(self, params, images, mask):
        self.images = images
        self.masks = mask
        self.args = params
        
    def trans(self, image, mask):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            mask = transform(mask)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            mask = transform(mask)
        return image, mask
    
    def __getitem__(self, index):
        image = self.images[index]
        mask = self.masks[index]
        image, mask = self.trans(image, mask)
        return image, mask
    
    def __len__(self):
        return len(self.images)

# Load data
ihc_image_list = glob('../../data/IHC4BC_Compressed/**/HER2/IHC/*.jpg')
image_temp_list = glob('../../data/IHC4BC_Compressed/**/Ki67/IHC/*.jpg')
ihc_image_list.extend(image_temp_list)
he_image_list = [p.replace('/IHC/', '/HE/') for p in ihc_image_list]

if len(ihc_image_list) > params['image_count']:
    ihc_image_list = ihc_image_list[:params['image_count']]
    he_image_list = he_image_list[:params['image_count']]

train_ihc_image = torch.zeros((len(ihc_image_list), params['inch'], 
                               params['image_size'], params['image_size']))
train_he_image = torch.zeros((len(ihc_image_list), params['inch'], 
                              params['image_size'], params['image_size']))

for i in tqdm(range(len(ihc_image_list))):
    train_ihc_image[i] = trans(transforms.ToTensor()(
        Image.open(ihc_image_list[i]).convert('RGB').resize(
            (params['image_size'], params['image_size']))))
    train_he_image[i] = trans(transforms.ToTensor()(
        Image.open(he_image_list[i]).convert('RGB').resize(
            (params['image_size'], params['image_size']))))
    
train_dataset = CustomDataset(params, train_ihc_image, train_he_image)
train_dataloader = DataLoader(train_dataset, batch_size=params['batch_size'], 
                             num_workers=4, drop_last=True, pin_memory=True)

In [None]:
net = ImprovedUnet(
    in_ch=params['inch'] + params['mask_ch'],
    mod_ch=params['modch'],
    out_ch=params['outch'],
    ch_mul=params['chmul'],
    num_res_blocks=params['numres'],
    cdim=params['cdim'],
    use_conv=params['useconv'],
    droprate=params['droprate'],
    num_heads=params['num_heads'],
    use_checkpoint=params['use_checkpoint'],
    dtype=params['dtype']
).to(device)

betas = get_named_beta_schedule(num_diffusion_timesteps=params['T'])
diffusion = GaussianDiffusion(
    dtype=params['dtype'],
    model=net,
    betas=betas,
    w=params['w'],
    v=params['v'],
    device=device
)
optimizer = torch.optim.AdamW(
    diffusion.model.parameters(),
    lr=params['lr'],
    betas=(params['beta1'], params['beta2']),
    weight_decay=2e-5
)

# üîß ÏàòÏ†ï: Cosine Scheduler T_max ÏàòÏ†ï
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer=optimizer,
    T_max=params['epochs'] - params['warmup_epochs'],  # ‚≠ê Ï†ÑÏ≤¥ epochsÏóêÏÑú warmup Ï†úÏô∏
    eta_min=1e-6  # ‚≠ê ÏµúÏÜå lr ÏÑ§Ï†ï
)

warmUpScheduler = GradualWarmupScheduler(
    optimizer=optimizer,
    multiplier=params['multiplier'],
    warm_epoch=params['warmup_epochs'],  # ‚≠ê Î™ÖÏãúÏ†ÅÏúºÎ°ú warmup Í∏∞Í∞Ñ ÏÑ§Ï†ï
    after_scheduler=cosineScheduler,
    last_epoch=0
)

from copy import deepcopy

ema_model = deepcopy(diffusion.model)
ema_model.eval()

def update_ema(ema_model, model, decay=0.9999):
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)

In [None]:
scaler = torch.cuda.amp.GradScaler()

# Gradient ÌÜµÍ≥Ñ Ï∂îÏ†Å
grad_norm_history = []

for epc in range(params['epochs']):
    diffusion.model.train()
    total_loss = 0
    steps = 0
    
    with tqdm(train_dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, mask in tqdmDataLoader:
            b = img.shape[0]
            
            cond_ihc = img.to(device)
            x_0 = mask.to(device)
            
            # Classifier-free guidanceÏö© cemb
            cemb = torch.zeros((b, params['cdim']), device=device)
            
            # Conditioning Ïù¥ÎØ∏ÏßÄ dropout
            cond_input = cond_ihc.clone()
            drop_idx = (torch.rand(b, device=device) < params['threshold'])
            if drop_idx.any():
                cond_input[drop_idx] = 0
            
            # üîß ÏàòÏ†ï: AMP ÏÇ¨Ïö© + Gradient Clipping
            optimizer.zero_grad(set_to_none=True)
            
            with torch.cuda.amp.autocast():
                loss = diffusion.trainloss(x_0, mask=cond_input, cemb=cemb)
                
            
            scaler.scale(loss).backward()
            
            # ‚≠ê Gradient Clipping (NaN Î∞©ÏßÄÏùò ÌïµÏã¨!)
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(
                diffusion.model.parameters(), 
                params['grad_clip']
            )
            grad_norm_history.append(grad_norm.item())
            
            scaler.step(optimizer)
            scaler.update()
            
            # ‚≠ê EMA ÏóÖÎç∞Ïù¥Ìä∏
            if steps % 10 == 0:  # 10 Ïä§ÌÖùÎßàÎã§ ÏóÖÎç∞Ïù¥Ìä∏
                update_ema(ema_model, diffusion.model, params['ema_decay'])
            
            steps += 1
            total_loss += loss.item()
            
            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epc + 1,
                    "loss": f"{total_loss / steps:.4f}",
                    "grad_norm": f"{grad_norm.item():.4f}",
                    "LR": f"{optimizer.state_dict()['param_groups'][0]['lr']:.2e}",
                    "scale": f"{scaler.get_scale():.0f}"
                }
            )
    
    warmUpScheduler.step()
    

    original_model = diffusion.model
    diffusion.model = ema_model
    diffusion.model.eval()
    
    with torch.no_grad():
        gen_n = min(params['gen_n'], train_ihc_image.shape[0])
        cond_for_gen = train_ihc_image[:gen_n].to(device)
        
        # ‚≠ê ÏàòÏ†ï: cembÎ•º ÏÉùÏÑ± ÏãúÏóêÎèÑ Ïò¨Î∞îÎ•¥Í≤å ÏÉùÏÑ±
        cemb_gen = torch.zeros((gen_n, params['cdim']), device=device)
        
        genshape = (gen_n, 3, params['image_size'], params['image_size'])
        
        if params['ddim']:
            generated = diffusion.ddim_sample(
                genshape, 50, 0.5, 'quadratic', 
                mask=cond_for_gen, cemb=cemb_gen
            )
        else:
            generated = diffusion.sample(
                genshape, 
                mask=cond_for_gen, cemb=cemb_gen
            )
        
        # Ïù¥ÎØ∏ÏßÄ Ï†ÄÏû•
        generated = transback(generated)
        cond_for_gen = transback(cond_for_gen)
        
        concatenated_images = torch.cat([
            torch.cat([cond_for_gen[i].cpu(), generated[i].cpu()], dim=2) 
            for i in range(gen_n)
        ], dim=1)
        
        img_pil = topilimage(concatenated_images)
        createDirectory(f'../../results/IHC2HE/condition_diffusion/membrane')
        img_pil.save(
            f'../../results/IHC2HE/condition_diffusion/membrane/epc_{epc+1}_samples.png'
        )
    
    # Î™®Îç∏ Î≥µÏõê
    diffusion.model = original_model

    checkpoint = {
        'epoch': epc + 1,
        'net': diffusion.model.state_dict(),
        'ema_net': ema_model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': warmUpScheduler.state_dict(),
        'scaler': scaler.state_dict(),
        'params': params
    }
    createDirectory(f'../../model/IHC2HE/condition_diffusion/membrane/')
    torch.save(
        checkpoint, 
        f'../../model/IHC2HE/condition_diffusion/membrane/ckpt_{epc+1}.pt'
    )
    print(f"\nüíæ Checkpoint saved at epoch {epc+1}")

print("‚úÖ Training completed!")