In [1]:
from birb_dataset import get_bird_dataset
from denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer
from birb_dataset import get_stl10_dataset
from birb_dataset import get_animals_10n

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
IMAGE_SIZE = 64
CONDITION_DIM = 256
CONDITIONING_WEIGHTS=(1, 4)
BATCH_SIZE = 16
# ROOT = './data/birds'
ROOT = './data/animals'
RESULTS_FOLDER='./results/animals_cond'

with open(f'{ROOT}/classes.txt', 'r') as f:
    CLASSES = [l.strip() for l in f.readlines()]


In [3]:
ds, class_to_idx = get_bird_dataset(ROOT, IMAGE_SIZE, CLASSES, CONDITION_DIM)


# ds, class_to_idx = get_stl10_dataset(IMAGE_SIZE)

In [4]:
class_to_idx

{'cane': 0,
 'cavallo': 1,
 'elefante': 2,
 'farfalla': 3,
 'gallina': 4,
 'gatto': 5,
 'mucca': 6,
 'pecora': 7,
 'ragno': 8,
 'scoiattolo': 9}

In [5]:
unet = Unet(
    IMAGE_SIZE, 
    condition_dim=CONDITION_DIM,
    condition_vocab_size=len(class_to_idx)
)

diff = GaussianDiffusion(
    unet, 
    image_size=IMAGE_SIZE,
    timesteps=1000,
    conditioning_weights=CONDITIONING_WEIGHTS,
)

trainer = Trainer(
    diff, 
    ds=ds,
    results_folder=RESULTS_FOLDER,
    conditioning=True,
    train_batch_size=BATCH_SIZE,
    class_to_idx=class_to_idx,
    save_and_sample_every=1000,
    train_lr=1e-8,
    num_samples=BATCH_SIZE,
    train_num_steps=10000000000
)

In [6]:
trainer.load(101)
trainer.train()

sampling loop time step: 100%|██████████| 1000/1000 [00:36<00:00, 27.27it/s]it/s] 
sampling loop time step: 100%|██████████| 1000/1000 [00:36<00:00, 27.12it/s]it/s]  
sampling loop time step: 100%|██████████| 1000/1000 [00:37<00:00, 26.80it/s]it/s]  
sampling loop time step: 100%|██████████| 1000/1000 [00:37<00:00, 26.89it/s]it/s]  
sampling loop time step: 100%|██████████| 1000/1000 [00:37<00:00, 26.84it/s]it/s]  
sampling loop time step: 100%|██████████| 1000/1000 [00:38<00:00, 25.81it/s]it/s]  
sampling loop time step: 100%|██████████| 1000/1000 [00:36<00:00, 27.17it/s]it/s]  
sampling loop time step: 100%|██████████| 1000/1000 [00:37<00:00, 26.49it/s]it/s]  
loss: 0.7177:   0%|          | 109898/10000000000 [45:14<749992:37:48,  3.70it/s]  