In [1]:
import torch
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import time
import torch.nn
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2

LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN = "/kaggle/input/saved-models/gen.pth"
CHECKPOINT_DISC = "/kaggle/input/saved-models/disc.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100
BATCH_SIZE = 16
LAMBDA_GP = 10
NUM_WORKERS = 4
HIGH_RES = 128
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3

highres_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

lowres_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

both_transforms = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)




In [2]:

class MyImageFolder(Dataset):
    def __init__(self, root_dir):
        super(MyImageFolder, self).__init__()
        self.data = []
        self.root_dir = root_dir
        self.class_names = os.listdir(root_dir)

        for index, name in enumerate(self.class_names):
            files = os.listdir(os.path.join(root_dir, name))
            self.data += list(zip(files, [index] * len(files)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_file, label = self.data[index]
        root_and_dir = os.path.join(self.root_dir, self.class_names[label])

        image = cv2.imread(os.path.join(root_and_dir, img_file))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        both_transform = both_transforms(image=image)["image"]
        low_res = lowres_transform(image=both_transform)["image"]
        high_res = highres_transform(image=both_transform)["image"]
        return low_res, high_res


def test():
    dataset = MyImageFolder(root_dir="/kaggle/input/div2k-dataset/DIV2K_train_HR")
    loader = DataLoader(dataset, batch_size=8)

    for low_res, high_res in loader:
        print(low_res.shape)
        print(high_res.shape)


#if __name__ == "__main__":
    #test()


In [3]:
from torchvision.models import vgg19
import torch.nn as nn


class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:35].eval().to(DEVICE)

        for param in self.vgg.parameters():
            param.requires_grad = False

        self.loss = nn.MSELoss()

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)


In [4]:

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_act, **kwargs):
        super().__init__()
        self.cnn = nn.Conv2d(
            in_channels,
            out_channels,
            **kwargs,
            bias=True,
        )
        self.act = nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity()

    def forward(self, x):
        return self.act(self.cnn(x))


class UpsampleBlock(nn.Module):
    def __init__(self, in_c, scale_factor=2):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode="nearest")
        self.conv = nn.Conv2d(in_c, in_c, 3, 1, 1, bias=True)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.act(self.conv(self.upsample(x)))


class DenseResidualBlock(nn.Module):
    def __init__(self, in_channels, channels=32, residual_beta=0.2):
        super().__init__()
        self.residual_beta = residual_beta
        self.blocks = nn.ModuleList()

        for i in range(5):
            self.blocks.append(
                ConvBlock(
                    in_channels + channels * i,
                    channels if i <= 3 else in_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    use_act=True if i <= 3 else False,
                )
            )

    def forward(self, x):
        new_inputs = x
        for block in self.blocks:
            out = block(new_inputs)
            new_inputs = torch.cat([new_inputs, out], dim=1)
        return self.residual_beta * out + x


class RRDB(nn.Module):
    def __init__(self, in_channels, residual_beta=0.2):
        super().__init__()
        self.residual_beta = residual_beta
        self.rrdb = nn.Sequential(*[DenseResidualBlock(in_channels) for _ in range(3)])

    def forward(self, x):
        return self.rrdb(x) * self.residual_beta + x


class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=23):
        super().__init__()
        self.initial = nn.Conv2d(
            in_channels,
            num_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=True,
        )
        self.residuals = nn.Sequential(*[RRDB(num_channels) for _ in range(num_blocks)])
        self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
        self.upsamples = nn.Sequential(
            UpsampleBlock(num_channels), UpsampleBlock(num_channels),
        )
        self.final = nn.Sequential(
            nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_channels, in_channels, 3, 1, 1, bias=True),
        )

    def forward(self, x):
        initial = self.initial(x)
        x = self.conv(self.residuals(initial)) + initial
        x = self.upsamples(x)
        return self.final(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    use_act=True,
                ),
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512 * 6 * 6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)

def initialize_weights(model, scale=0.1):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight.data)
            m.weight.data *= scale

        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight.data)
            m.weight.data *= scale


