In [3]:
import torch
from torch import nn
import requests
import os
import config
from torch import optim
from utils import load_checkpoint, save_checkpoint, plot_examples
from loss import VGGLoss
from torch.utils.data import DataLoader
from model import Generator, Discriminator
from tqdm import tqdm
from dataset import MyImageFolder

In [None]:
def download_file(url, save_path):
    response = requests.get(url, stream=True)
    with open(save_path, 'wb') as out_file:
        for chunk in response.iter_content(chunk_size=8192):
            out_file.write(chunk)

# Define the URLs and the paths where you want to save the datasets
urls = [
    "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip",
    "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip",
    "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic/X4/DIV2K_train_LR_bicubic_X4.zip",
    "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic/X4/DIV2K_valid_LR_bicubic_X4.zip"

]

save_paths = [
    "DIV2K_train_HR.zip",
    "DIV2K_valid_HR.zip",
    "DIV2K_train_LR_bicubic_X4.zip",
    "DIV2K_valid_LR_bicubic_X4.zip"
]

def test():
    for url, save_path in zip(urls, save_paths):
      # print(f"Downloading {url}...")
      download_file(url, save_path)
      # print(f"Saved to {save_path}")

      print("Download complete!")

    !unzip /content/DIV2K_train_HR.zip -d /content/DIV2K_train_HR
    dataset = MyImageFolder(root_dir = "/content/DIV2K_train_HR/")
    loader = DataLoader(dataset,batch_size = 1, num_workers = 8)

    for low_res,high_res in loader:
      print(low_res.shape)
      print(high_res.shape)
test()

TRAIN THE NETWORK

In [18]:
torch.backends.cudnn.benchmark = True


def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
    loop = tqdm(loader, leave=True)

    for idx, (low_res, high_res) in enumerate(loop):
        high_res = high_res.to(config.DEVICE)
        low_res = low_res.to(config.DEVICE)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        fake = gen(low_res)
        disc_real = disc(high_res)
        disc_fake = disc(fake.detach())
        disc_loss_real = bce(
            disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
        )
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = disc_loss_fake + disc_loss_real

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        disc_fake = disc(fake)
        #l2_loss = mse(fake, high_res)
        adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
        loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
        gen_loss = loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        if idx % 200 == 0:
            plot_examples("test_images/", gen)


def main():
    dataset = MyImageFolder(root_dir="/content/DIV2K_train_HR/")
    loader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=config.NUM_WORKERS,
    )
    gen = Generator(inChannels=3).to(config.DEVICE)
    disc = Discriminator(inChannels=3).to(config.DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.9, 0.999))
    mse = nn.MSELoss()
    bce = nn.BCEWithLogitsLoss()
    vgg_loss = VGGLoss()


    # if config.LOAD_MODEL:
    #     load_checkpoint(
    #         config.CHECKPOINT_GEN,
    #         gen,
    #         opt_gen,
    #         config.LEARNING_RATE,
    #     )
    #     load_checkpoint(
    #        config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
    #     )

    for epoch in range(config.NUM_EPOCHS):
        train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)
        print("EPOCH : ",epoch)
        print("MSE Loss : ", mse)
        print("VGG_loss:",vgg_loss)

        # if config.SAVE_MODEL:
        #     save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
        #     save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)




In [19]:
main()

100%|██████████| 50/50 [01:05<00:00,  1.31s/it]


EPOCH :  0
MSE Loss :  MSELoss()
VGG_loss: VGGLoss(
  (vgg): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 25

 26%|██▌       | 13/50 [00:21<01:00,  1.63s/it]


KeyboardInterrupt: ignored

In [23]:




# Download the datasets



Downloading http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip...
Saved to DIV2K_train_HR.zip
Downloading http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip...
Saved to DIV2K_valid_HR.zip
Downloading http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic/X4/DIV2K_train_LR_bicubic_X4.zip...
Saved to DIV2K_train_LR_bicubic_X4.zip
Downloading http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic/X4/DIV2K_valid_LR_bicubic_X4.zip...
Saved to DIV2K_valid_LR_bicubic_X4.zip
Download complete!
