In [7]:
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.models import vgg19

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN = "/content/drive/MyDrive/Models/last_trained/srgan_gen.pth.tar"
CHECKPOINT_DISC = "/content/drive/MyDrive/Models/last_trained/srgan_disc.pth.tar"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 1000
BATCH_SIZE = 8
NUM_WORKERS = 4
HIGH_RES = 96
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3

In [10]:
both_transforms = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

highres_transform = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

lowres_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

In [11]:
class MyImageFolder(Dataset):
    def __init__(self, root_dir):
        super(MyImageFolder, self).__init__()
        self.data = []
        self.root_dir = root_dir
        self.class_names = os.listdir(root_dir)

        for index, name in enumerate(self.class_names):
            files = os.listdir(os.path.join(root_dir, name))
            self.data += list(zip(files, [index] * len(files)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_file, label = self.data[index]
        root_and_dir = os.path.join(self.root_dir, self.class_names[label])

        image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
        image = both_transforms(image=image)["image"]
        high_res = highres_transform(image=image)["image"]
        low_res = lowres_transform(image=image)["image"]
        return low_res, high_res

In [12]:
def test():
    dataset = MyImageFolder(root_dir="/content/drive/MyDrive/Datasets/srgan/train_images")
    loader = DataLoader(dataset, batch_size=1, num_workers=8)

    for low_res, high_res in loader:
        print(low_res.shape)
        print(high_res.shape)
        break
test()

torch.Size([1, 3, 24, 24])
torch.Size([1, 3, 96, 96])


In [13]:
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(DEVICE)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)

In [14]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        discriminator=False,
        use_act=True,
        use_bn=True,
        **kwargs,
    ):
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True)
            if discriminator
            else nn.PReLU(num_parameters=out_channels)
        )

    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))

In [15]:
class UpsampleBlock(nn.Module):
    def __init__(self, in_c, scale_factor):
        super().__init__()
        self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
        self.ps = nn.PixelShuffle(scale_factor)  # in_c * 4, H, W --> in_c, H*2, W*2
        self.act = nn.PReLU(num_parameters=in_c)

    def forward(self, x):
        return self.act(self.ps(self.conv(x)))

In [16]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False,
        )

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x

In [17]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
        self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = self.convblock(x) + initial
        x = self.upsamples(x)
        return torch.tanh(self.final(x))

In [18]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    discriminator=True,
                    use_act=True,
                    use_bn=False if idx == 0 else True,
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            # nn.Sigmoid()
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)

In [19]:
def test():
    low_resolution = 24  # 96x96 -> 24x24
    with torch.cuda.amp.autocast():
        x = torch.randn((5, 3, low_resolution, low_resolution))
        gen = Generator()
        gen_out = gen(x)
        disc = Discriminator()
        disc_out = disc(gen_out)

        print(gen_out.shape)
        print(disc_out.shape)

test()

torch.Size([5, 3, 96, 96])
torch.Size([5, 1])


In [20]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [21]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [22]:
def plot_examples(low_res_folder, gen):
    files = os.listdir(low_res_folder)

    gen.eval()
    for file in files:
        image = Image.open("/content/drive/MyDrive/Datasets/srgan/test_images/" + file)
        with torch.no_grad():
            upscaled_img = gen(
                test_transform(image=np.asarray(image))["image"]
                .unsqueeze(0)
                .to(DEVICE)
            )
        save_image(upscaled_img * 0.5 + 0.5, f"/content/drive/MyDrive/evaluation/{file}")
    gen.train()

In [23]:
def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
    loop = tqdm(loader, leave=True)

    for idx, (low_res, high_res) in enumerate(loop):
        high_res = high_res.to(DEVICE)
        low_res = low_res.to(DEVICE)
        
        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        fake = gen(low_res)
        disc_real = disc(high_res)
        disc_fake = disc(fake.detach())
        disc_loss_real = bce(
            disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
        )
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = disc_loss_fake + disc_loss_real

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        disc_fake = disc(fake)
        #l2_loss = mse(fake, high_res)
        adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
        loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
        gen_loss = loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()
        if idx == 0:
            plot_examples("/content/drive/MyDrive/Datasets/srgan/test_images/", gen)

In [None]:
def main():
    dataset = MyImageFolder(root_dir="/content/drive/MyDrive/Datasets/srgan/train_images/")
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS,
    )
    gen = Generator(in_channels=3).to(DEVICE)
    disc = Discriminator(in_channels=3).to(DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    mse = nn.MSELoss()
    bce = nn.BCEWithLogitsLoss()
    vgg_loss = VGGLoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN,
            gen,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
           CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )

    for epoch in range(NUM_EPOCHS):
        train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)

        if SAVE_MODEL:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

        print(f"Epoch[{epoch}/{NUM_EPOCHS}]")