def test():
    gen = Generator()
    disc = Discriminator()
    low_res = 24
    x = torch.randn((5, 3, low_res, low_res))
    gen_out = gen(x)
    disc_out = disc(gen_out)

    print(gen_out.shape)
    print(disc_out.shape)

#if __name__ == "__main__":
#   test()


In [5]:
from torchvision.utils import save_image


def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake.detach() * (1 - alpha)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    # model.load_state_dict(checkpoint)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def plot_examples(low_res_folder, gen):
    files = os.listdir(low_res_folder)

    gen.eval()
    for file in files:
        image=Image.open(os.path.join("test_images/", file))
        with torch.no_grad():
            upscaled_img = gen(
                test_transform(image=np.asarray(image))["image"]
                .unsqueeze(0)
                .to(DEVICE)
            )
        save_image(upscaled_img, "/kaggle/output/")
    gen.train()


In [6]:

from torch import optim
from torch.utils.tensorboard import SummaryWriter

torch.backends.cudnn.benchmark = True

def train_fn(
    loader,
    disc,
    gen,
    opt_gen,
    opt_disc,
    l1,
    vgg_loss,
    g_scaler,
    d_scaler,
    writer,
    tb_step,
):
    loop = tqdm(loader, leave=True)

    for idx, (low_res, high_res) in enumerate(loop):
        high_res = high_res.to(DEVICE)
        low_res = low_res.to(DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(low_res)
            critic_real = disc(high_res)
            critic_fake = disc(fake.detach())
            gp = gradient_penalty(disc, high_res, fake, device=DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + LAMBDA_GP * gp
            )

        opt_disc.zero_grad()
        d_scaler.scale(loss_critic).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        with torch.cuda.amp.autocast():
            l1_loss = 1e-2 * l1(fake, high_res)
            adversarial_loss = 5e-3 * -torch.mean(disc(fake))
            loss_for_vgg = vgg_loss(fake, high_res)
            gen_loss = l1_loss + loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        g_scaler.scale(gen_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        writer.add_scalar("Critic loss", loss_critic.item(), global_step=tb_step)
        tb_step += 1

        if idx % 100 == 0 and idx > 0:
            plot_examples("test_images/", gen)

        loop.set_postfix(
            gp=gp.item(),
            critic=loss_critic.item(),
            l1=l1_loss.item(),
            vgg=loss_for_vgg.item(),
            adversarial=adversarial_loss.item(),
        )

    return tb_step


def main():
    dataset = MyImageFolder(root_dir="/kaggle/input/div2k-dataset/DIV2K_train_HR")
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS,
    )
    gen = Generator(in_channels=3).to(DEVICE)
    disc = Discriminator(in_channels=3).to(DEVICE)
    initialize_weights(gen)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
    writer = SummaryWriter("logs")
    tb_step = 0
    l1 = nn.L1Loss()
    gen.train()
    disc.train()
    vgg_loss = VGGLoss()

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    #save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
    #save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN,
            gen,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC,
            disc,
            opt_disc,
            LEARNING_RATE,
        )


    for epoch in range(NUM_EPOCHS):
        print("Running :",epoch+1,"/",NUM_EPOCHS)
        tb_step = train_fn(
            loader,
            disc,
            gen,
            opt_gen,
            opt_disc,
            l1,
            vgg_loss,
            g_scaler,
            d_scaler,
            writer,
            tb_step,
        )

        if SAVE_MODEL:
            save_checkpoint(gen, opt_gen, filename="gen.pth")
            save_checkpoint(disc, opt_disc, filename="disc.pth")


if __name__ == "__main__":
    try_model = False

    if try_model:
        # Will just use pretrained weights and run on images
        # in test_images/ and save the ones to SR in saved/
        gen = Generator(in_channels=3).to(DEVICE)
        opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
        load_checkpoint(
            CHECKPOINT_GEN,
            gen,
            opt_gen,
            LEARNING_RATE,
        )
        plot_examples("test_images/", gen)
    else:
        # This will train from scratch
        main()


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 240MB/s]


