<a href="https://colab.research.google.com/github/akhilgakhar/Diffusion_Model101/blob/main/Intro_to_Diffusion_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -q kaggle

In [None]:
from google.colab import files

In [None]:
files.upload()

In [None]:
! mkdir ~/.kaggle

In [None]:
! cp kaggle.json ~/.kaggle/

In [None]:
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
! kaggle datasets download -d ebrahimelgazar/pixel-art

In [None]:
! unzip pixel-art.zip

In [None]:
# imports
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
from matplotlib import pyplot as plt
from diffusers import UNet2DModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

# pixel art dataset loader.
class PixelArtDataset(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.images = np.load(images_path)
        self.labels = np.load(labels_path)
        self.transform = transform

        # Data Validation (more comprehensive)
        if self.images.shape[1:] != (16, 16, 3):  # Check image size (16x16x3)
            raise ValueError(f"Images must be 16x16x3, but are {self.images.shape}")
        if self.labels.ndim != 2 or self.labels.shape[1] != 5:  # Check label shape
            raise ValueError(f"Labels must be a 2D array with 5 elements per row, but are {self.labels.shape}")
        if self.images.shape[0] != self.labels.shape[0]:
            raise ValueError("Number of images and labels must match")
        if self.images.dtype != np.uint8:  # Check data type. Assuming your images are 8-bit
            raise TypeError(f"Images must be uint8, but are {self.images.dtype}. Convert if necessary")
        if self.labels.dtype != np.float64:  # Check data type. Assuming your labels are float32
            raise TypeError(f"Labels must be float32, but are {self.labels.dtype}. Convert if necessary")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        # Convert image to tensor, normalize, and add channel dimension
        image = torch.from_numpy(image).float() / 255.0  # Normalize to [0, 1]

        # No need to unsqueeze for RGB images:
        image = image.permute(2, 0, 1) # Change from HxWx3 to 3xHxW (CHW)

        if self.transform:
            image = self.transform(image)

        label = torch.from_numpy(label).float()

        return image, label

images_path = "sprites.npy"
labels_path = "sprites_labels.npy"

transform = transforms.Compose([

])

dataset = PixelArtDataset(images_path, labels_path, transform=transform)

# Create a DataLoader for batching and shuffling
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)  # Adjust batch size as needed

def corrupt(x, amount):
  """Corrupt the input `x` by mixing it with noise according to `amount`"""
  noise = torch.rand_like(x)
  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
  return x*(1-amount) + noise*amount

# U-Net architecture for the denoising model
# Create the network
net = UNet2DModel(
    sample_size=16,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(32, 64, 64),  # Roughly matching our basic unet example
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",   # a regular ResNet upsampling block
    ),
)
net.to(device)

n_epochs = 10 # Try experimenting
batch_size = 2 # Try Experimenting
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []

from tqdm import tqdm # add loss to the leaf of tqdm

for epoch in tqdm(range(n_epochs)):

    for x, y in train_dataloader:

        # Get some data and prepare the corrupted version
        x = x.to(device) # Data on the GPU
        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
        noisy_x = corrupt(x, noise_amount) # Create our noisy x

        # Get the model prediction
        pred = net(noisy_x, 0).sample #<<< Using timestep 0 always, adding .sample

        # Calculate the loss
        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store the loss for later
        losses.append(loss.item())

    # Print our the average of the loss values for this epoch:
    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    #print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

torch.save(net.state_dict(), "model_state.pth")

# Samples
n_steps = 40
x = torch.rand(56, 3, 16, 16).to(device)
for i in range(n_steps):
  noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low
  with torch.no_grad():
    pred = net(x, 0).sample
  mix_factor = 1/(n_steps - i)
  x = x*(1-mix_factor) + pred*mix_factor

# Plotting
fig, axs = plt.subplots(1, 1, figsize=(8, 8))

# Convert for display
grid_img = torchvision.utils.make_grid(x.detach().cpu(), nrow=10, normalize=True).permute(1, 2, 0).numpy()

# Display
axs.imshow(grid_img)  # No cmap needed for RGB
axs.set_title('Generated Samples')
axs.axis("off")

plt.savefig('output_image.png')