### Data Download

In [1]:
# RUN ONLY IF YOU DONT HAVE DATA SET YET

from zipfile import ZipFile
import os
import requests

urls = [
    "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip",
    "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
]

data_dir = "../loppuprojekti/data"
os.makedirs(data_dir, exist_ok=True)

for url in urls:
    filename = os.path.join(data_dir, os.path.basename(url))
    # Download zip file
    print(f"Downloading {url}...")
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print("Download ready:", filename)
    # Exctract zip file
    print(f"Extracting {filename}...")
    with ZipFile(filename, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    print("Extracted files!")

print("Data pulling ready:", data_dir)

Downloading http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip...
Download ready: ../loppuprojekti/data/DIV2K_valid_HR.zip
Extracting ../loppuprojekti/data/DIV2K_valid_HR.zip...
Extracted files!
Downloading http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip...
Download ready: ../loppuprojekti/data/DIV2K_train_HR.zip
Extracting ../loppuprojekti/data/DIV2K_train_HR.zip...
Extracted files!
Data pulling ready: ../loppuprojekti/data


### Data Preparation

In [None]:
import random
import shutil

train = "../loppuprojekti/data/DIV2K_train_HR/"
validation = "../loppuprojekti/data/DIV2K_valid_HR/"
test = "../loppuprojekti/data/DIV2K_test_HR"

os.makedirs(test, exist_ok=True)

train_files = os.listdir(train)

num_test_files = int(len(train_files) * 0.15)

files_to_move = random.sample(train_files, num_test_files)

if not os.listdir(test):
    for file in files_to_move:
        src = os.path.join(train, file)
        dst = os.path.join(test, file)
        shutil.move(src, dst)

    print(f"Moved {len(files_to_move)} files to test directory")

else:
    print("Test files already exist")

In [None]:
import tensorflow as tf

train_files = [
    os.path.join(train, fname)
    for fname in os.listdir(train)
    if fname.lower().endswith((".png", ".jpg", ".jpeg"))
]
val_files = [
    os.path.join(validation, fname)
    for fname in os.listdir(validation)
    if fname.lower().endswith((".png", ".jpg", ".jpeg"))
]
test_files = [
    os.path.join(test, fname)
    for fname in os.listdir(test)
    if fname.lower().endswith((".png", ".jpg", ".jpeg"))
]


def load_image(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img


crop_size = 96
scale = 2


def random_crop_and_downscale(hr_img):
    hr_patch = tf.image.random_crop(hr_img, size=[crop_size, crop_size, 3])
    lr_patch = tf.image.resize(
        hr_patch, [crop_size // scale, crop_size // scale], method="area"
    )
    return lr_patch, hr_patch


# Build datasets
train_ds = tf.data.Dataset.from_tensor_slices(train_files).map(
    load_image, num_parallel_calls=tf.data.AUTOTUNE
)
val_ds = tf.data.Dataset.from_tensor_slices(val_files).map(
    load_image, num_parallel_calls=tf.data.AUTOTUNE
)
test_ds = tf.data.Dataset.from_tensor_slices(test_files).map(
    load_image, num_parallel_calls=tf.data.AUTOTUNE
)

# Apply random cropping and downscaling
train_sr = train_ds.map(random_crop_and_downscale, num_parallel_calls=tf.data.AUTOTUNE)
val_sr = val_ds.map(random_crop_and_downscale, num_parallel_calls=tf.data.AUTOTUNE)
test_sr = test_ds.map(random_crop_and_downscale, num_parallel_calls=tf.data.AUTOTUNE)

# Example: print shapes
for lr, hr in train_sr.take(1):
    print("LR patch shape:", lr.shape)
    print("HR patch shape:", hr.shape)


In [None]:
import matplotlib.pyplot as plt

# Get one LR-HR pair from the dataset
for lr, hr in train_sr.take(1):
    lr_img = lr.numpy()
    hr_img = hr.numpy()

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.title("Low-Resolution")
    plt.imshow(lr_img)
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title("High-Resolution")
    plt.imshow(hr_img)
    plt.axis("off")

    plt.show()