=> Loading checkpoint
=> Loading checkpoint
Running : 1 / 100


100%|██████████| 50/50 [00:52<00:00,  1.06s/it, adversarial=-.0709, critic=-12.1, gp=0.152, l1=0.000527, vgg=1.55]


=> Saving checkpoint
=> Saving checkpoint
Running : 2 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0454, critic=-7, gp=0.0945, l1=0.000385, vgg=0.956]


=> Saving checkpoint
=> Saving checkpoint
Running : 3 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0573, critic=-9.57, gp=0.202, l1=0.000556, vgg=1.18]


=> Saving checkpoint
=> Saving checkpoint
Running : 4 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.054, critic=-9.13, gp=0.225, l1=0.000682, vgg=0.838]


=> Saving checkpoint
=> Saving checkpoint
Running : 5 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0416, critic=-10, gp=0.177, l1=0.000519, vgg=1.33]


=> Saving checkpoint
=> Saving checkpoint
Running : 6 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0662, critic=-10.4, gp=0.089, l1=0.000564, vgg=1.63]


=> Saving checkpoint
=> Saving checkpoint
Running : 7 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.062, critic=-11.4, gp=0.222, l1=0.00054, vgg=1.06]


=> Saving checkpoint
=> Saving checkpoint
Running : 8 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0609, critic=-9.6, gp=0.0681, l1=0.000521, vgg=1.42]


=> Saving checkpoint
=> Saving checkpoint
Running : 9 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0498, critic=-10.1, gp=0.0851, l1=0.000542, vgg=1.87]


=> Saving checkpoint
=> Saving checkpoint
Running : 10 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0388, critic=-7.9, gp=0.121, l1=0.0004, vgg=0.988]


=> Saving checkpoint
=> Saving checkpoint
Running : 11 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0293, critic=-8.31, gp=0.3, l1=0.000551, vgg=1.43]


=> Saving checkpoint
=> Saving checkpoint
Running : 12 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0434, critic=-8.57, gp=0.125, l1=0.000465, vgg=1]


=> Saving checkpoint
=> Saving checkpoint
Running : 13 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0523, critic=-7.35, gp=0.0638, l1=0.000466, vgg=0.987]


=> Saving checkpoint
=> Saving checkpoint
Running : 14 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0559, critic=-9.01, gp=0.179, l1=0.000457, vgg=1.35]


=> Saving checkpoint
=> Saving checkpoint
Running : 15 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0351, critic=-6.1, gp=0.208, l1=0.000334, vgg=0.807]


=> Saving checkpoint
=> Saving checkpoint
Running : 16 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0658, critic=-13.1, gp=0.201, l1=0.000651, vgg=1.68]


=> Saving checkpoint
=> Saving checkpoint
Running : 17 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0676, critic=-10.6, gp=0.227, l1=0.000494, vgg=1.19]


=> Saving checkpoint
=> Saving checkpoint
Running : 18 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0288, critic=-9.3, gp=0.208, l1=0.000534, vgg=1.16]


=> Saving checkpoint
=> Saving checkpoint
Running : 19 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.049, critic=-6.69, gp=0.115, l1=0.000586, vgg=0.897]


=> Saving checkpoint
=> Saving checkpoint
Running : 20 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.028, critic=-7.56, gp=0.101, l1=0.000442, vgg=1.04]


=> Saving checkpoint
=> Saving checkpoint
Running : 21 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0349, critic=-8.78, gp=0.213, l1=0.000451, vgg=1.09]


=> Saving checkpoint
=> Saving checkpoint
Running : 22 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0225, critic=-11.9, gp=0.262, l1=0.000559, vgg=1.58]


=> Saving checkpoint
=> Saving checkpoint
Running : 23 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0555, critic=-8.26, gp=0.124, l1=0.000487, vgg=0.769]


=> Saving checkpoint
=> Saving checkpoint
Running : 24 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0509, critic=-9.66, gp=0.12, l1=0.000542, vgg=1.74]


