In [1]:
import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from torch.utils.data.distributed import DistributedSampler
import random
from conditionDiffusion.unet import Unet
from conditionDiffusion.embedding import ConditionalEmbedding
from conditionDiffusion.utils import get_named_beta_schedule
from conditionDiffusion.diffusion import GaussianDiffusion
from conditionDiffusion.Scheduler import GradualWarmupScheduler
from PIL import Image
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",6)
print(f"Device:\t\t{device}")
import pytorch_model_summary as tms

GPUs used:	8
Device:		cuda:6


In [2]:
class_list=['유형1','유형2']
params={'image_size':512,
        'lr':1e-5,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':4,
        'epochs':1000,
        'n_classes':None,
        'data_path':'../../data/origin_type/BRNT/',
        'image_count':5000,
        'inch':3,
        'modch':64,
        'outch':3,
        'chmul':[1,2,4,8],
        'numres':2,
        'dtype':torch.float32,
        'cdim':10,
        'useconv':False,
        'droprate':0.1,
        'T':1000,
        'w':1.8,
        'v':0.3,
        'multiplier':2.5,
        'threshold':0.1,
        'ddim':True,
        }


In [3]:
trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,parmas, images,label):
        
        self.images = images
        self.args=parmas
        self.label=label
        
    def trans(self,image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            
        return image
    
    def __getitem__(self, index):
        image=self.images[index]
        label=self.label[index]
        image = self.trans(image)
        return image,label
    
    def __len__(self):
        return len(self.images)


image_label=[]
image_path=[]
for i in tqdm(range(len(class_list))):
    image_list=glob(params['data_path']+class_list[i]+'/*.jpeg')
    for j in range(len(image_list)):
        image_path.append(image_list[j])
        image_label.append(i)
        
train_images=torch.zeros((len(image_path),params['inch'],params['image_size'],params['image_size']))
for i in tqdm(range(len(image_path))):
    train_images[i]=trans(Image.open(image_path[i]).convert('RGB').resize((params['image_size'],params['image_size'])))
train_dataset=CustomDataset(params,train_images,image_label)
dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True)

100%|██████████| 2/2 [00:00<00:00, 16.47it/s]
100%|██████████| 3722/3722 [02:11<00:00, 28.41it/s]


In [4]:
net = Unet(in_ch = params['inch'],
            mod_ch = params['modch'],
            out_ch = params['outch'],
            ch_mul = params['chmul'],
            num_res_blocks = params['numres'],
            cdim = params['cdim'],
            use_conv = params['useconv'],
            droprate = params['droprate'],
            dtype = params['dtype']
            ).to(device)
cemblayer = ConditionalEmbedding(len(class_list), params['cdim'], params['cdim']).to(device)
betas = get_named_beta_schedule(num_diffusion_timesteps = params['T'])
diffusion = GaussianDiffusion(
                    dtype = params['dtype'],
                    model = net,
                    betas = betas,
                    w = params['w'],
                    v = params['v'],
                    device = device
                )
optimizer = torch.optim.AdamW(
                itertools.chain(
                    diffusion.model.parameters(),
                    cemblayer.parameters()
                ),
                lr = params['lr'],
                weight_decay = 1e-4
            )

cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                        optimizer = optimizer,
                        T_max = params['epochs'],
                        eta_min = 0,
                        last_epoch = -1
                    )
warmUpScheduler = GradualWarmupScheduler(
                        optimizer = optimizer,
                        multiplier = params['multiplier'],
                        warm_epoch = params['epochs'] // 10,
                        after_scheduler = cosineScheduler,
                        last_epoch = 0
                    )

