In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt


In [3]:
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything()

In [4]:
DATASET                 = r"D:\gokul\FINAL CAPSTONE\DyslexiEase\Data\Processsed\test.mp4"
START_TRAIN_AT_IMG_SIZE = 32
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE           = 1e-3
BATCH_SIZES             = [32, 32, 32, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]
image_size              = 512
CHANNELS_IMG            = 3
Z_DIM                   = 256  
IN_CHANNELS             = 256  
LAMBDA_GP               = 10
PROGRESSIVE_EPOCHS      = [30] * len(BATCH_SIZES)

In [5]:
def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            )#,a
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    return loader, dataset

In [5]:
def check_loader():
    loader,_ = get_loader(128)
    cloth ,_ = next(iter(loader))
    _, ax    = plt.subplots(3,3, figsize=(8,8))
    #plt.suptitle('Some real samples', fontsize=15, fontweight='bold')
    ind = 0 
    for k in range(3):
        for kk in range(3):
            ind += 1
            #ax[k][kk].imshow((cloth[ind].permute(1,2,0)+1)/2) 
check_loader()

: 

In [6]:
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

In [8]:
class WSConv2d(nn.Module):

    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
    ):
        super(WSConv2d, self).__init__()
        self.conv      = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale     = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias      = self.conv.bias #Copy the bias of the current column layer
        self.conv.bias = None      #Remove the bias

        # initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

In [12]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)

In [11]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1  = WSConv2d(in_channels, out_channels)
        self.conv2  = WSConv2d(out_channels, out_channels)
        self.leaky  = nn.LeakyReLU(0.2)
        self.pn     = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

In [14]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super(Generator, self).__init__()

        # initial takes 1x1 -> 4x4
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(
            len(factors) - 1
        ):  # -1 to prevent index error because of factors[i+1]
            conv_in_c  = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        # The number of channels in upscale will stay the same, while
        # out which has moved through prog_blocks might change. To ensure
        # we can convert both to rgb we use different rgb_layers
        # (steps-1) and steps for upscaled, out respectively
        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

In [16]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        # here we work back ways from factors because the discriminator
        # should be mirrored from the generator. So the first prog_block and
        # rgb layer we append will work for input size 1024x1024, then 512->256-> etc
        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        # perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
        # did this to "mirror" the generator initial_rgb
        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # down sampling using avg pool

        # this is the block for 4x4 input size
        self.final_block = nn.Sequential(
            # +1 to in_channels because we concatenate from MiniBatch std
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(
                in_channels, 1, kernel_size=1, padding=0, stride=1
            ),  # we use this instead of linear layer
        )

    def fade_in(self, alpha, downscaled, out):
        """Used to fade in downscaled using avg pooling and output from CNN"""
        # alpha should be scalar within [0, 1], and upscale.shape == generated.shape
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        # we take the std for each example (across all channels, and pixels) then we repeat it
        # for a single channel and concatenate it with the image. In this way the discriminator
        # will get information about the variation in the batch/image
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        # where we should start in the list of prog_blocks, maybe a bit confusing but
        # the last is for the 4x4. So example let's say steps=1, then we should start
        # at the second to last because input_size will be 8x8. If steps==0 we just
        # use the final block
        cur_step = len(self.prog_blocks) - steps

        # convert from rgb as initial step, this will depend on
        # the image size (each will have it's on rgb layer)
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:  # i.e, image is 4x4
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        # because prog_blocks might change the channels, for down scale we use rgb_layer
        # from previous/smaller size which in our case correlates to +1 in the indexing
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))

        # the fade_in is done first between the downscaled and the input
        # this is opposite from the generator
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In [17]:
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)
 
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [18]:
def generate_examples(gen, steps, n=100):

    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1, Z_DIM, 1, 1).to(DEVICE)
            img = gen(noise, alpha, steps)
            
            if not os.path.exists(f'saved_examples/step{steps}'):
                os.makedirs(f'saved_examples/step{steps}')
            save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
    gen.train()

In [19]:
torch.backends.cudnn.benchmarks = True


