In [1]:
import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
from torch import Tensor
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from glob import glob
from torch.utils.data.distributed import DistributedSampler
import random
from conditionDiffusion.unet import Unet
from conditionDiffusion.embedding import ConditionalEmbedding
from conditionDiffusion.utils import get_named_beta_schedule
from conditionDiffusion.diffusion import GaussianDiffusion
from conditionDiffusion.Scheduler import GradualWarmupScheduler
from PIL import Image
print(f"GPUs used:\t{torch.cuda.device_count()}")
device = torch.device("cuda",6)
print(f"Device:\t\t{device}")
import pytorch_model_summary as tms
import torchvision



GPUs used:	8
Device:		cuda:6


In [2]:
class_list=['유형1','유형2']
params={'image_size':1024,
        'lr':1e-5,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':1,
        'epochs':1000,
        'n_classes':None,
        'data_path':'../../data/origin_type/STNT/',
        'image_count':5000,
        'inch':3,
        'modch': 32,
        'outch': 3,
        'chmul': [1, 2, 4, 8, 16, 32, 64],
        'numres':2,
        'dtype':torch.float32,
        'cdim':10,
        'useconv':False,
        'droprate':0.1,
        'T':1000,
        'w':1.8,
        'v':0.3,
        'multiplier':2.5,
        'threshold':0.1,
        'ddim':True,
        }


In [3]:
net = Unet(in_ch = params['inch'],
            mod_ch = params['modch'],
            out_ch = params['outch'],
            ch_mul = params['chmul'],
            num_res_blocks = params['numres'],
            cdim = params['cdim'],
            use_conv = params['useconv'],
            droprate = params['droprate'],
            dtype = params['dtype']
            ).to(device)
cemblayer = ConditionalEmbedding(len(class_list), params['cdim'], params['cdim']).to(device)
betas = get_named_beta_schedule(num_diffusion_timesteps = params['T'])
diffusion = GaussianDiffusion(
                    dtype = params['dtype'],
                    model = net,
                    betas = betas,
                    w = params['w'],
                    v = params['v'],
                    device = device
                )
optimizer = torch.optim.AdamW(
                itertools.chain(
                    diffusion.model.parameters(),
                    cemblayer.parameters()
                ),
                lr = params['lr'],
                weight_decay = 1e-4
            )


cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                            optimizer = optimizer,
                            T_max = params['epochs']/100,
                            eta_min = 0,
                            last_epoch = -1
                        )
warmUpScheduler = GradualWarmupScheduler(
                        optimizer = optimizer,
                        multiplier = params['multiplier'],
                        warm_epoch = params['epochs'] // 10,
                        after_scheduler = cosineScheduler,
                        last_epoch = 0
                    )

checkpoint=torch.load(f'../../model/conditionDiff/details/STNT/ckpt_138_checkpoint.pt',map_location=device)
diffusion.model.load_state_dict(checkpoint['net'])
cemblayer.load_state_dict(checkpoint['cemblayer'])
optimizer.load_state_dict(checkpoint['optimizer'])
warmUpScheduler.load_state_dict(checkpoint['scheduler'])
def transback(data:Tensor) -> Tensor:
    return data / 2 + 0.5

  checkpoint=torch.load(f'../../model/conditionDiff/details/STNT/ckpt_138_checkpoint.pt',map_location=device)


In [4]:
diffusion.model.eval()
cemblayer.eval()
all_samples = []
class_count = {key: 0 for key in class_list}
each_device_batch =len(class_list)*5
topilimage = torchvision.transforms.ToPILImage()
with torch.no_grad():
    lab = torch.ones(len(class_list), each_device_batch // len(class_list)).type(torch.long) \
    * torch.arange(start = 0, end = len(class_list)).reshape(-1, 1)
    lab = lab.reshape(-1, 1).squeeze()
    lab = lab.to(device)
    cemb = cemblayer(lab)
    genshape = (each_device_batch , 3, params['image_size'], params['image_size'])
    for k in range(100):
        generated = diffusion.ddim_sample(
                genshape, 100, 0.5, 'quadratic', cemb=cemb)
        generated=transback(generated)
        for i in range(len(lab)):
            img_pil = topilimage(generated[i].cpu())
            img_pil.save(f'../../result/Detail/STNT/Generator_image/{class_list[lab[i]]}/{class_count[class_list[lab[i]]]}.png')
            class_count[class_list[lab[i]]]+=1
    #generated = diffusion.sample(genshape, cemb = cemb)
torch.cuda.empty_cache()

Start generating(ddim)...


100%|██████████| 100/100 [08:02<00:00,  4.83s/it]


ending sampling process(ddim)...
Start generating(ddim)...


 25%|██▌       | 25/100 [01:57<05:51,  4.69s/it]


KeyboardInterrupt: 

In [None]:
class_list[lab[i]]