In [3]:
import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from torch.utils.data.distributed import DistributedSampler
import random
from conditionDiffusion.unet import Unet
from conditionDiffusion.utils import get_named_beta_schedule
# Note: using channel-concat IHC conditioning (IHC image -> condition); no label-based ConditionalEmbedding used here.
from conditionDiffusion.diffusion import GaussianDiffusion
from conditionDiffusion.Scheduler import GradualWarmupScheduler
from PIL import Image
import torchvision
import torch.nn as nn
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",0)
print(f"Device:\t\t{device}")
topilimage = torchvision.transforms.ToPILImage()
def createDirectory(directory):
    """_summary_
        create Directory
    Args:
        directory (string): file_path
    """
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print("Error: Failed to create the directory.")


GPUs used:	2
Device:		cuda:0


In [2]:
params={'image_size':512,
        'lr':2e-5,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':1,
        'epochs':1000,
        'n_classes':None,
        'image_count':20000,
        'inch':3,
        'mask_ch': 3,  # number of channels in conditioning IHC images (default same as image channels)
        'modch':128,
        'outch':3,
        'chmul':[1,2,4,8],
        'numres':1,
        'dtype':torch.float32,
        'cdim':10,
        'useconv':False,
        'droprate':0.1,
        'T':1000,
        'w':1.8,
        'v':0.3,
        'multiplier':1,
        'threshold':0.1,
        'ddim':True,
        'gen_n':8,  # number of IHC exemplars to sample for generation
        }

In [3]:
trans = transforms.Compose([
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

class CustomDataset(Dataset):
    """IHC->HE paired dataset.

    Returns (ihc_image, he_image, label) for each index.
    """
    def __init__(self,parmas, images,mask,label):
        
        self.images = images
        self.masks = mask
        self.args=parmas
        self.label=label
        
    def trans(self,image,mask):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            mask = transform(mask)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            mask = transform(mask)
        return image,mask
    
    def __getitem__(self, index):
        image=self.images[index]
        label=self.label[index]
        mask=self.masks[index]
        image,mask = self.trans(image,mask)
        return image,mask,label
    
    def __len__(self):
        return len(self.images)


image_label=[]
image_path=[]
ihc_image_list=glob('../../data/IHC4BC_Compressed/**/HER2/IHC/*.jpg')
image_temp_list=glob('../../data/IHC4BC_Compressed/**/Ki67/IHC/*.jpg')
ihc_image_list.extend(image_temp_list)
he_image_list=[p.replace('/IHC/','/HE/') for p in ihc_image_list]
if len(ihc_image_list)>params['image_count']:
    ihc_image_list=ihc_image_list[:params['image_count']]
    he_image_list=he_image_list[:params['image_count']]

train_ihc_image=torch.zeros((len(ihc_image_list),params['inch'],params['image_size'],params['image_size']))
train_he_image=torch.zeros((len(ihc_image_list),params['inch'],params['image_size'],params['image_size']))

for i in tqdm(range(len(ihc_image_list))):
    train_ihc_image[i]=trans(transforms.ToTensor()(Image.open(ihc_image_list[i]).convert('RGB').resize((params['image_size'],params['image_size']))))
    train_he_image[i]=trans(transforms.ToTensor()(Image.open(he_image_list[i]).convert('RGB').resize((params['image_size'],params['image_size']))))
    
train_dataset=CustomDataset(params,train_ihc_image,train_he_image,image_label)
train_dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],num_workers=4,drop_last=True)

KeyboardInterrupt: 

In [None]:
net = Unet(in_ch = params['inch'] + params['mask_ch'],
            mod_ch = params['modch'],
            out_ch = params['outch'],
            ch_mul = params['chmul'],
            num_res_blocks = params['numres'],
            cdim = params['cdim'],
            use_conv = params['useconv'],
            droprate = params['droprate'],
            dtype = params['dtype']
            ).to(device)
# We will sample conditioning IHC images from `train_ihc_image` at generation time (no persistent cond tensor list needed).
betas = get_named_beta_schedule(num_diffusion_timesteps = params['T'])
diffusion = GaussianDiffusion(
                    dtype = params['dtype'],
                    model = net,
                    betas = betas,
                    w = params['w'],
                    v = params['v'],
                    device = device
                )
