In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sympy.printing.pytorch import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from torchvision.transforms import Compose, Resize, ToTensor

from tqdm import tqdm

from utils import *
from config import *
from CProGAN import *

  from .autonotebook import tqdm as notebook_tqdm
<All keys matched successfully>


In [2]:
def train_fn(
        critic,
        gen,
        loader,
        dataset,
        step,
        alpha,
        opt_critic,
        opt_gen,
        tensorboard_step,
        writer,
        scaler_gen,
        scaler_critic,
):
    loop = tqdm(loader, leave=True)
    for batch_idx, data in enumerate(loop):
        real = data['pix'].to(config.DEVICE)
        embed = data['emb'].to(config.DEVICE)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
        noise = torch.randn(cur_batch_size, config.Z_DIM, 1, 1).to(config.DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, embed, alpha, step)
            critic_real = critic(real, embed, alpha, step)
            critic_fake = critic(fake.detach(), embed, alpha, step)
            gp = gradient_penalty(critic, embed, real, fake, alpha, step, device=config.DEVICE)
            loss_critic = (
                    -(torch.mean(critic_real) - torch.mean(critic_fake))
                    + config.LAMBDA_GP * gp
                    + (0.001 * torch.mean(critic_real ** 2))
            )

        opt_critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, embed, alpha, step)
            loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
                (config.PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(config.FIXED_NOISE, embed, alpha, step) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tensorboard_step, alpha

In [3]:
def main():
    # initialize gen and disc, note: discriminator should be called critic,
    # according to WGAN paper (since it no longer outputs between [0, 1])
    # but really who cares..
    gen = Generator(
        config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
    ).to(config.DEVICE)
    critic = Discriminator(
        config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
    ).to(config.DEVICE)

    # initialize optimizers and scalers for FP16 training
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))
    opt_critic = optim.Adam(
        critic.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99)
    )
    scaler_critic = torch.cuda.amp.GradScaler()
    scaler_gen = torch.cuda.amp.GradScaler()

    # for tensorboard plotting
    writer = SummaryWriter(f"logs/gan{config.ITER}")

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC, critic, opt_critic, config.LEARNING_RATE,
        )

    gen.train()
    critic.train()

    tensorboard_step = 0
    # start at step that corresponds to img size that we set in config
    step = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4))
    for num_epochs in config.PROGRESSIVE_EPOCHS[step:]:
        alpha = 1e-5  # start with very low alpha
        loader, dataset = get_loader(4 * 2 ** step)  # 4->0, 8->1, 16->2, 32->3, 64 -> 4
        print(f"Current image size: {4 * 2 ** step}")

        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            tensorboard_step, alpha = train_fn(
                critic,
                gen,
                loader,
                dataset,
                step,
                alpha,
                opt_critic,
                opt_gen,
                tensorboard_step,
                writer,
                scaler_gen,
                scaler_critic,
            )

            if config.SAVE_MODEL:
                save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
                save_checkpoint(critic, opt_critic, filename=config.CHECKPOINT_CRITIC)

        step += 1  # progress to the next img size

In [4]:
if __name__ == "__main__":
    main()

  scaler_critic = torch.cuda.amp.GradScaler()
  scaler_gen = torch.cuda.amp.GradScaler()
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Current image size: 8
Epoch [1/30]


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 625/625 [00:16<00:00, 38.21it/s, gp=0.0121, loss_critic=1.38]   


=> Saving checkpoint
=> Saving checkpoint
Epoch [2/30]


100%|██████████| 625/625 [00:23<00:00, 26.54it/s, gp=0.00982, loss_critic=1.07] 


=> Saving checkpoint
=> Saving checkpoint
Epoch [3/30]


100%|██████████| 625/625 [00:16<00:00, 36.96it/s, gp=0.0064, loss_critic=1.52]  


=> Saving checkpoint
=> Saving checkpoint
Epoch [4/30]


100%|██████████| 625/625 [00:16<00:00, 38.84it/s, gp=0.00666, loss_critic=0.605]  


=> Saving checkpoint
=> Saving checkpoint
Epoch [5/30]


100%|██████████| 625/625 [00:18<00:00, 33.91it/s, gp=0.00937, loss_critic=0.84]    


=> Saving checkpoint
=> Saving checkpoint
Epoch [6/30]


100%|██████████| 625/625 [00:19<00:00, 31.47it/s, gp=0.00305, loss_critic=0.581]  


=> Saving checkpoint
=> Saving checkpoint


KeyboardInterrupt: 

In [None]:
gen = Generator(
    config.Z_DIM, config.IN_CHANNELS, img_channels=config.CHANNELS_IMG
).to(config.DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.99))

In [None]:
load_checkpoint(
    config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
)

In [None]:
gen.train()

In [None]:
embedder = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)

In [None]:
embeding = torch.from_numpy(embedder.encode("tree")).unsqueeze(0).to(DEVICE)

In [None]:
x = torch.randn(1, Z_DIM, 1, 1).to(DEVICE)

In [None]:
out = gen(x, embeding, alpha = 1e-5, steps = int(log2(config.START_TRAIN_AT_IMG_SIZE / 4)))

In [None]:
out.shape

In [None]:
import torch
import matplotlib.pyplot as plt

def display_image_from_tensor(tensor: torch.Tensor) -> None:
    """
    Displays an image from a PyTorch tensor of shape [1, 3, 128, 128].

    This function assumes the tensor values are in the range [0, 1] or [-1, 1].
    It normalizes the tensor if necessary and displays the image using matplotlib.

    Args:
        tensor (torch.Tensor): The input tensor representing the image.

    Raises:
        ValueError: If the tensor does not match the expected shape.
    """
    if tensor.shape != torch.Size([1, 3, 128, 128]):
        raise ValueError(f"Expected tensor shape [1, 3, 128, 128], but got {tensor.shape}")

    # Squeeze the batch dimension
    image = tensor.squeeze(0)

    # Transpose to [H, W, C] for matplotlib
    image = image.permute(1, 2, 0).detach().cpu().numpy()

    # Normalize if values are in [-1, 1]
    if image.min() < 0:
        image = (image + 1) / 2

    # Clip values to [0, 1]
    image = image.clip(0, 1)

    # Display the image
    plt.imshow(image)
    plt.axis('off')
    plt.show()

In [None]:
display_image_from_tensor(out)