In [20]:
def train_fn(
    critic,
    gen,
    loader,
    dataset,
    step,
    alpha,
    opt_critic,
    opt_gen,
):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        # which is equivalent to minimizing the negative of the expression
        noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(DEVICE)

        fake = gen(noise, alpha, step)
        critic_real = critic(real, alpha, step)
        critic_fake = critic(fake.detach(), alpha, step)
        gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
        loss_critic = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake, alpha, step)
        loss_gen = -torch.mean(gen_fake)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Update alpha and ensure less than 1
        alpha += cur_batch_size / (
            (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )
        

    return alpha


In [21]:
# initialize gen and disc, note: discriminator we called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(
    Z_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
).to(DEVICE)
critic = Discriminator(
    IN_CHANNELS, img_channels=CHANNELS_IMG
).to(DEVICE)

# initialize optimizers
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)


gen.train()
critic.train()

step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS:
    alpha = 1e-5  # start with very low alpha, you can start with alpha=0
    loader, dataset = get_loader(4 * 2 ** step)  # 4->0, 8->1, 16->2, 32->3, 64 -> 4
    print(f"Current image size: {4 * 2 ** step}")

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        alpha = train_fn(
            critic,
            gen,
            loader,
            dataset,
            step,
            alpha,
            opt_critic,
            opt_gen,
        )
    generate_examples(gen, step, n=100)


    step += 1  # progress to the next img size

Current image size: 32
Epoch [1/30]


100%|██████████| 19/19 [00:06<00:00,  2.73it/s, gp=0.0611, loss_critic=-15.4]


Epoch [2/30]


100%|██████████| 19/19 [00:04<00:00,  3.95it/s, gp=0.735, loss_critic=-18]  


Epoch [3/30]


100%|██████████| 19/19 [00:05<00:00,  3.47it/s, gp=0.217, loss_critic=-9.59]


Epoch [4/30]


100%|██████████| 19/19 [00:05<00:00,  3.37it/s, gp=0.279, loss_critic=-7]   


Epoch [5/30]


100%|██████████| 19/19 [00:04<00:00,  3.88it/s, gp=0.162, loss_critic=-6.85] 


Epoch [6/30]


100%|██████████| 19/19 [00:04<00:00,  3.87it/s, gp=0.175, loss_critic=-5.66] 


Epoch [7/30]


100%|██████████| 19/19 [00:05<00:00,  3.40it/s, gp=0.0313, loss_critic=-3.41]


Epoch [8/30]


100%|██████████| 19/19 [00:05<00:00,  3.39it/s, gp=0.188, loss_critic=-7.25] 


Epoch [9/30]


100%|██████████| 19/19 [00:05<00:00,  3.45it/s, gp=0.1, loss_critic=-6.77]   


Epoch [10/30]


100%|██████████| 19/19 [00:05<00:00,  3.47it/s, gp=0.0286, loss_critic=-5.32]


Epoch [11/30]


100%|██████████| 19/19 [00:05<00:00,  3.45it/s, gp=0.0824, loss_critic=-7.59]


Epoch [12/30]


100%|██████████| 19/19 [00:05<00:00,  3.42it/s, gp=0.0602, loss_critic=-5.56]


Epoch [13/30]


100%|██████████| 19/19 [00:04<00:00,  3.87it/s, gp=0.0226, loss_critic=-5.22]


Epoch [14/30]


100%|██████████| 19/19 [00:05<00:00,  3.50it/s, gp=0.057, loss_critic=-4.5]  


Epoch [15/30]


100%|██████████| 19/19 [00:05<00:00,  3.49it/s, gp=0.0812, loss_critic=-6.99]


Epoch [16/30]


100%|██████████| 19/19 [00:05<00:00,  3.52it/s, gp=0.0538, loss_critic=-7.47]


Epoch [17/30]


100%|██████████| 19/19 [00:05<00:00,  3.48it/s, gp=0.0585, loss_critic=-5.1] 


Epoch [18/30]


100%|██████████| 19/19 [00:05<00:00,  3.46it/s, gp=0.0663, loss_critic=-4.72]


Epoch [19/30]


