In [2]:
import os
import matplotlib.pyplot as plt

from loadData import DIV2K
from model import Generator, Discriminator
from trainer import SrganTrainer, SrganGeneratorTrainer

%matplotlib inline

In [4]:
# Download and extract the high-resolution training dataset
if not os.path.exists('.div2k/images/DIV2K_train_HR'):
    DIV2K.download_archive('DIV2K_train_HR.zip', '.div2k/images')
if not os.path.exists('.div2k/images/DIV2K_train_LR_x8'):
    DIV2K.download_archive('DIV2K_train_LR_x8.zip', '.div2k/images')
if not os.path.exists('.div2k/images/DIV2K_valid_HR'):
    DIV2K.download_archive('DIV2K_valid_HR.zip', '.div2k/images')
if not os.path.exists('.div2k/images/DIV2K_valid_HR'):
    DIV2K.download_archive('DIV2K_valid_LR_x8.zip', '.div2k/images')

In [5]:
# Assuming you have the DIV2K class implemented or imported
div2k_train = DIV2K(scale=8, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=8, subset='valid', downgrade='bicubic')


Caching dataset to .div2k/caches/DIV2K_train_HR.cache ...
Dataset cached to .div2k/caches/DIV2K_train_HR.cache.
Caching dataset to .div2k/caches/DIV2K_train_LR_x8_X8.cache ...
Dataset cached to .div2k/caches/DIV2K_train_LR_x8_X8.cache.
Caching dataset to .div2k/caches/DIV2K_valid_HR.cache ...
Dataset cached to .div2k/caches/DIV2K_valid_HR.cache.
Caching dataset to .div2k/caches/DIV2K_valid_LR_x8_X8.cache ...
Dataset cached to .div2k/caches/DIV2K_valid_LR_x8_X8.cache.


In [6]:
# Check if high-resolution images are downloaded
hr_dir = div2k_train._hr_images_dir()
lr_dir = div2k_train._lr_images_dir()

print(f"HR Directory Exists: {os.path.exists(hr_dir)}")
print(f"LR Directory Exists: {os.path.exists(lr_dir)}")

# Manually trigger the download if needed
if not os.path.exists(hr_dir):
    DIV2K.download_archive(div2k_train._hr_images_archive(), div2k_train.images_dir, extract=True)

if not os.path.exists(lr_dir):
    DIV2K.download_archive(div2k_train._lr_images_archive(), div2k_train.images_dir, extract=True)


HR Directory Exists: True
LR Directory Exists: True


In [None]:
div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

In [7]:
from torch.utils.data import DataLoader

# DataLoader for the training dataset with random transformations
train_loader = DataLoader(dataset=div2k_train,
                          batch_size=16,
                          shuffle=True,  # Shuffles the data for each epoch
                          num_workers=4,  # Number of subprocesses to use for data loading
                          pin_memory=True)  # Copies Tensors into CUDA pinned memory before returning them

# DataLoader for the validation dataset
valid_loader = DataLoader(dataset=div2k_valid,
                          batch_size=16,
                          shuffle=False,  # No need to shuffle validation data
                          num_workers=4,
                          pin_memory=True)


In [8]:
# Initialize the generator model
generator = Generator()

# Initialize the trainer for the generator
pre_trainer = SrganGeneratorTrainer(model=generator, checkpoint_dir='./ckpt/pre_generator')


In [None]:
# Train the generator
pre_trainer.train(train_loader, 
                  valid_loader, 
                  steps=1000000, 
                  evaluate_every=1000, 
                  save_best_only=False)
# Save the model weights
torch.save(generator.state_dict(), 'pre_generator.pth')