optimizer = torch.optim.AdamW(
                diffusion.model.parameters(),
                lr = params['lr'],
                weight_decay = 1e-4
            )


cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = params['epochs']/100,
                            eta_min = 0,
                            last_epoch = -1
                        )
warmUpScheduler = GradualWarmupScheduler(
                        optimizer = optimizer,
                        multiplier = params['multiplier'],
                        warm_epoch = params['epochs'] // 10,
                        after_scheduler = cosineScheduler,
                        last_epoch = 0
                    )



from pytorch_model_summary import summary
print(summary(net, torch.ones(1, params['inch'] + params['mask_ch'], params['image_size'], params['image_size']), max_depth=3))

TypeError: summary() got an unexpected keyword argument 'inputs'

In [None]:
scaler = torch.cuda.amp.GradScaler()
for epc in range(params['epochs']):
    diffusion.model.train()
    total_loss = 0
    steps = 0
    with tqdm(train_dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, mask, lab in tqdmDataLoader:
            b = img.shape[0]

            # For image-translation: img = IHC (condition), mask = HE (target)
            cond_ihc = img.to(device)
            x_0 = mask.to(device)
            lab = lab.to(device)

            # Channel-concat conditioning: prepare cond and zero global cemb
            cemb = torch.zeros((b, params['cdim']), device=device)
            # classifier-free guidance by dropping conditioning images
            cond_input = cond_ihc.clone()
            drop_idx = (torch.rand(b, device=device) < params['threshold'])
            if drop_idx.any():
                cond_input[drop_idx] = 0

            # AMP를 사용한 손실 계산 및 역전파
            with torch.cuda.amp.autocast():
                loss = diffusion.trainloss(x_0, mask=cond_input, cemb=cemb)

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            steps += 1
            total_loss += loss.item()
            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epc + 1,
                    "loss": total_loss / steps,
                    "batch per device": x_0.shape[0],
                    "img shape": x_0.shape[1:],
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                }
            )

    warmUpScheduler.step()
    diffusion.model.eval()
    all_samples = []

    with torch.no_grad():
        # Select conditioning exemplars from train_ihc_image (first params['gen_n'] or whole if smaller)
        try:
            gen_n = min(params['gen_n'], train_ihc_image.shape[0])
            cond_for_gen = train_ihc_image[:gen_n].to(device)
        except NameError:
            # If train_ihc_image isn't available, fallback to a random noise cond of appropriate channels
            cond_for_gen = torch.randn((params['gen_n'], params['mask_ch'], params['image_size'], params['image_size']), device=device)

        # Define generation shape for the image batches
        genshape = (cond_for_gen.shape[0], 3, params['image_size'], params['image_size'])
        # Sample images using the chosen method (DDIM or standard sampling)
        if params['ddim']:
            generated = diffusion.ddim_sample(genshape, 50, 0.5, 'quadratic', mask=cond_for_gen, cemb=cemb)
        else:
            generated = diffusion.sample(genshape, mask=cond_for_gen, cemb=cemb)

        # Convert the generated tensors to images and save them (NO Generator postprocessing)
        generated = transback(generated.to(device))
        for i in range(cond_for_gen.shape[0]):
            # show conditioning IHC (left) and generated HE (right)
            img_pil = topilimage(torch.concat([cond_for_gen[i].cpu(), generated[i].cpu()], dim=2))
            createDirectory(f'../../results/IHC2HE/condition_diffusion/membrain')
            img_pil.save(f'../../results/IHC2HE/condition_diffusion/membrain/epc_{epc}_idx_{i}.png')

        # Save model checkpoint (no cemblayer state needed)
        checkpoint = {
            'net': diffusion.model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': warmUpScheduler.state_dict()
        }
        createDirectory(f'../../model/IHC2HE/condition_diffusion/membrain/')
        torch.save(checkpoint, f'../../model/IHC2HE/condition_diffusion/membrain/ckpt_{epc+1}_checkpoint.pt')