=> Saving checkpoint
=> Saving checkpoint
Running : 25 / 100


100%|██████████| 50/50 [00:36<00:00,  1.37it/s, adversarial=-.0505, critic=-9.43, gp=0.209, l1=0.000464, vgg=1.12]


=> Saving checkpoint
=> Saving checkpoint
Running : 26 / 100


100%|██████████| 50/50 [00:36<00:00,  1.38it/s, adversarial=-.0741, critic=-8.6, gp=0.129, l1=0.000445, vgg=1.18]


=> Saving checkpoint
=> Saving checkpoint
Running : 27 / 100


100%|██████████| 50/50 [00:36<00:00,  1.37it/s, adversarial=-.055, critic=-9.93, gp=0.167, l1=0.000477, vgg=1.03]


=> Saving checkpoint
=> Saving checkpoint
Running : 28 / 100


100%|██████████| 50/50 [00:36<00:00,  1.38it/s, adversarial=-.0551, critic=-7.58, gp=0.0487, l1=0.000488, vgg=1.17]


=> Saving checkpoint
=> Saving checkpoint
Running : 29 / 100


100%|██████████| 50/50 [00:36<00:00,  1.37it/s, adversarial=-.00861, critic=-11.2, gp=0.494, l1=0.000618, vgg=1.26]


=> Saving checkpoint
=> Saving checkpoint
Running : 30 / 100


100%|██████████| 50/50 [00:36<00:00,  1.37it/s, adversarial=-.0494, critic=-8.67, gp=0.181, l1=0.000433, vgg=1.06]


=> Saving checkpoint
=> Saving checkpoint
Running : 31 / 100


100%|██████████| 50/50 [00:36<00:00,  1.37it/s, adversarial=-.025, critic=-6.84, gp=0.164, l1=0.000399, vgg=0.895]


=> Saving checkpoint
=> Saving checkpoint
Running : 32 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0494, critic=-7.51, gp=0.17, l1=0.000429, vgg=1.41]


=> Saving checkpoint
=> Saving checkpoint
Running : 33 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0259, critic=-9.64, gp=0.43, l1=0.000507, vgg=1.11]


=> Saving checkpoint
=> Saving checkpoint
Running : 34 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0594, critic=-10.2, gp=0.0672, l1=0.000552, vgg=1.31]


=> Saving checkpoint
=> Saving checkpoint
Running : 35 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0528, critic=-7.6, gp=0.162, l1=0.00043, vgg=1.02]


=> Saving checkpoint
=> Saving checkpoint
Running : 36 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.027, critic=-8.29, gp=0.182, l1=0.000486, vgg=1.05]


=> Saving checkpoint
=> Saving checkpoint
Running : 37 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.032, critic=-9.91, gp=0.372, l1=0.000489, vgg=2.07]


=> Saving checkpoint
=> Saving checkpoint
Running : 38 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0773, critic=-9.44, gp=0.168, l1=0.00051, vgg=1.45]


=> Saving checkpoint
=> Saving checkpoint
Running : 39 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0525, critic=-10.2, gp=0.161, l1=0.000511, vgg=1.1]


=> Saving checkpoint
=> Saving checkpoint
Running : 40 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0121, critic=-6.21, gp=0.611, l1=0.000502, vgg=1.16]


=> Saving checkpoint
=> Saving checkpoint
Running : 41 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0502, critic=-8.16, gp=0.181, l1=0.000371, vgg=0.745]


=> Saving checkpoint
=> Saving checkpoint
Running : 42 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0352, critic=-8.72, gp=0.152, l1=0.000474, vgg=0.89]


=> Saving checkpoint
=> Saving checkpoint
Running : 43 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0657, critic=-7.26, gp=0.0925, l1=0.000374, vgg=1.18]


=> Saving checkpoint
=> Saving checkpoint
Running : 44 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0143, critic=-7.19, gp=0.35, l1=0.000487, vgg=1.08]