In [5]:
for epc in range(params['epochs']):
    diffusion.model.train()
    cemblayer.train()
    total_loss=0
    steps=0
    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for img, lab in tqdmDataLoader:
            b = img.shape[0]
            optimizer.zero_grad()
            x_0 = img.to(device)
            lab = lab.to(device)
            cemb = cemblayer(lab)
            cemb[np.where(np.random.rand(b)<params['threshold'])] = 0
            loss = diffusion.trainloss(x_0, cemb = cemb)
            loss.backward()
            optimizer.step()
            steps+=1
            total_loss+=loss.item()
            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epc + 1,
                    "loss: ": total_loss/steps,
                    "batch per device: ":x_0.shape[0],
                    "img shape: ": x_0.shape[1:],
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                }
            )
    warmUpScheduler.step()
    if (epc) % 20 == 0:
        diffusion.model.eval()
        cemblayer.eval()
        # generating samples
        # The model generate 80 pictures(8 per row) each time
        # pictures of same row belong to the same class
        all_samples = []
        each_device_batch =10
        with torch.no_grad():
            lab = torch.ones(len(class_list), each_device_batch // len(class_list)).type(torch.long) \
            * torch.arange(start = 0, end = len(class_list)).reshape(-1, 1)
            lab = lab.reshape(-1, 1).squeeze()
            lab = lab.to(device)
            cemb = cemblayer(lab)
            genshape = (each_device_batch , 3, 512, 512)
            if params['ddim']:
                generated = diffusion.ddim_sample(genshape, 50, 0, 'linear', cemb = cemb)
            else:
                generated = diffusion.sample(genshape, cemb = cemb)
            img = transback(generated)
            img = img.reshape(len(class_list), each_device_batch // len(class_list), 3, 512, 512).contiguous()
            all_samples.append(img)
            samples = torch.concat(all_samples, dim = 1).reshape(each_device_batch, 3, 512, 512)

        save_image(samples,f'../../model/conditionDiff/details/BRNT/generated512_{epc+1}_pict.png', nrow = each_device_batch // len(class_list))
        # save checkpoints
        checkpoint = {
                            'net':diffusion.model.state_dict(),
                            'cemblayer':cemblayer.state_dict(),
                            'optimizer':optimizer.state_dict(),
                            'scheduler':warmUpScheduler.state_dict()
                        }
        torch.save(checkpoint, f'../../model/conditionDiff/details/BRNT/ckpt512_{epc+1}_checkpoint.pt')
    torch.cuda.empty_cache()
    

100%|██████████| 931/931 [12:03<00:00,  1.29it/s, epoch=1, loss: =0.213, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1e-5]


Start generating(ddim)...


100%|██████████| 50/50 [01:57<00:00,  2.35s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [12:01<00:00,  1.29it/s, epoch=2, loss: =0.0792, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.01e-5]
100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=3, loss: =0.0664, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.03e-5]
100%|██████████| 931/931 [12:08<00:00,  1.28it/s, epoch=4, loss: =0.0559, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.04e-5]
100%|██████████| 931/931 [12:04<00:00,  1.29it/s, epoch=5, loss: =0.0537, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.06e-5]
100%|██████████| 931/931 [12:01<00:00,  1.29it/s, epoch=6, loss: =0.0462, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.08e-5]
100%|██████████| 931/931 [12:04<00:00,  1.29it/s, epoch=7, loss: =0.0454, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.09e-5]
100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=8, loss: =0.0444, batch per device: =2, img shape: =torch.Size

Start generating(ddim)...


100%|██████████| 50/50 [01:58<00:00,  2.37s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=22, loss: =0.0362, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.32e-5]
100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=23, loss: =0.0373, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.33e-5]
100%|██████████| 931/931 [12:06<00:00,  1.28it/s, epoch=24, loss: =0.0377, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.35e-5]
100%|██████████| 931/931 [12:04<00:00,  1.29it/s, epoch=25, loss: =0.0352, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.36e-5]
100%|██████████| 931/931 [12:02<00:00,  1.29it/s, epoch=26, loss: =0.0341, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.38e-5]
100%|██████████| 931/931 [12:04<00:00,  1.28it/s, epoch=27, loss: =0.0379, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.39e-5] 
100%|██████████| 931/931 [12:06<00:00,  1.28it/s, epoch=28, loss: =0.0372, batch per device: =2, img shape: =to

Start generating(ddim)...


100%|██████████| 50/50 [02:21<00:00,  2.84s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [12:02<00:00,  1.29it/s, epoch=42, loss: =0.0339, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.62e-5]
100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=43, loss: =0.033, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.63e-5] 
100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=44, loss: =0.0362, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.65e-5]
100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=45, loss: =0.0343, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.66e-5]
100%|██████████| 931/931 [12:05<00:00,  1.28it/s, epoch=46, loss: =0.0335, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.68e-5] 
100%|██████████| 931/931 [12:03<00:00,  1.29it/s, epoch=47, loss: =0.0312, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.69e-5]
100%|██████████| 931/931 [12:06<00:00,  1.28it/s, epoch=48, loss: =0.0338, batch per device: =2, img shape: =to

Start generating(ddim)...


100%|██████████| 50/50 [02:22<00:00,  2.85s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [11:58<00:00,  1.30it/s, epoch=62, loss: =0.0329, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.92e-5]
100%|██████████| 931/931 [12:01<00:00,  1.29it/s, epoch=63, loss: =0.033, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.93e-5] 
100%|██████████| 931/931 [12:02<00:00,  1.29it/s, epoch=64, loss: =0.0322, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.94e-5]
100%|██████████| 931/931 [12:03<00:00,  1.29it/s, epoch=65, loss: =0.0337, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.96e-5]
100%|██████████| 931/931 [12:04<00:00,  1.28it/s, epoch=66, loss: =0.0338, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.98e-5]
100%|██████████| 931/931 [12:03<00:00,  1.29it/s, epoch=67, loss: =0.0317, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1.99e-5] 
100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=68, loss: =0.0315, batch per device: =2, img shape: =to

Start generating(ddim)...


100%|██████████| 50/50 [02:09<00:00,  2.58s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [12:01<00:00,  1.29it/s, epoch=82, loss: =0.0343, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=2.21e-5]
100%|██████████| 931/931 [12:07<00:00,  1.28it/s, epoch=83, loss: =0.0313, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=2.23e-5]
100%|██████████| 931/931 [12:07<00:00,  1.28it/s, epoch=84, loss: =0.032, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=2.25e-5] 
100%|██████████| 931/931 [12:11<00:00,  1.27it/s, epoch=85, loss: =0.0321, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=2.26e-5]
100%|██████████| 931/931 [12:11<00:00,  1.27it/s, epoch=86, loss: =0.0335, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=2.28e-5]
100%|██████████| 931/931 [12:12<00:00,  1.27it/s, epoch=87, loss: =0.0331, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=2.29e-5]
100%|██████████| 931/931 [12:08<00:00,  1.28it/s, epoch=88, loss: =0.0311, batch per device: =2, img shape: =tor

Start generating(ddim)...


100%|██████████| 50/50 [00:58<00:00,  1.18s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [12:01<00:00,  1.29it/s, epoch=102, loss: =0.0306, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1e-5]
100%|██████████| 931/931 [12:02<00:00,  1.29it/s, epoch=103, loss: =0.0308, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1e-5]
100%|██████████| 931/931 [12:01<00:00,  1.29it/s, epoch=104, loss: =0.0319, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1e-5]
100%|██████████| 931/931 [12:02<00:00,  1.29it/s, epoch=105, loss: =0.0298, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1e-5]
100%|██████████| 931/931 [12:02<00:00,  1.29it/s, epoch=106, loss: =0.0345, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1e-5]
100%|██████████| 931/931 [12:01<00:00,  1.29it/s, epoch=107, loss: =0.032, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=1e-5] 
100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=108, loss: =0.0308, batch per device: =2, img shape: =torch.Size([3,

Start generating(ddim)...


100%|██████████| 50/50 [00:59<00:00,  1.19s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=122, loss: =0.0316, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.99e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=123, loss: =0.0312, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.99e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=124, loss: =0.0322, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.99e-6]
100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=125, loss: =0.0313, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.99e-6]
100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=126, loss: =0.0312, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.99e-6]
100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=127, loss: =0.0305, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.98e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=128, loss: =0.0316, batch per device: =2, img shap

Start generating(ddim)...


100%|██████████| 50/50 [00:59<00:00,  1.19s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=142, loss: =0.0302, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.96e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=143, loss: =0.0315, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.96e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=144, loss: =0.0313, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.96e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=145, loss: =0.0297, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.95e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=146, loss: =0.0307, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.95e-6]
100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=147, loss: =0.0312, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.95e-6]
100%|██████████| 931/931 [11:58<00:00,  1.30it/s, epoch=148, loss: =0.0316, batch per device: =2, img shap

Start generating(ddim)...


100%|██████████| 50/50 [00:56<00:00,  1.14s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=162, loss: =0.0304, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.91e-6]
100%|██████████| 931/931 [11:57<00:00,  1.30it/s, epoch=163, loss: =0.0305, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.91e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=164, loss: =0.0302, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.91e-6]
100%|██████████| 931/931 [12:00<00:00,  1.29it/s, epoch=165, loss: =0.0303, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.9e-6]
100%|██████████| 931/931 [12:02<00:00,  1.29it/s, epoch=166, loss: =0.03, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.9e-6]  
100%|██████████| 931/931 [11:58<00:00,  1.29it/s, epoch=167, loss: =0.0318, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.9e-6]
100%|██████████| 931/931 [11:59<00:00,  1.29it/s, epoch=168, loss: =0.0316, batch per device: =2, img shape: 

Start generating(ddim)...


100%|██████████| 50/50 [01:03<00:00,  1.26s/it]


ending sampling process(ddim)...


100%|██████████| 931/931 [12:06<00:00,  1.28it/s, epoch=182, loss: =0.0311, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.84e-6]
100%|██████████| 931/931 [12:09<00:00,  1.28it/s, epoch=183, loss: =0.0325, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.84e-6]
100%|██████████| 931/931 [12:20<00:00,  1.26it/s, epoch=184, loss: =0.0321, batch per device: =2, img shape: =torch.Size([3, 512, 512]), LR=9.84e-6]
 59%|█████▉    | 549/931 [07:05<04:56,  1.29it/s, epoch=185, loss: =0.0299, batch per device: =4, img shape: =torch.Size([3, 512, 512]), LR=9.83e-6]


KeyboardInterrupt: 