main()

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

100%|██████████| 63/63 [01:18<00:00,  1.25s/it]


=> Saving checkpoint
=> Saving checkpoint
Epoch[0/1000]


100%|██████████| 63/63 [00:57<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[1/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[2/1000]


100%|██████████| 63/63 [00:54<00:00,  1.16it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[3/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[4/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[5/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[6/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[7/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[8/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[9/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[10/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[11/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[12/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[13/1000]


100%|██████████| 63/63 [00:57<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[14/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[15/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[16/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[17/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[18/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[19/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[20/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[21/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[22/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[23/1000]


100%|██████████| 63/63 [00:57<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[24/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[25/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[26/1000]


100%|██████████| 63/63 [00:59<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[27/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[28/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[29/1000]


100%|██████████| 63/63 [00:57<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[30/1000]


100%|██████████| 63/63 [00:55<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[31/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[32/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[33/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[34/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[35/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[36/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[37/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[38/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[39/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[40/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[41/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[42/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[43/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[44/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[45/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[46/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[47/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[48/1000]


100%|██████████| 63/63 [00:58<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[49/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[50/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[51/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[52/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[53/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[54/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[55/1000]


100%|██████████| 63/63 [00:59<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[56/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[57/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[58/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[59/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[60/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[61/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[62/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[63/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[64/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[65/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[66/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[67/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[68/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[69/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[70/1000]


100%|██████████| 63/63 [00:54<00:00,  1.16it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[71/1000]


100%|██████████| 63/63 [00:58<00:00,  1.08it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[72/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[73/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[74/1000]


100%|██████████| 63/63 [00:59<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[75/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[76/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[77/1000]


100%|██████████| 63/63 [00:59<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[78/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[79/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[80/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[81/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[82/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[83/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[84/1000]


100%|██████████| 63/63 [00:59<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[85/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[86/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[87/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[88/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[89/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[90/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[91/1000]


100%|██████████| 63/63 [00:57<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[92/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[93/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[94/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[95/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[96/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[97/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[98/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[99/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[100/1000]


100%|██████████| 63/63 [00:58<00:00,  1.08it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[101/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[102/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[103/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[104/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[105/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[106/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[107/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[108/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[109/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[110/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[111/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[112/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[113/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[114/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[115/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[116/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[117/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[118/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[119/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[120/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[121/1000]


100%|██████████| 63/63 [00:57<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[122/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[123/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[124/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[125/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[126/1000]


100%|██████████| 63/63 [00:54<00:00,  1.15it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[127/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[128/1000]


100%|██████████| 63/63 [00:56<00:00,  1.12it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[129/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[130/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[131/1000]


100%|██████████| 63/63 [00:58<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[132/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[133/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[134/1000]


100%|██████████| 63/63 [00:58<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[135/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[136/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[137/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[138/1000]


100%|██████████| 63/63 [00:58<00:00,  1.08it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[139/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[140/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[141/1000]


100%|██████████| 63/63 [00:57<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[142/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[143/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[144/1000]


100%|██████████| 63/63 [00:59<00:00,  1.07it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[145/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[146/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[147/1000]


100%|██████████| 63/63 [00:58<00:00,  1.09it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[148/1000]


100%|██████████| 63/63 [00:55<00:00,  1.14it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[149/1000]


100%|██████████| 63/63 [00:55<00:00,  1.13it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[150/1000]


100%|██████████| 63/63 [00:56<00:00,  1.11it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[151/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[152/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[153/1000]


100%|██████████| 63/63 [00:57<00:00,  1.10it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[154/1000]


100%|██████████| 63/63 [00:59<00:00,  1.06it/s]


=> Saving checkpoint
=> Saving checkpoint
Epoch[155/1000]


 27%|██▋       | 17/63 [00:17<00:31,  1.47it/s]

In [None]:
gen = Generator(in_channels=3).to(DEVICE)
disc = Discriminator(in_channels=3).to(DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))

load_checkpoint(
            CHECKPOINT_GEN,
            gen,
            opt_gen,
            LEARNING_RATE,
        )
load_checkpoint(
CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
)

In [None]:
files = os.listdir('/content/drive/MyDrive/Datasets/srgan/test_images/')

# gen.eval()
for file in files:
    # print(file)
    image = Image.open("/content/drive/MyDrive/Datasets/srgan/test_images/" + file)
    with torch.no_grad():
        upscaled_img = gen(
            test_transform(image=np.asarray(image))["image"]
            .unsqueeze(0)
            .to(DEVICE)
        )
    # print(upscaled_img*0.5 + 0.5)
    # break
    save_image(upscaled_img*0.5 + 0.5, f"{file}")
# gen.train()

In [None]:
print(image)