=> Saving checkpoint
=> Saving checkpoint
Running : 45 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0367, critic=-5.31, gp=0.469, l1=0.000374, vgg=0.968]


=> Saving checkpoint
=> Saving checkpoint
Running : 46 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0447, critic=-8.88, gp=0.436, l1=0.000423, vgg=0.853]


=> Saving checkpoint
=> Saving checkpoint
Running : 47 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0454, critic=-7.32, gp=0.12, l1=0.000436, vgg=1.13]


=> Saving checkpoint
=> Saving checkpoint
Running : 48 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0402, critic=-6.81, gp=0.196, l1=0.000468, vgg=1.1]


=> Saving checkpoint
=> Saving checkpoint
Running : 49 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0341, critic=-8.77, gp=0.263, l1=0.000436, vgg=0.903]


=> Saving checkpoint
=> Saving checkpoint
Running : 50 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0249, critic=-7.5, gp=0.45, l1=0.000472, vgg=0.979]


=> Saving checkpoint
=> Saving checkpoint
Running : 51 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0349, critic=-6.81, gp=0.168, l1=0.00043, vgg=0.803]


=> Saving checkpoint
=> Saving checkpoint
Running : 52 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0137, critic=-5.74, gp=0.102, l1=0.000438, vgg=1.04]


=> Saving checkpoint
=> Saving checkpoint
Running : 53 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0399, critic=-7.98, gp=0.153, l1=0.000418, vgg=0.867]


=> Saving checkpoint
=> Saving checkpoint
Running : 54 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0463, critic=-9.23, gp=0.102, l1=0.000485, vgg=1.14]


=> Saving checkpoint
=> Saving checkpoint
Running : 55 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0565, critic=-10.1, gp=0.26, l1=0.000577, vgg=1.72]


=> Saving checkpoint
=> Saving checkpoint
Running : 56 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0441, critic=-10.5, gp=0.275, l1=0.000492, vgg=1.17]


=> Saving checkpoint
=> Saving checkpoint
Running : 57 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0362, critic=-7.72, gp=0.0775, l1=0.000409, vgg=0.818]


=> Saving checkpoint
=> Saving checkpoint
Running : 58 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0457, critic=-5.46, gp=0.125, l1=0.000405, vgg=0.73]


=> Saving checkpoint
=> Saving checkpoint
Running : 59 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0471, critic=-6.35, gp=0.0678, l1=0.00038, vgg=0.788]


=> Saving checkpoint
=> Saving checkpoint
Running : 60 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0288, critic=-6.47, gp=0.131, l1=0.000357, vgg=0.922]


=> Saving checkpoint
=> Saving checkpoint
Running : 61 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0255, critic=-8.67, gp=0.0733, l1=0.000485, vgg=0.847]


=> Saving checkpoint
=> Saving checkpoint
Running : 62 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0249, critic=-8.42, gp=0.146, l1=0.000444, vgg=1.14]


=> Saving checkpoint
=> Saving checkpoint
Running : 63 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0613, critic=-6.07, gp=0.0239, l1=0.00041, vgg=0.589]


=> Saving checkpoint
=> Saving checkpoint
Running : 64 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0416, critic=-8.64, gp=0.263, l1=0.00045, vgg=1.33]


=> Saving checkpoint
=> Saving checkpoint
Running : 65 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0162, critic=-8.77, gp=0.246, l1=0.000533, vgg=1.1]


=> Saving checkpoint
=> Saving checkpoint
Running : 66 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.00695, critic=-10, gp=0.216, l1=0.000533, vgg=1.89]


=> Saving checkpoint
=> Saving checkpoint
Running : 67 / 100


100%|██████████| 50/50 [00:36<00:00,  1.35it/s, adversarial=-.0223, critic=-9.47, gp=0.315, l1=0.000404, vgg=1.22]


=> Saving checkpoint
=> Saving checkpoint
Running : 68 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0558, critic=-8.89, gp=0.0979, l1=0.000441, vgg=0.919]


=> Saving checkpoint
=> Saving checkpoint
Running : 69 / 100


