In [1]:
import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from torch.utils.data.distributed import DistributedSampler
import random
from conditionDiffusion.unet import Unet
from conditionDiffusion.embedding import ConditionalEmbedding
from conditionDiffusion.utils import get_named_beta_schedule
from conditionDiffusion.diffusion import GaussianDiffusion
from conditionDiffusion.Scheduler import GradualWarmupScheduler
from PIL import Image
import torchvision
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",1)
print(f"Device:\t\t{device}")
tf=transforms.ToTensor()
topilimage = torchvision.transforms.ToPILImage()
import pytorch_model_summary as tms


GPUs used:	8
Device:		cuda:1


In [2]:
class_list=['유형1','유형2','유형3','유형4','유형5','유형6','유형7','유형8','유형9','유형10','유형11','유형12','유형13','유형14','유형15']
params={'image_size':1024,
        'lr':2e-5,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':1,
        'epochs':1000,
        'n_classes':None,
        'data_path':'../../data/normalization_type/BR_mask/**/',
        'image_count':5000,
        'inch':6,
        'modch': 128,
        'outch': 6,
        'chmul': [1, 2,4,4],
        'numres':2,
        'dtype':torch.float32,
        'cdim':100,
        'useconv':False,
        'droprate':0.1,
        'T':1000,
        'w':1.8,
        'v':0.3,
        'multiplier':1,
        'threshold':0.1,
        'ddim':True,
        }


In [3]:
trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""

    def __init__(self, parmas, images, label):

        self.images = images
        self.args = parmas
        self.label = label

    def trans(self, image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)

        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)

        return image

    def __getitem__(self, index):
        image = self.images[index]
        label = self.label[index]
        image = self.trans(image)
        return image, label

    def __len__(self):
        return len(self.images)





image_label = []
image_path = []
for i in tqdm(range(len(class_list))):
    image_list = glob(params['data_path']+class_list[i]+'/*.npy')
    for j in range(len(image_list)):
        image_path.append(image_list[j])
        image_label.append(i)
train_label=torch.zeros((len(image_path),params['inch'],params['image_size'],params['image_size']))      
for i in tqdm(range(10)):
    npy_label = np.load(image_path[i].replace(
        'image', 'mask').replace('jpeg', 'npy'))
    for j in range(params['inch']):
        train_label[i, j] = torch.tensor(npy_label == j).float()*2-1
train_dataset=CustomDataset(params,train_label,image_label)
dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True)

100%|██████████| 15/15 [00:01<00:00, 14.96it/s]
100%|██████████| 10/10 [00:00<00:00, 20.90it/s]


In [4]:
net = Unet(in_ch = params['inch'],
            mod_ch = params['modch'],
            out_ch = params['outch'],
            ch_mul = params['chmul'],
            num_res_blocks = params['numres'],
            cdim = params['cdim'],
            use_conv = params['useconv'],
            droprate = params['droprate'],
            dtype = params['dtype']
            ).to(device)
betas = get_named_beta_schedule(num_diffusion_timesteps = params['T'])
cemblayer = ConditionalEmbedding(
    len(class_list), params['cdim'], params['cdim']).to(device)
diffusion = GaussianDiffusion(
                    dtype = params['dtype'],
                    model = net,
                    betas = betas,
                    w = params['w'],
                    v = params['v'],
                    device = device
                )
optimizer = torch.optim.AdamW(
                itertools.chain(
                    diffusion.model.parameters(),
                    cemblayer.parameters()
                ),
                lr = params['lr'],
                weight_decay = 1e-6
            )


cosineScheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
warmUpScheduler = GradualWarmupScheduler(
                        optimizer = optimizer,
                        multiplier = params['multiplier'],
                        warm_epoch = 50,
                        after_scheduler = cosineScheduler,
                        last_epoch = 0
                    )
# checkpoint=torch.load(f'../../model/conditionDiff/BR/ckpt_35_checkpoint.pt',map_location=device)
# diffusion.model.load_state_dict(checkpoint['net'])

checkpoint=0

In [5]:
scaler = torch.cuda.amp.GradScaler()

