In [1]:
import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from torch.utils.data.distributed import DistributedSampler
import random
from conditionDiffusion.unet import Unet
from conditionDiffusion.embedding import ConditionalEmbedding, MaskEmbedding
from conditionDiffusion.utils import get_named_beta_schedule
from conditionDiffusion.diffusion import GaussianDiffusion
from conditionDiffusion.Scheduler import GradualWarmupScheduler
from PIL import Image
import torchvision
import torch.nn as nn
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",0)
print(f"Device:\t\t{device}")
topilimage = torchvision.transforms.ToPILImage()
def createDirectory(directory):
    """_summary_
        create Directory
    Args:
        directory (string): file_path
    """
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print("Error: Failed to create the directory.")


GPUs used:	2
Device:		cuda:0


In [2]:
params={'image_size':512,
        'lr':2e-5,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':1,
        'epochs':1000,
        'n_classes':None,
        'image_count':20000,
        'inch':3,
        'mask_ch':3,  # set to number of mask channels (change to len(mask_list)+1 if using one-hot masks)
        'modch':128,
        'outch':3,
        'chmul':[1,2,4,8],
        'numres':1,
        'dtype':torch.float32,
        'cdim':10,
        'useconv':False,
        'droprate':0.1,
        'T':1000,
        'w':1.8,
        'v':0.3,
        'multiplier':1,
        'threshold':0.1,
        'ddim':True,
        }

In [3]:
trans = transforms.Compose([
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,parmas, images,mask,label):
        
        self.images = images
        self.masks = mask
        self.args=parmas
        self.label=label
        
    def trans(self,image,mask):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            mask = transform(mask)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            mask = transform(mask)
        return image,mask
    
    def __getitem__(self, index):
        image=self.images[index]
        label=self.label[index]
        mask=self.masks[index]
        image,mask = self.trans(image,mask)
        return image,mask,label
    
    def __len__(self):
        return len(self.images)


image_label=[]
image_path=[]
ihc_image_list=glob('../../data/IHC4BC_Compressed/**/HER2/IHC/*.jpg')
image_temp_list=glob('../../data/IHC4BC_Compressed/**/Ki67/IHC/*.jpg')
ihc_image_list.extend(image_temp_list)
he_image_list=[p.replace('/IHC/','/HE/') for p in ihc_image_list]
if len(ihc_image_list)>params['image_count']:
    ihc_image_list=ihc_image_list[:params['image_count']]
    he_image_list=he_image_list[:params['image_count']]

train_ihc_image=torch.zeros((len(ihc_image_list),params['inch'],params['image_size'],params['image_size']))
train_he_image=torch.zeros((len(ihc_image_list),params['inch'],params['image_size'],params['image_size']))

for i in tqdm(range(len(ihc_image_list))):
    train_ihc_image[i]=trans(transforms.ToTensor()(Image.open(ihc_image_list[i]).convert('RGB').resize((params['image_size'],params['image_size']))))
    train_he_image[i]=trans(transforms.ToTensor()(Image.open(he_image_list[i]).convert('RGB').resize((params['image_size'],params['image_size']))))

KeyboardInterrupt: 

In [4]:
# Using channel-concat conditioning: concatenate x and cond_image along channels and feed to Unet
# If you have a mask_list (one-hot masks), set params['mask_ch'] accordingly; otherwise keep the default.


net = Unet(in_ch = params['inch'] + params['mask_ch'],
            mod_ch = params['modch'],
            out_ch = params['outch'],
            ch_mul = params['chmul'],
            num_res_blocks = params['numres'],
            cdim = params['cdim'],
            use_conv = params['useconv'],
            droprate = params['droprate'],
            dtype = params['dtype']
            ).to(device)
# Use image-based mask embedding (image -> global conditional vector)
# If your masks have different number of channels (e.g., one-hot masks with shape [num_classes, H, W]),
# set in_ch accordingly, e.g., in_ch = len(mask_list)+1
cemblayer = MaskEmbedding(in_ch = params['inch'], cdim = params['cdim']).to(device)
betas = get_named_beta_schedule(num_diffusion_timesteps = params['T'])
diffusion = GaussianDiffusion(
                    dtype = params['dtype'],
                    model = net,
                    betas = betas,
                    w = params['w'],
                    v = params['v'],
                    device = device
                )
optimizer = torch.optim.AdamW(
                itertools.chain(
                    diffusion.model.parameters(),
                    cemblayer.parameters()
                ),
                lr = params['lr'],
                weight_decay = 1e-4
            )


cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = params['epochs']/100,
                            eta_min = 0,
                            last_epoch = -1
                        )