100%|██████████| 19/19 [00:05<00:00,  3.49it/s, gp=0.121, loss_critic=-6.12] 


Epoch [20/30]


100%|██████████| 19/19 [00:05<00:00,  3.44it/s, gp=0.0195, loss_critic=-4.9] 


Epoch [21/30]


100%|██████████| 19/19 [00:05<00:00,  3.38it/s, gp=0.0648, loss_critic=-5.65]


Epoch [22/30]


100%|██████████| 19/19 [00:05<00:00,  3.34it/s, gp=0.0762, loss_critic=-4.89]


Epoch [23/30]


100%|██████████| 19/19 [00:05<00:00,  3.34it/s, gp=0.0636, loss_critic=-6.08]


Epoch [24/30]


100%|██████████| 19/19 [00:05<00:00,  3.39it/s, gp=0.0796, loss_critic=-5.36]


Epoch [25/30]


100%|██████████| 19/19 [00:05<00:00,  3.39it/s, gp=0.0828, loss_critic=-4.64]


Epoch [26/30]


100%|██████████| 19/19 [00:05<00:00,  3.39it/s, gp=0.0245, loss_critic=-4.6] 


Epoch [27/30]


100%|██████████| 19/19 [00:05<00:00,  3.41it/s, gp=0.0662, loss_critic=-5.41]


Epoch [28/30]


100%|██████████| 19/19 [00:05<00:00,  3.34it/s, gp=0.0565, loss_critic=-4.49]


Epoch [29/30]


100%|██████████| 19/19 [00:05<00:00,  3.35it/s, gp=0.024, loss_critic=-4.49] 


Epoch [30/30]


100%|██████████| 19/19 [00:05<00:00,  3.47it/s, gp=0.0663, loss_critic=-5.15]


Current image size: 64
Epoch [1/30]


100%|██████████| 19/19 [00:10<00:00,  1.89it/s, gp=0.23, loss_critic=-10.8]  


Epoch [2/30]


100%|██████████| 19/19 [00:10<00:00,  1.88it/s, gp=0.177, loss_critic=-13.3] 


Epoch [3/30]


100%|██████████| 19/19 [00:10<00:00,  1.88it/s, gp=0.0705, loss_critic=-10.2]


Epoch [4/30]


100%|██████████| 19/19 [00:10<00:00,  1.86it/s, gp=0.326, loss_critic=-12.6]


Epoch [5/30]


100%|██████████| 19/19 [00:10<00:00,  1.87it/s, gp=0.147, loss_critic=-10.7] 


Epoch [6/30]


100%|██████████| 19/19 [00:10<00:00,  1.84it/s, gp=0.245, loss_critic=-13.4] 


Epoch [7/30]


100%|██████████| 19/19 [00:10<00:00,  1.88it/s, gp=0.353, loss_critic=-13.6]


Epoch [8/30]


100%|██████████| 19/19 [00:10<00:00,  1.85it/s, gp=0.25, loss_critic=-12.4] 


Epoch [9/30]


100%|██████████| 19/19 [00:10<00:00,  1.89it/s, gp=0.225, loss_critic=-11.8]


Epoch [10/30]


100%|██████████| 19/19 [00:09<00:00,  2.00it/s, gp=0.29, loss_critic=-12.5] 


Epoch [11/30]


100%|██████████| 19/19 [00:09<00:00,  1.99it/s, gp=0.28, loss_critic=-10.2] 


Epoch [12/30]


100%|██████████| 19/19 [00:09<00:00,  1.99it/s, gp=0.166, loss_critic=-10.8]


Epoch [13/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.167, loss_critic=-9.68]


Epoch [14/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.324, loss_critic=-13.8]


Epoch [15/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.354, loss_critic=-9.01]


Epoch [16/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.256, loss_critic=-12.3]


Epoch [17/30]


100%|██████████| 19/19 [00:09<00:00,  1.96it/s, gp=0.198, loss_critic=-12.6]


Epoch [18/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.208, loss_critic=-10.9]


Epoch [19/30]


100%|██████████| 19/19 [00:09<00:00,  1.96it/s, gp=0.204, loss_critic=-12]  


