In [5]:
import torch
import torch.nn as nn

In [6]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(config.DEVICE), y.to(config.DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        x_fake = y_fake*0.5+0.5 # remove normalization
        save_image(y_fake, folder+f"/y_gen_{epoch}.png")
        save_image(x*0.5+0.5, folder+f"/input_{epoch}.png")
    gen.train()

In [7]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    torch.save(checkpoint, filename)

In [8]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["optimizer"])
    
    # if we dont want do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    
    for param_group in optimizer.prarm_groups:
        param_group["lr"] = lr

In [9]:
def gradient_penalty(critic, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real *epsilon + fake * (1-epsilon)
    
    # calculate critic scores
    mixed_scores = critic(interpolated_images)
    
    gradient = torch.autograd.grad(
        inputs = interpolated_images, 
        outputs = mixed_scores,
        grad_outputs = torch.ones_like(mixed_scores),
        create_graph = True,
        retain_graph = True,
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim = 1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [10]:
def generate_examples(gen, steps, n = 100):
    gen.eval()
    alpha = 1.0
    for i in torch.no_grad():
        with torch.no_grad():
            noise = torch.randn(1, config.Z_DIM, 1, 1).to(config.DEVICE)
            img = gen(noise, alpha, steps)
            save_image(img * 0.5 + 0.5, f"saved_examples/img_{i}.png")
            
        gen.train()

In [11]:
def plot_to_tensorboard(writer, loss_critic, loss_gen, real, fake, tensorboard_step):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)
    
    with torch.no_grad():
        # take out (up to) 32 examples
        img_grid_real = torchvision.utils.make_grid(real[8:], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[8:], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Real", img_grid_fake, global_step=tensorboard_step)