warmUpScheduler = GradualWarmupScheduler(
                        optimizer = optimizer,
                        multiplier = params['multiplier'],
                        warm_epoch = params['epochs'] // 10,
                        after_scheduler = cosineScheduler,
                        last_epoch = 0
                    )


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features),
                      nn.ReLU(inplace=True),
                      nn.ReflectionPad2d(1),
                      nn.Conv2d(in_features, in_features, 3),
                      nn.InstanceNorm2d(in_features)]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)



# checkpoint=torch.load(f'../../model/conditionDiff/details/STNT/ckpt_231_checkpoint.pt',map_location=device)
# diffusion.model.load_state_dict(checkpoint['net'])
# cemblayer.load_state_dict(checkpoint['cemblayer'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# warmUpScheduler.load_state_dict(checkpoint['scheduler'])


In [None]:
mask_tensor_list=torch.zeros((len(class_list),len(mask_list)+1,params['image_size'],params['image_size'])).to(device)
scaler = torch.cuda.amp.GradScaler()
for epc in range(params['epochs']):
    diffusion.model.train()
    cemblayer.train()
    total_loss = 0
    steps = 0
    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, mask, lab in tqdmDataLoader:
            b = img.shape[0]

            x_0 = img.to(device)
            mask_0 = mask.to(device)
            lab = lab.to(device)

            # 조건 임베딩 계산: mask(image) -> cemb
            cemb = cemblayer(mask_0)
            # classifier-free guidance: randomly drop some conditioning embeddings
            drop_idx = (torch.rand(b, device=device) < params['threshold'])
            if drop_idx.any():
                cemb[drop_idx] = 0
            mask_tensor_list[lab.argmax().item()] = mask_0 
            # AMP를 사용한 손실 계산 및 역전파
            with torch.cuda.amp.autocast():
                loss =  diffusion.trainloss(x_0, mask=mask_0, cemb=cemb)

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            steps += 1
            total_loss += loss.item()
            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epc + 1,
                    "loss": total_loss / steps,
                    "batch per device": x_0.shape[0],
                    "img shape": x_0.shape[1:],
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                }
            )
    
    warmUpScheduler.step()
    diffusion.model.eval()
    cemblayer.eval()
    all_samples = []
        # Use stored masks per class to compute conditional embeddings (image -> cemb)
        mask_for_gen = mask_tensor_list.to(device)
        cemb = cemblayer(mask_for_gen)
        
        # Define generation shape for the image batches
        genshape = (each_device_batch, 3, params['image_size'], params['image_size'])
        # Sample images using the chosen method (DDIM or standard sampling)
        if params['ddim']:
            generated = diffusion.ddim_sample(genshape, 50, 0.5, 'quadratic', mask=mask_for_gen, cemb=cemb)
        else:
            generated = diffusion.sample(genshape, mask=mask_for_gen, cemb=cemb)
        
        # Convert the generated tensors to images and save them
        generated = transback(Generator(generated.to(device)).to(device))
        for i in range(len(class_list)):
            img_pil = topilimage(torch.concat([generated[i].cpu(),(mask_for_gen[i].cpu().argmax(dim=0)/len(mask_list)*2-1).unsqueeze(0).repeat(3, 1, 1)],dim=2))
            createDirectory(
                f'../../result/mask_Diffusion/BRIL/{class_list[i]}/')
            img_pil.save(f'../../result/mask_Diffusion/BRIL/{class_list[i]}/{epc}.png')
        
        # Save model checkpoints
        checkpoint = {
            'net': diffusion.model.state_dict(),
            'cemblayer': cemblayer.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': warmUpScheduler.state_dict()
        }
        createDirectory(
        f'../../model/mask_Diffusion/BRIL/')
        torch.save(checkpoint, f'../../model/mask_Diffusion/BRIL/ckpt_{epc+1}_checkpoint.pt')
            'net': diffusion.model.state_dict(),
            'cemblayer': cemblayer.state_dict(),
            'optimizer': optimizer.state_dict(),

In [None]:
topilimage(torch.concat([generated[i].cpu(),(mask_tensor_list[0].cpu().argmax(dim=0)/len(mask_list)*2-1).unsqueeze(0).repeat(3, 1, 1)],dim=2))

In [None]:
torch.unique(mask_tensor_list[0].cpu().argmax(dim=0))

In [None]:
(mask_tensor_list[0].cpu().argmax(dim=0)/len(mask_list)*2-1).unsqueeze(0).repeat(3, 1, 1).shape