Epoch [20/30]


100%|██████████| 19/19 [00:09<00:00,  1.96it/s, gp=0.191, loss_critic=-10.3]


Epoch [21/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.171, loss_critic=-11.1]


Epoch [22/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.209, loss_critic=-10.2]


Epoch [23/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.171, loss_critic=-10.4]


Epoch [24/30]


100%|██████████| 19/19 [00:09<00:00,  1.96it/s, gp=0.196, loss_critic=-9.16]


Epoch [25/30]


100%|██████████| 19/19 [00:09<00:00,  1.96it/s, gp=0.162, loss_critic=-8.53]


Epoch [26/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.201, loss_critic=-10.3] 


Epoch [27/30]


100%|██████████| 19/19 [00:09<00:00,  1.97it/s, gp=0.128, loss_critic=-10]  


Epoch [28/30]


100%|██████████| 19/19 [00:09<00:00,  1.90it/s, gp=0.16, loss_critic=-7.42] 


Epoch [29/30]


100%|██████████| 19/19 [00:10<00:00,  1.83it/s, gp=0.147, loss_critic=-4.63] 


Epoch [30/30]


100%|██████████| 19/19 [00:10<00:00,  1.84it/s, gp=0.0949, loss_critic=-6.72]


Current image size: 128
Epoch [1/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=1.01, loss_critic=-10.5] 


Epoch [2/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.428, loss_critic=-15.2]


Epoch [3/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.341, loss_critic=-11]  


Epoch [4/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.118, loss_critic=-11.7] 


Epoch [5/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.225, loss_critic=-14.4]


Epoch [6/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.124, loss_critic=-11.1] 


Epoch [7/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.337, loss_critic=-19.3] 


Epoch [8/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.11, loss_critic=-12.3]  


Epoch [9/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.394, loss_critic=-14]   


Epoch [10/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.217, loss_critic=-11.9]


Epoch [11/30]


100%|██████████| 19/19 [00:18<00:00,  1.04it/s, gp=0.247, loss_critic=-13.4]


Epoch [12/30]


100%|██████████| 19/19 [00:18<00:00,  1.05it/s, gp=0.239, loss_critic=-11.9]


Epoch [13/30]


100%|██████████| 19/19 [00:17<00:00,  1.06it/s, gp=0.243, loss_critic=-13.9]


Epoch [14/30]


100%|██████████| 19/19 [00:17<00:00,  1.06it/s, gp=0.342, loss_critic=-12.2]


Epoch [15/30]


100%|██████████| 19/19 [00:18<00:00,  1.05it/s, gp=0.282, loss_critic=-11.2]


Epoch [16/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.297, loss_critic=-10.2]


Epoch [17/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.219, loss_critic=-10.5]


Epoch [18/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.182, loss_critic=-10.5]


Epoch [19/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.227, loss_critic=-12.5]


Epoch [20/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.223, loss_critic=-11.6]


Epoch [21/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.202, loss_critic=-10.4]


Epoch [22/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.273, loss_critic=-10.2]


Epoch [23/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.213, loss_critic=-10.1]


Epoch [24/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.232, loss_critic=-10.5]


Epoch [25/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.198, loss_critic=-11.3]


Epoch [26/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.203, loss_critic=-10.2]


Epoch [27/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.242, loss_critic=-11.7]


Epoch [28/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.217, loss_critic=-12.1]


Epoch [29/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.147, loss_critic=-5.33]


Epoch [30/30]


100%|██████████| 19/19 [00:17<00:00,  1.07it/s, gp=0.212, loss_critic=-7.53]


Current image size: 256
Epoch [1/30]


100%|██████████| 19/19 [02:01<00:00,  6.39s/it, gp=0.4, loss_critic=-24.3]    


Epoch [2/30]


100%|██████████| 19/19 [01:56<00:00,  6.11s/it, gp=0.617, loss_critic=-26.1]


Epoch [3/30]


100%|██████████| 19/19 [01:56<00:00,  6.13s/it, gp=0.496, loss_critic=-24.4]


Epoch [4/30]


