In [1]:
import json
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from model.ESRGAN import RRDBNet, Discriminator_VGG, train_esrgan, VGGFeatureExtractor, EfficientNetLiteDiscriminator
from model.ResNetSR import ResNetSR
from data_utils.dataset import LocalImageDataset, get_split_indices
from model.utils import get_device, train_model_single_epoch, validate_model_single_epoch, save_checkpoint, save_samples, CombinedLoss

In [2]:
import urllib.request
import tarfile
import os

# Create directory for checkpoints if it doesn't exist
os.makedirs("checkpoints", exist_ok=True)

# Download and extract checkpoint
checkpoint_url = "https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/lite/efficientnet-lite0.tar.gz"
checkpoint_path = "checkpoints/efficientnet-lite0.tar.gz"

if not os.path.exists(checkpoint_path):
    print("Downloading EfficientNet-Lite0 checkpoint...")
    urllib.request.urlretrieve(checkpoint_url, checkpoint_path)
    
    print("Extracting checkpoint...")
    with tarfile.open(checkpoint_path, 'r:gz') as tar:
        tar.extractall(path="checkpoints")
    print("Checkpoint downloaded and extracted")
else:
    print("Checkpoint already exists")

Checkpoint already exists


In [3]:
generator = ResNetSR(upscale_factor=2, num_res_blocks=2, num_channels=1, num_features=32)
# discriminator = Discriminator_VGG()
discriminator = EfficientNetLiteDiscriminator()
vgg_extractor = '_'



In [4]:
with open("data/test_indices.json", "r") as f:
    test_indices = json.load(f)
len(test_indices)

594

In [5]:
INPUT_DIR = "data/resolution_128"
TARGET_DIR = "data/resolution_256"
NUM_EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 0.01

GRAD_CLIP = 1
EDGE_WEIGHT = 0.3
PSNR_WEIGHT = 0.3

history = {
    "train_loss": [],
    "train_psnr": [],
    "val_loss": [],
    "val_psnr": []
}

In [6]:
num_images = len([f for f in os.listdir(INPUT_DIR) if f.endswith(('.jpg', '.png', '.jpeg', '.webp'))])
train_indices, val_indices = get_split_indices(num_images, test_indices)

In [7]:
train_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, train_indices), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, val_indices), batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(LocalImageDataset(INPUT_DIR, TARGET_DIR, test_indices), batch_size=BATCH_SIZE, shuffle=False)

In [8]:
g_optimizer = optim.AdamW(generator.parameters(), lr=LEARNING_RATE)
d_optimizer = optim.AdamW(discriminator.parameters(), lr=LEARNING_RATE)

In [None]:
train_esrgan(
    generator,
    discriminator,
    vgg_extractor,
    train_loader,
    val_loader,
    g_optimizer,
    d_optimizer,
    num_epochs=10,
    device="cpu",
)

Training:   1%|▏         | 1/67 [01:26<1:35:14, 86.58s/it]

G_loss: 0.1410 | D_loss: 0.6929 | G_lr: 0.010000 | D_lr: 0.010000

Training:   3%|▎         | 2/67 [02:50<1:32:18, 85.20s/it]

G_loss: 1.8100 | D_loss: 0.6911 | G_lr: 0.010000 | D_lr: 0.010000