for epc in range(params['epochs']):
    diffusion.model.train()
    total_loss=0
    steps=0
    
    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, lab in tqdmDataLoader:
            b = img.shape[0]

            x_0 = img.to(device)
            lab = lab.to(device)
            cemb = cemblayer(lab)
            loss = diffusion.trainloss(x_0, cemb=cemb)
            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            steps += 1
            total_loss += loss.item()
            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epc + 1,
                    "loss: ": total_loss/steps,
                    "batch per device: ": x_0.shape[0],
                    "img shape: ": x_0.shape[1:],
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                }
            )
    warmUpScheduler.step()

    diffusion.model.eval()
    cemblayer.eval()
    # generating samples
    # The model generate 80 pictures(8 per row) each time
    # pictures of same row belong to the same class
    all_samples = []
    each_device_batch =len(class_list)
    count=1
    with torch.no_grad():
        for i, (img, lab1) in enumerate(dataloader):
            if i==0:
                lab = lab1.to(device)
            elif i < count:
                lab = torch.cat((lab,lab1.to(device)),0).to(device)
            else:
                break
        cemb =cemblayer(lab)
        genshape = (count, 6, params['image_size'], params['image_size'])
        if params['ddim']:
            generated = diffusion.ddim_sample(
                genshape, 20, 0.5, 'quadratic', cemb=cemb)
        else:
            generated = diffusion.sample(genshape, cemb = cemb)
        generated=transback(generated)
        generated=generated.cpu()
        for i in range(len(lab)):
            img_tensor=torch.zeros((3,params['image_size'],params['image_size']))
            img_tensor[0]=torch.where(generated[i,1]>0,1,0)
            img_tensor[1]=torch.where(generated[i,2]>0,1,0)
            img_tensor[2]=torch.where(generated[i,3]>0,1,0)
            img_tensor[0]=torch.where(generated[i,4]>0,1,0)
            img_tensor[1]=torch.where(generated[i,4]>0,1,0)
            img_tensor[1]=torch.where(generated[i,5]>0,1,0)
            img_tensor[2]=torch.where(generated[i,5]>0,1,0)
            
            img_pil = topilimage(img_tensor)
            img_pil.save(f'../../result/mask_synth/BR_mask/{epc}_{i}.png')

        # save checkpoints
        checkpoint = {
            'net': diffusion.model.state_dict(),
            'cemblayer': cemblayer.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': warmUpScheduler.state_dict()
        }
    torch.save(checkpoint, f'../../model/mask_synth/BR_mask/ckpt_{epc+1}_checkpoint.pt')
    torch.cuda.empty_cache()
    

  0%|          | 0/9179 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU  has a total capacity of 79.25 GiB of which 634.62 MiB is free. Process 303021 has 14.32 GiB memory in use. Including non-PyTorch memory, this process has 64.28 GiB memory in use. Of the allocated memory 63.29 GiB is allocated by PyTorch, and 505.17 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [15]:
count=1
with torch.no_grad():
    for i, (img, lab1) in enumerate(dataloader):
        if i==0:
            lab = lab1.to(device)
        elif i < count:
            lab = torch.cat((lab,lab1.to(device)),0).to(device)
        else:
            break
    cemb =cemblayer(lab)
    genshape = (count, 6, params['image_size'], params['image_size'])
    if params['ddim']:
        generated = diffusion.ddim_sample(
            genshape, 20, 0.5, 'quadratic', cemb=cemb)
    else:
        generated = diffusion.sample(genshape, cemb = cemb)
    generated=transback(generated)
    generated=generated.cpu()
    for i in range(len(lab)):
        img_tensor=torch.zeros((3,params['image_size'],params['image_size']))
        img_tensor[0]=torch.where(generated[i,1]>0,1,0)
        img_tensor[1]=torch.where(generated[i,2]>0,1,0)
        img_tensor[2]=torch.where(generated[i,3]>0,1,0)
        img_tensor[0]=torch.where(generated[i,4]>0,1,0)
        img_tensor[1]=torch.where(generated[i,4]>0,1,0)
        img_tensor[1]=torch.where(generated[i,5]>0,1,0)
        img_tensor[2]=torch.where(generated[i,5]>0,1,0)
        
        img_pil = topilimage(img_tensor)
        img_pil.save(f'../../result/mask_synth/BR_mask/{epc}_{i}.png')

    # save checkpoints
    checkpoint = {
        'net': diffusion.model.state_dict(),
        'cemblayer': cemblayer.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': warmUpScheduler.state_dict()
    }

Start generating(ddim)...


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:26<00:00,  1.33s/it]


ending sampling process(ddim)...
