In [124]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from torch import nn
from DDUN.SUNET.unet import Unet
from DDUN.SUNET.sddpm import MarkovDDPM
from DDUN.SUNET.embedding import SinsuoidalPostionalEmbedding
from DDUN.SUNET.utils import *
from tqdm.notebook import tqdm
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [125]:
# Markovddpm.load_state_dict(torch.load("models/best_ddpm_1000T.pt"))
# optimizer.load_state_dict(torch.load("models/optimizer.pt"))

In [126]:
# * Initialization
N_STEPS = 1000
BATCH_SIZE = 32  # 1024
TIME_EMB_DIM = 100
START = 0.0001
END = 0.02
lr = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DEVICE = "cpu"

# * Data
transform_data = transforms.Compose(
    [transforms.ToTensor(), transforms.Lambda(lambda x: (x - 0.5) * 2.0)]
)

reverse_transform = transforms.Compose(
    [
        transforms.Lambda(lambda t: (t + 1) / 2),
        # transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.0),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ]
)
dataset  = datasets.FashionMNIST("data", train=True, download=True, transform=transform_data)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [127]:
dataset.classes

['T-shirt/top',
 'Trouser',
 'Pullover',
 'Dress',
 'Coat',
 'Sandal',
 'Shirt',
 'Sneaker',
 'Bag',
 'Ankle boot']

In [128]:
model = Unet(
    noise_steps=N_STEPS,
    time_emb_dim=TIME_EMB_DIM,
    num_classes = len(dataset.classes),
    device=DEVICE,
).to(DEVICE)

diffuser = MarkovDDPM(N_STEPS, START, END, 1, 28, len(dataset.classes), DEVICE)
optimizer = Adam(model.parameters(), lr=lr)

mse = nn.MSELoss()
# pbar = tqdm(loader)
EPOCHS = 20
BEST_LOSS = float("inf")

In [129]:
#model.load_state_dict(torch.load("models\SDDPM_model.pth"))
# optimizer.load_state_dict(torch.load("models/optimizer.pt"))

In [130]:
for epoch in range(EPOCHS):
    epoch_loss = 0.0
    for i , (images, labels) in enumerate(tqdm(loader)):
        images = images.to(DEVICE)
        #t = torch.arange(images.shape[0]).to(DEVICE)
        t = torch.randint(0, N_STEPS, (images.shape[0],), device=DEVICE).long()
        
        x_t, real_noise = diffuser.noise_images(images, t)
        predicted_noise = model(x_t, t,labels)
        if np.random.random() <0.1:
            labels = None
        loss = mse(predicted_noise, real_noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # pbar.set_postfix(MSE = loss.item())
        epoch_loss += loss.item()
        
    #* avg loss per epoch
    epoch_loss /= len(loader)
    
    
    if epoch_loss < BEST_LOSS:
        BEST_LOSS = epoch_loss
        torch.save(model.state_dict(), "models/SDDPM_CFG.pth")
        print("+++++++++++++++++++++++++++++++++++++++")
        print(f"at epoch {epoch} - loss improved to {BEST_LOSS:.6f}")
        print("+++++++++++++++++++++++++++++++++++++++")
    else: 
        print(f"Epoch {epoch} loss: {epoch_loss:.6f}")

  0%|          | 0/1875 [00:00<?, ?it/s]

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)


In [135]:
val = diffuser.generate(
    model = model,
    x_shape=(1, 1, 28, 28),
    labels = len(dataset.classes),
    cfg_scale=-0.1,
)

In [None]:
nn.Embedding(10, 1000)

Embedding(10, 1000)

In [None]:
gen, gen_hit = diffuser.generate(model, (16, 1, 28, 28), save_gen_hist=True)


KeyboardInterrupt: 

In [None]:
diffuser.save_gen_into_gif(gen_hit, "sdf")

In [None]:

def save_gen_into_gif(self=None, gen_hist=None, gif_name=None):
    frames = []
    for idx, tensor in enumerate(gen_hist[-2*int(len(gen_hist)/3):]):
        if idx % 9 == 0:
            normalized = tensor.clone()
            
            for i in range(len(normalized)):
                normalized[i] -= torch.min(normalized[i])
                normalized[i] *= 255 / torch.max(normalized[i])
                
            #* resahimg to a square image
            frame = einops.rearrange(normalized, "(b1 b2) c h w -> (b1 h) (b2 w) c", b1=int(tensor.shape[0]** 0.5))
            frame = frame.cpu().numpy().astype(np.uint8)
            frame = np.squeeze(frame, axis=2)
            #* converting to PIL image
            frame = Image.fromarray(frame)
            frame = frame.resize((1024, 1024))
            frame = np.array(frame)
            frames.append(frame)
    for i in range(18):
        frames.append(frames[-1])
    
    if gif_name is None:
        gif_name = "SDDPM_results"
    imageio.mimsave(f'{gif_name}.gif', frames, format = 'GIF-PIL', fps =  100000 ) #type: ignore
    print(f'gif with {len(frames)} frames saved')
    plt.imshow(frames[-1], cmap='gray')
    plt.axis('off')
save_gen_into_gif(gen_hist=gen_hit)