100%|██████████| 50/50 [00:36<00:00,  1.36it/s, adversarial=-.0567, critic=-10.4, gp=0.174, l1=0.000525, vgg=1.19]


=> Saving checkpoint
=> Saving checkpoint
Running : 70 / 100


100%|██████████| 50/50 [00:38<00:00,  1.31it/s, adversarial=-.0551, critic=-8.9, gp=0.176, l1=0.000441, vgg=1.6]


=> Saving checkpoint
=> Saving checkpoint
Running : 71 / 100


100%|██████████| 50/50 [00:37<00:00,  1.32it/s, adversarial=-.0652, critic=-8.09, gp=0.0882, l1=0.000413, vgg=1.29]


=> Saving checkpoint
=> Saving checkpoint
Running : 72 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0491, critic=-12.3, gp=0.3, l1=0.000653, vgg=0.979]


=> Saving checkpoint
=> Saving checkpoint
Running : 73 / 100


100%|██████████| 50/50 [00:37<00:00,  1.32it/s, adversarial=-.0478, critic=-7.96, gp=0.11, l1=0.000449, vgg=0.841]


=> Saving checkpoint
=> Saving checkpoint
Running : 74 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0381, critic=-7.71, gp=0.225, l1=0.000461, vgg=0.896]


=> Saving checkpoint
=> Saving checkpoint
Running : 75 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0561, critic=-10.8, gp=0.298, l1=0.000504, vgg=1.06]


=> Saving checkpoint
=> Saving checkpoint
Running : 76 / 100


100%|██████████| 50/50 [00:37<00:00,  1.32it/s, adversarial=-.0271, critic=-6.72, gp=0.0558, l1=0.000339, vgg=0.892]


=> Saving checkpoint
=> Saving checkpoint
Running : 77 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0245, critic=-6.87, gp=0.275, l1=0.00044, vgg=1.24]


=> Saving checkpoint
=> Saving checkpoint
Running : 78 / 100


100%|██████████| 50/50 [00:37<00:00,  1.32it/s, adversarial=-.0341, critic=-8.65, gp=0.271, l1=0.000451, vgg=1.12]


=> Saving checkpoint
=> Saving checkpoint
Running : 79 / 100


100%|██████████| 50/50 [00:37<00:00,  1.32it/s, adversarial=-.0194, critic=-7.46, gp=0.126, l1=0.000464, vgg=0.892]


=> Saving checkpoint
=> Saving checkpoint
Running : 80 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0199, critic=-6.55, gp=0.172, l1=0.000405, vgg=0.791]


=> Saving checkpoint
=> Saving checkpoint
Running : 81 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.0518, critic=-8.67, gp=0.137, l1=0.00044, vgg=1.36]


=> Saving checkpoint
=> Saving checkpoint
Running : 82 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0402, critic=-5.83, gp=0.0628, l1=0.000379, vgg=0.543]


=> Saving checkpoint
=> Saving checkpoint
Running : 83 / 100


100%|██████████| 50/50 [00:37<00:00,  1.35it/s, adversarial=-.0598, critic=-10.9, gp=0.0679, l1=0.000606, vgg=1.22]


=> Saving checkpoint
=> Saving checkpoint
Running : 84 / 100


100%|██████████| 50/50 [00:37<00:00,  1.32it/s, adversarial=-.053, critic=-9.33, gp=0.25, l1=0.000477, vgg=1.21]


=> Saving checkpoint
=> Saving checkpoint
Running : 85 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0444, critic=-7.52, gp=0.163, l1=0.000422, vgg=0.771]


=> Saving checkpoint
=> Saving checkpoint
Running : 86 / 100


100%|██████████| 50/50 [00:37<00:00,  1.34it/s, adversarial=-.037, critic=-8.14, gp=0.125, l1=0.000463, vgg=1.31]


=> Saving checkpoint
=> Saving checkpoint
Running : 87 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0685, critic=-6.85, gp=0.102, l1=0.000413, vgg=1.4]


