# Setup

In [None]:
!pip install torchmultimodal-nightly

In [1]:
import torch
import torchvision
import torchvision.transforms.functional as F

from torch import nn
from tqdm import tqdm
from torchmultimodal.diffusion_labs.modules.adapters.cfguidance import CFGuidance
from torchmultimodal.diffusion_labs.modules.losses.diffusion_hybrid_loss import DiffusionHybridLoss
from torchmultimodal.diffusion_labs.samplers.ddpm import DDPModule
from torchmultimodal.diffusion_labs.predictors.noise_predictor import NoisePredictor
from torchmultimodal.diffusion_labs.schedules.discrete_gaussian_schedule import linear_beta_schedule, DiscreteGaussianSchedule
from torchmultimodal.diffusion_labs.transforms.diffusion_transform import RandomDiffusionSteps
from torchmultimodal.diffusion_labs.utils.common import DiffusionOutput

# Schedule

In [2]:
schedule = DiscreteGaussianSchedule(linear_beta_schedule(1000))

# Predictor

In [3]:
predictor = NoisePredictor(schedule, lambda x: torch.clamp(x, -1, 1))

# U-Net

In [4]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, cond_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels + cond_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.pooling = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x, c):
        _, _, w, h = x.size()
        c = c.expand(-1, -1, w, h)
        x = self.block(torch.cat([x, c], 1))
        x_small = self.pooling(x)
        return x, x_small

class UpBlock(nn.Module):
    def __init__(self, inp, out):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(inp*2, out, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out, out, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.upsample = nn.Upsample(scale_factor=2)

    def forward(self, x, x_small):
        x_big = self.upsample(x_small)
        x = torch.cat((x_big, x), dim=1)
        x = self.block(x)
        return x

class UNet(nn.Module):
    def __init__(self, time_size=32, digit_size=32, steps=1000):
        super().__init__()
        cond_size = time_size + digit_size
        self.conv = nn.Conv2d(1, 128, kernel_size=3, padding=1)
        self.down = nn.ModuleList([DownBlock(128, 256, cond_size), DownBlock(256, 512, cond_size)])
        self.bottleneck = DownBlock(512, 512, cond_size)
        self.up = nn.ModuleList([UpBlock(512, 256), UpBlock(256, 128)])

        self.variance = nn.Conv2d(128, 1, kernel_size=3, padding=1)
        self.prediction = nn.Conv2d(128, 1, kernel_size=3, padding=1)
        self.time_projection = nn.Embedding(steps, time_size)

    def forward(self, x, t, conditional_inputs):
        b, c, h, w = x.shape
        timestep = self.time_projection(t).view(b, -1, 1, 1)
        condition = conditional_inputs["context"].view(b, -1, 1, 1)
        condition = torch.cat([timestep, condition], dim=1)

        x = self.conv(x)
        self.outs = []
        for block in self.down:
            out, x = block(x, condition)
            self.outs.append(out)
        x, _ = self.bottleneck(x, condition)
        for block in self.up:
            x = block(self.outs.pop(), x)
        v = self.variance(x)
        p = self.prediction(x)
        return DiffusionOutput(p, v)

# Diffusion Model

In [5]:
unet = UNet(time_size=32, digit_size=32)
unet = CFGuidance(unet, {"context": 32}, guidance=2.0)

In [6]:
eval_steps = torch.linspace(0, 999, 250, dtype=torch.long)
model = DDPModule(unet, schedule, predictor, eval_steps)

In [7]:
encoder = nn.Embedding(10, 32)

# Data

In [8]:
from torchvision.transforms import Compose, Resize, ToTensor, Lambda

diffusion_transform = RandomDiffusionSteps(schedule, batched=False)
# transform = Compose([Resize(32),
#                      ToTensor(),
#                      Lambda(lambda x: 2*x - 1),
#                      Lambda(lambda x: diffusion_transform({"x": x}))])

def scale(x):
    return 2 * x - 1

def apply_diffusion_transform(x):
    return diffusion_transform({"x": x})

transform = Compose([Resize(32),
                     ToTensor(),
                     Lambda(scale),
                     Lambda(apply_diffusion_transform)])

In [9]:
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader

train_dataset = FashionMNIST("fashion_mnist", train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=192, shuffle=True, num_workers=2, pin_memory=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to fashion_mnist\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:42<00:00, 618642.89it/s] 


Extracting fashion_mnist\FashionMNIST\raw\train-images-idx3-ubyte.gz to fashion_mnist\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to fashion_mnist\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 149192.52it/s]


Extracting fashion_mnist\FashionMNIST\raw\train-labels-idx1-ubyte.gz to fashion_mnist\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to fashion_mnist\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:04<00:00, 990598.22it/s] 


Extracting fashion_mnist\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to fashion_mnist\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to fashion_mnist\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5163146.10it/s]

Extracting fashion_mnist\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to fashion_mnist\FashionMNIST\raw






# Train

In [None]:
import torch
# Choose the GPU device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

epochs = 5

encoder.to(device)
model.to(device)

optimizer = torch.optim.AdamW(
    [{"params": encoder.parameters()}, {"params": model.parameters()}], lr=0.0001
)
h_loss = DiffusionHybridLoss(schedule)

encoder.train()
model.train()
for e in range(epochs):
    for sample in (pbar := tqdm(train_dataloader)):
        x, c = sample
        x0, xt, noise, t, c = x["x"].to(device), x["xt"].to(device), x["noise"].to(device), x["t"].to(device), c.to(device)
        optimizer.zero_grad()

        c = encoder(c)
        out = model(xt, t, {"context": c})
        loss = h_loss(out.prediction, noise, out.mean, out.log_variance, x0, xt, t)

        loss.backward()
        optimizer.step()

        pbar.set_description(f'{e+1}| Loss: {loss.item()}')


# Generate

In [11]:
def fashion_encoder(name, num=1):
    fashion_dict = {"t-shirt": 0, "pants": 1, "sweater": 2, "dress": 3, "coat": 4, 
                    "sandal": 5, "shirt": 6, "sneaker": 7, "purse": 8, "boot": 9}
    idx = torch.as_tensor([fashion_dict[name] for _ in range(num)]).to(device)

    encoder.eval()
    with torch.no_grad():
        embed = encoder(idx)
    return embed

In [15]:
model.eval()

c = fashion_encoder("boot", 9)
noise = torch.randn(size=(9,1,32,32)).to(device)

with torch.no_grad():
    imgs = model(noise, conditional_inputs={"context": c})

img_grid = torchvision.utils.make_grid(imgs, 3)
img = F.to_pil_image((img_grid + 1) / 2)
img.resize((288, 288))