100%|██████████| 19/19 [01:56<00:00,  6.15s/it, gp=0.616, loss_critic=-28.3]


Epoch [5/30]


100%|██████████| 19/19 [01:59<00:00,  6.28s/it, gp=0.692, loss_critic=-25.7]


Epoch [6/30]


100%|██████████| 19/19 [01:59<00:00,  6.27s/it, gp=0.763, loss_critic=-22]  


Epoch [7/30]


100%|██████████| 19/19 [01:58<00:00,  6.26s/it, gp=0.653, loss_critic=-33]  


Epoch [8/30]


100%|██████████| 19/19 [01:57<00:00,  6.17s/it, gp=0.369, loss_critic=-24.2]


Epoch [9/30]


100%|██████████| 19/19 [01:56<00:00,  6.12s/it, gp=1.24, loss_critic=-32.8] 


Epoch [10/30]


100%|██████████| 19/19 [01:58<00:00,  6.24s/it, gp=0.841, loss_critic=-29.1]


Epoch [11/30]


100%|██████████| 19/19 [01:57<00:00,  6.20s/it, gp=0.631, loss_critic=-31.5]


Epoch [12/30]


100%|██████████| 19/19 [01:56<00:00,  6.11s/it, gp=0.555, loss_critic=-26.4]


Epoch [13/30]


100%|██████████| 19/19 [01:56<00:00,  6.13s/it, gp=0.925, loss_critic=-28.1]


Epoch [14/30]


100%|██████████| 19/19 [01:56<00:00,  6.13s/it, gp=0.895, loss_critic=-47.1]


Epoch [15/30]


100%|██████████| 19/19 [01:57<00:00,  6.20s/it, gp=0.948, loss_critic=-33.8]


Epoch [16/30]


100%|██████████| 19/19 [01:56<00:00,  6.14s/it, gp=0.472, loss_critic=-18.4]


Epoch [17/30]


100%|██████████| 19/19 [01:57<00:00,  6.18s/it, gp=0.658, loss_critic=-23]  


Epoch [18/30]


100%|██████████| 19/19 [01:57<00:00,  6.20s/it, gp=0.698, loss_critic=-31.2]


Epoch [19/30]


100%|██████████| 19/19 [01:56<00:00,  6.12s/it, gp=0.471, loss_critic=-20.9]


Epoch [20/30]


100%|██████████| 19/19 [01:57<00:00,  6.19s/it, gp=0.907, loss_critic=-32.5]


Epoch [21/30]


100%|██████████| 19/19 [01:56<00:00,  6.15s/it, gp=0.815, loss_critic=-23.6]


Epoch [22/30]


100%|██████████| 19/19 [01:56<00:00,  6.14s/it, gp=0.654, loss_critic=-25.8]


Epoch [23/30]


100%|██████████| 19/19 [01:58<00:00,  6.23s/it, gp=0.508, loss_critic=-24.8]


Epoch [24/30]


100%|██████████| 19/19 [01:59<00:00,  6.28s/it, gp=0.684, loss_critic=-25.1]


Epoch [25/30]


100%|██████████| 19/19 [01:59<00:00,  6.27s/it, gp=0.724, loss_critic=-23.2]


Epoch [26/30]


100%|██████████| 19/19 [01:59<00:00,  6.27s/it, gp=0.614, loss_critic=-23.9]


Epoch [27/30]


100%|██████████| 19/19 [01:59<00:00,  6.28s/it, gp=0.582, loss_critic=-27.1]


Epoch [28/30]


100%|██████████| 19/19 [01:57<00:00,  6.18s/it, gp=0.628, loss_critic=-25.8]


Epoch [29/30]


100%|██████████| 19/19 [01:56<00:00,  6.12s/it, gp=0.675, loss_critic=-27.2]


Epoch [30/30]


100%|██████████| 19/19 [01:57<00:00,  6.20s/it, gp=0.677, loss_critic=-25.6]


Current image size: 512
Epoch [1/30]


  0%|          | 0/19 [00:15<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 12.00 GiB total capacity; 17.12 GiB already allocated; 0 bytes free; 17.49 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF