In [18]:
import torch
from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler
from DiffusionFreeGuidence.ModelCondition import UNet
from torchvision.utils import save_image
import os
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import ToTensor
from PIL import Image
from tqdm import tqdm

In [22]:
modelConfig = {
        "epoch": 70,
        "img_num": 100,
        "T": 500,
        "channel": 128,
        "channel_mult": [1, 2, 2, 2],
        "num_res_blocks": 2,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.5,
        "beta_1": 1e-4,
        "beta_T": 0.028,
        "img_size": 32,
        "grad_clip": 1.,
        "device": "cuda:7",
        "w": 1.8,
        "save_dir": "./CheckpointsCondition/",
        "test_load_weight": "ckpt_69_.pt",
        "nrow": 10,
    }

In [23]:
device = torch.device(modelConfig["device"])

# load model and evaluate
with torch.no_grad():
    img_list = []
    for j in range(100):
        step = int(modelConfig["img_num"] // 10)
        label_list = []
        for i in range(10):
            label_list += [torch.ones(size=[1]).long() * i] * step
        labels = torch.cat(label_list, dim=0).long().to(device) + 1
        model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
        ckpt = torch.load(os.path.join(
            modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
        model.load_state_dict(ckpt)
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)

        noisyImage = torch.randn(
            size=[modelConfig["img_num"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
        saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
        sampledImg, sampledImg_list= sampler(noisyImage, labels, diffusion_process = True)
        sampledImg = sampledImg * 0.5 + 0.5  # [0 ~ 1]
        img_list.append(sampledImg)
        print("{} images generated".format((j+1)*modelConfig["img_num"]))

100%|██████████| 500/500 [02:01<00:00,  4.10it/s]


100 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


200 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


300 images generated


100%|██████████| 500/500 [02:01<00:00,  4.10it/s]


400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


500 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.04it/s]


700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


900 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


1000 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


1100 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


1200 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


1300 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


1400 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


1500 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


1600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


1700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


1800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


1900 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


2000 images generated


100%|██████████| 500/500 [02:02<00:00,  4.09it/s]


2100 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


2200 images generated


100%|██████████| 500/500 [02:02<00:00,  4.09it/s]


2300 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


2400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


2500 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


2600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


2700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


2800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


2900 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


3000 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


3100 images generated


100%|██████████| 500/500 [02:03<00:00,  4.04it/s]


3200 images generated


100%|██████████| 500/500 [02:04<00:00,  4.03it/s]


3300 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


3400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.04it/s]


3500 images generated


100%|██████████| 500/500 [02:02<00:00,  4.08it/s]


3600 images generated


100%|██████████| 500/500 [02:02<00:00,  4.09it/s]


3700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


3800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


3900 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


4000 images generated


100%|██████████| 500/500 [02:02<00:00,  4.08it/s]


4100 images generated


100%|██████████| 500/500 [02:02<00:00,  4.09it/s]


4200 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


4300 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


4400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


4500 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


4600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


4700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


4800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


4900 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


5000 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


5100 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


5200 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


5300 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


5400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


5500 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


5600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


5700 images generated


100%|██████████| 500/500 [02:02<00:00,  4.09it/s]


5800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


5900 images generated


100%|██████████| 500/500 [02:01<00:00,  4.11it/s]


6000 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


6100 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


6200 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


6300 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


6400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


6500 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


6600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


6700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


6800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


6900 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


7000 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


7100 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


7200 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


7300 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


7400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


7500 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


7600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


7700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


7800 images generated


100%|██████████| 500/500 [01:59<00:00,  4.18it/s]


7900 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


8000 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


8100 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


8200 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


8300 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


8400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.04it/s]


8500 images generated


100%|██████████| 500/500 [02:04<00:00,  4.03it/s]


8600 images generated


100%|██████████| 500/500 [02:03<00:00,  4.04it/s]


8700 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


8800 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


8900 images generated


100%|██████████| 500/500 [02:02<00:00,  4.08it/s]


9000 images generated


100%|██████████| 500/500 [02:02<00:00,  4.08it/s]


9100 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


9200 images generated


100%|██████████| 500/500 [02:03<00:00,  4.05it/s]


9300 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


9400 images generated


100%|██████████| 500/500 [02:03<00:00,  4.06it/s]


9500 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


9600 images generated


100%|██████████| 500/500 [02:02<00:00,  4.08it/s]


9700 images generated


100%|██████████| 500/500 [02:00<00:00,  4.13it/s]


9800 images generated


100%|██████████| 500/500 [02:02<00:00,  4.07it/s]


9900 images generated


100%|██████████| 500/500 [02:02<00:00,  4.08it/s]

10000 images generated





In [25]:
images = torch.cat(img_list, dim=0).to(device)

In [26]:
images.shape

torch.Size([10000, 3, 32, 32])

In [27]:
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Load the image and label at the given index
        image = self.images[idx]
        label = self.labels[idx]
        
        # # Apply the transformation if specified
        # if self.transform:
        #     image = self.transform(image)
        
        # Convert the label to a tensor
        label = torch.tensor(label)
        
        # Return the image and label as a tuple
        return image, label

In [28]:
dataset = CustomDataset(sampledImg, labels, transform=ToTensor())
torch.save(dataset, 'dataset.pth')