=> Saving checkpoint
=> Saving checkpoint
Running : 88 / 100


100%|██████████| 50/50 [00:38<00:00,  1.31it/s, adversarial=-.0332, critic=-6.03, gp=0.067, l1=0.00037, vgg=0.604]


=> Saving checkpoint
=> Saving checkpoint
Running : 89 / 100


100%|██████████| 50/50 [00:37<00:00,  1.32it/s, adversarial=-.0504, critic=-6.5, gp=0.105, l1=0.000354, vgg=1.04]


=> Saving checkpoint
=> Saving checkpoint
Running : 90 / 100


100%|██████████| 50/50 [00:38<00:00,  1.30it/s, adversarial=-.0558, critic=-10.4, gp=0.201, l1=0.000511, vgg=0.911]


=> Saving checkpoint
=> Saving checkpoint
Running : 91 / 100


100%|██████████| 50/50 [00:39<00:00,  1.27it/s, adversarial=-.0344, critic=-12.4, gp=0.271, l1=0.00059, vgg=1.45]


=> Saving checkpoint
=> Saving checkpoint
Running : 92 / 100


100%|██████████| 50/50 [00:38<00:00,  1.31it/s, adversarial=-.0419, critic=-6.14, gp=0.149, l1=0.00041, vgg=1.23]


=> Saving checkpoint
=> Saving checkpoint
Running : 93 / 100


100%|██████████| 50/50 [00:39<00:00,  1.27it/s, adversarial=-.0722, critic=-9.9, gp=0.0875, l1=0.000551, vgg=1.32]


=> Saving checkpoint
=> Saving checkpoint
Running : 94 / 100


100%|██████████| 50/50 [00:39<00:00,  1.27it/s, adversarial=-.0591, critic=-9.87, gp=0.161, l1=0.000538, vgg=1.27]


=> Saving checkpoint
=> Saving checkpoint
Running : 95 / 100


100%|██████████| 50/50 [00:39<00:00,  1.28it/s, adversarial=-.046, critic=-6.93, gp=0.173, l1=0.000399, vgg=1.05]


=> Saving checkpoint
=> Saving checkpoint
Running : 96 / 100


100%|██████████| 50/50 [00:39<00:00,  1.28it/s, adversarial=-.0562, critic=-8.63, gp=0.0818, l1=0.00044, vgg=1.08]


=> Saving checkpoint
=> Saving checkpoint
Running : 97 / 100


100%|██████████| 50/50 [00:38<00:00,  1.29it/s, adversarial=-.0375, critic=-6.93, gp=0.0589, l1=0.000361, vgg=0.935]


=> Saving checkpoint
=> Saving checkpoint
Running : 98 / 100


100%|██████████| 50/50 [00:38<00:00,  1.29it/s, adversarial=-.0403, critic=-8.24, gp=0.164, l1=0.000505, vgg=1.11]


=> Saving checkpoint
=> Saving checkpoint
Running : 99 / 100


100%|██████████| 50/50 [00:38<00:00,  1.30it/s, adversarial=-.0306, critic=-9.54, gp=0.221, l1=0.00048, vgg=1.07]


=> Saving checkpoint
=> Saving checkpoint
Running : 100 / 100


100%|██████████| 50/50 [00:37<00:00,  1.33it/s, adversarial=-.0361, critic=-10.7, gp=0.111, l1=0.000529, vgg=1.27]


=> Saving checkpoint
=> Saving checkpoint


In [7]:
gen = Generator(in_channels=3).to(DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
load_checkpoint(
    CHECKPOINT_GEN,
    gen,
    opt_gen,
    LEARNING_RATE,
)
gen.eval()

image=Image.open("/kaggle/input/test-images/woman_LR.png")
with torch.no_grad():
    upscaled_img = gen(
        test_transform(image=np.asarray(image))["image"]
        .unsqueeze(0)
        .to(DEVICE)
    )
save_image(upscaled_img, "/kaggle/working/woman.png")
#gen.train()

=> Loading checkpoint
