In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sympy.printing.pytorch import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from generator import Generator, initialize_weights
from discriminator import Discriminator
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from torchvision.transforms import Compose, Resize, ToTensor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset('poloclub/diffusiondb', '2m_random_10k')
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Downloading data: 100%|██████████| 644M/644M [00:50<00:00, 12.9MB/s]    
Downloading data: 100%|██████████| 651M/651M [00:50<00:00, 12.9MB/s]    
Downloading data: 100%|██████████| 662M/662M [00:51<00:00, 12.8MB/s]    
Downloading data: 100%|██████████| 599M/599M [00:46<00:00, 12.9MB/s]    
Downloading data: 100%|██████████| 611M/611M [00:47<00:00, 12.9MB/s]    
Downloading data: 100%|██████████| 571M/571M [00:44<00:00, 12.9MB/s]    
Downloading data: 100%|██████████| 602M/602M [00:47<00:00, 12.7MB/s]    
Downloading data: 100%|██████████| 571M/571M [00:44<00:00, 12.9MB/s]    
Downloading data: 100%|██████████| 578M/578M [00:44<00:00, 12.9MB/s]    
Downloading data: 100%|██████████| 666M/666M [00:51<00:00, 12.9MB/s]    
Generating train split: 10000 examples [00:16, 622.50 examples/s]
<A

In [3]:
device = "mps" if torch.mps.is_available() else "cpu"
LEARNING_RATE = 2e-4
BATCH_SIZE = 10
IMAGE_SIZE = 64
CHANNELS = 3 #3
EMBEDDING_SIZE = 10
GEN_EMBEDDING_SIZE = 100
Z_DIM = 456
NUM_EPOCHS = 10
FEATURE_DISC = 64
FEATURE_GEN = 64
CRITIC_ITER = 10
WEIGHT_CLIP = 0.01
LAMBDA_GP = 10

In [4]:
jitter = Compose(
    [
        Resize((IMAGE_SIZE, IMAGE_SIZE)),
        ToTensor(),
    ]
)

def transforms(examples):
    examples["image"] = [jitter(image.convert("RGB")) for image in examples["image"]]
    examples["prompt"] = [torch.from_numpy(model.encode(examples["prompt"]))]

    return examples

In [5]:
dataset.set_transform(transforms)

In [6]:
def gradient_penalty(critic, labels, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, labels)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [7]:
train_dataset = dataset["train"].select_columns(['image', 'prompt'])

In [8]:
len(train_dataset)

10000

In [9]:
gen = Generator(Z_DIM, CHANNELS, FEATURE_GEN, IMAGE_SIZE).to(device)
critic = Discriminator(CHANNELS, FEATURE_DISC, IMAGE_SIZE, 768).to(device)
initialize_weights(gen)
initialize_weights(critic)

In [10]:
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

In [11]:
fixed_noise = torch.randn(1, Z_DIM, 1, 1).to(device)


In [12]:
writer_f = SummaryWriter(f"logs/fake")
writer_r = SummaryWriter(f"logs/real")
step = 0

In [13]:
gen.train()
critic.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
  (emb): Linear(in_feat

In [14]:
import torchvision
from torchinfo import summary

In [15]:
summary(gen)

Layer (type:depth-idx)                   Param #
Generator                                --
├─Sequential: 1-1                        --
│    └─ConvTranspose2d: 2-1              20,055,040
│    └─Sequential: 2-2                   --
│    │    └─ConvTranspose2d: 3-1         8,388,608
│    │    └─BatchNorm2d: 3-2             1,024
│    │    └─ReLU: 3-3                    --
│    └─Sequential: 2-3                   --
│    │    └─ConvTranspose2d: 3-4         2,097,152
│    │    └─BatchNorm2d: 3-5             512
│    │    └─ReLU: 3-6                    --
│    └─Sequential: 2-4                   --
│    │    └─ConvTranspose2d: 3-7         524,288
│    │    └─BatchNorm2d: 3-8             256
│    │    └─ReLU: 3-9                    --
│    └─ConvTranspose2d: 2-5              6,147
│    └─Tanh: 2-6                         --
Total params: 31,073,027
Trainable params: 31,073,027
Non-trainable params: 0

In [16]:
summary(critic)

Layer (type:depth-idx)                   Param #
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       4,160
│    └─LeakyReLU: 2-2                    --
│    └─Sequential: 2-3                   --
│    │    └─Conv2d: 3-1                  131,072
│    │    └─InstanceNorm2d: 3-2          256
│    │    └─LeakyReLU: 3-3               --
│    └─Sequential: 2-4                   --
│    │    └─Conv2d: 3-4                  524,288
│    │    └─InstanceNorm2d: 3-5          512
│    │    └─LeakyReLU: 3-6               --
│    └─Sequential: 2-5                   --
│    │    └─Conv2d: 3-7                  2,097,152
│    │    └─InstanceNorm2d: 3-8          1,024
│    │    └─LeakyReLU: 3-9               --
│    └─Conv2d: 2-6                       8,193
├─Linear: 1-2                            3,149,824
Total params: 5,916,481
Trainable params: 5,916,481
Non-trainable params: 0

In [17]:
for epoch in range(NUM_EPOCHS):
    for inx, data in enumerate(train_dataset):
        real = data['image'].unsqueeze(0).to(device)
        labels = data['prompt'].to(device)
        current_batch_size = real.size(0)  # Get actual batch size

        for _ in range(CRITIC_ITER):
            noise = torch.randn(current_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)
            gp = gradient_penalty(critic, labels, real, fake, device=device)
            loss_critic = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp)
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

        # Train generator
        out = critic(fake, labels).reshape(-1)
        lossG = -torch.mean(out)
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if inx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {inx}/{len(train_dataset)} "
                f"Loss D: {loss_critic:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(noise, labels)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_r.add_image("Real", img_grid_real, global_step=step)
                writer_f.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/10] Batch 0/10000 Loss D: -20.2760, loss G: 9.1220
Epoch [0/10] Batch 100/10000 Loss D: -20.9305, loss G: 90.1726
Epoch [0/10] Batch 200/10000 Loss D: -12.2908, loss G: 66.3133
Epoch [0/10] Batch 300/10000 Loss D: -37.9530, loss G: -14.9834
Epoch [0/10] Batch 400/10000 Loss D: -40.9974, loss G: 6.2963
Epoch [0/10] Batch 500/10000 Loss D: -47.9526, loss G: 29.1011
Epoch [0/10] Batch 600/10000 Loss D: -59.8531, loss G: 57.6297
Epoch [0/10] Batch 700/10000 Loss D: -31.2061, loss G: 21.3148
Epoch [0/10] Batch 800/10000 Loss D: -40.3443, loss G: 40.9904
Epoch [0/10] Batch 900/10000 Loss D: -41.9761, loss G: 38.4272
Epoch [0/10] Batch 1000/10000 Loss D: -43.9387, loss G: 61.3847
Epoch [0/10] Batch 1100/10000 Loss D: -40.7754, loss G: 9.8716
Epoch [0/10] Batch 1200/10000 Loss D: -35.6619, loss G: 42.1238
Epoch [0/10] Batch 1300/10000 Loss D: -42.2953, loss G: 56.5671
Epoch [0/10] Batch 1400/10000 Loss D: -35.0694, loss G: 6.3435
Epoch [0/10] Batch 1500/10000 Loss D: -31.9651, loss G:

KeyboardInterrupt: 