In [29]:
import torch.nn.functional as F
from torchvision.models import vgg19

import torch
import torch.nn as nn
import torch.optim as optim
import random

import torchvision as tv
from torchvision import transforms

import matplotlib.pyplot as plt
import cv2

from torchvision.transforms import InterpolationMode
import csv
import os
from RLFN.rlfn import RLFN_S
from utils import tensor2uint, tensor_to_uint8, uint2tensor4
from choose_device import choose_device


In [30]:
# %env PYTORCH_ENABLE_MPS_FALLBACK=1

In [31]:
%env

{'PATH': '/Users/lukaszbala/projects/super-res-in-face-images/venv/bin:/Users/lukaszbala/miniforge3/bin:/Users/lukaszbala/miniforge3/condabin:/Users/lukaszbala/.nvm/versions/node/v19.7.0/bin:/opt/homebrew/bin:/opt/homebrew/sbin:/usr/local/bin:/System/Cryptexes/App/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/local/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/appleinternal/bin:/Library/Apple/usr/bin:/Applications/VMware Fusion.app/Contents/Public:/Users/lukaszbala/Library/Application Support/JetBrains/Toolbox/scripts',
 'GSETTINGS_SCHEMA_DIR_CONDA_BACKUP': '',
 'XML_CATALOG_FILES': 'file:///Users/lukaszbala/miniforge3/etc/xml/catalog file:///etc/xml/catalog',
 'CONDA_DEFAULT_ENV': 'base',
 'CONDA_EXE': '/Users/lukaszbala/miniforge3/bin/conda',
 'CONDA_PYTHON_EXE': '/Users/lukaszbala/miniforge3/bin/python',
 'HOMEBREW_PREFIX': '/opt/homebrew'

In [32]:
def save_losses_to_csv(training_losses, validation_losses, filename):
    """
    Save training and validation losses to a CSV file.

    Args:
        training_losses (list of float): List of training losses for each epoch.
        validation_losses (list of float): List of validation losses for each epoch.
        filename (str): Path to the output CSV file.
    """
    # Ensure both lists are of the same length
    assert len(training_losses) == len(validation_losses), "Training and validation losses must have the same length."

    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Epoch", "Training Loss", "Validation Loss"])

        for epoch_idx, (train_loss_val, val_loss_val) in enumerate(zip(training_losses, validation_losses), start=1):
            writer.writerow([epoch_idx, train_loss_val, val_loss_val])

In [33]:
def save_epoch_images(LR_img, HR_img, SR_img, HR_dim, epoch):
    """
    Save the input, output, and high-resolution images for a given epoch.

    Args:
        LR_img (np.ndarray): Low-resolution input image.
        HR_img (np.ndarray): High-resolution target image.
        SR_img (np.ndarray): Super-resolved output image.
        HR_dim (tuple): Dimensions for resizing the low-resolution image.
        epoch (int): Current epoch number.
    """
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, dpi=150, figsize=(16, 5))

    ax1.imshow(LR_img)
    ax1.set_title('Low-Resolution Input')
    ax2.imshow(HR_img)
    ax2.set_title('High-Resolution Target')
    ax3.imshow(SR_img)
    ax3.set_title('Super-Resolution Output')
    ax4.imshow(cv2.resize(LR_img, HR_dim, interpolation=cv2.INTER_CUBIC))
    ax4.set_title('Bicubic Interpolation')

    # for ax in (ax1, ax2, ax3, ax4):
        # ax.axis('off')

    plt.tight_layout()
    plt.savefig(f'trained-model/images/epoch_{epoch}.png')
    # plt.show()
    plt.close()


In [34]:
def prune_catalog(data_path):
    """
    Delete the contents of the directory.

    Args:
        data_path (str): Path to the dataset directory.
    """
    for root, dirs, files in os.walk(data_path):
        for file in files:
            os.remove(os.path.join(root, file))

In [35]:

class L1Loss(nn.Module):
    """L1 Loss (Mean Absolute Error)"""
    def __init__(self):
        super(L1Loss, self).__init__()

    def forward(self, prediction, target):
        return F.l1_loss(prediction, target)

class PerceptualLoss(nn.Module):
    """Perceptual Loss using specific layers from VGG-19"""
    def __init__(self, layers=[1, 3, 5, 9, 13], lambda_weights=None, device='cpu'):
        super(PerceptualLoss, self).__init__()
        self.layers = layers
        self.lambda_weights = lambda_weights if lambda_weights is not None else [1.0] * len(layers)
        self.device = device
        
        # Load pre-trained VGG-19
        vgg = vgg19(pretrained=True).features.to(device).eval()
        self.feature_extractors = [nn.Sequential(*list(vgg[:layer + 1])).to(device) for layer in layers]

        for extractor in self.feature_extractors:
            for param in extractor.parameters():
                param.requires_grad = False

    def forward(self, prediction, target):
        loss = 0.0
        for i, extractor in enumerate(self.feature_extractors):
            # Extract features
            pred_features = extractor(prediction)
            target_features = extractor(target)
            
            # Compute L1 loss on the features
            d_features = F.l1_loss(pred_features, target_features, reduction='mean')
            
            # Apply lambda weights and accumulate the loss
            loss += self.lambda_weights[i] * d_features
        
        return loss

class CombinedL1PerceptualLoss(nn.Module):
    """Combined L1 and Perceptual Loss"""
    def __init__(self, alpha=1.0, beta=1.0, layers=[1, 3, 5, 9, 13], lambda_weights=None, device='cpu'):
        super(CombinedL1PerceptualLoss, self).__init__()
        self.l1_loss = L1Loss()
        self.perceptual_loss = PerceptualLoss(layers=layers, lambda_weights=lambda_weights, device=device)
        self.alpha = alpha  # Weight for L1 loss
        self.beta = beta    # Weight for perceptual loss
        self.device = device

    def forward(self, prediction, target):
        # Move inputs to the same device as the loss functions
        prediction = prediction.to(self.device)
        target = target.to(self.device)
        
        l1 = self.l1_loss(prediction, target)
        perceptual = self.perceptual_loss(prediction, target)
        total_loss = self.alpha * l1 + self.beta * perceptual
        return total_loss


In [36]:

# Setup
device = choose_device()
model = RLFN_S(in_channels=3, out_channels=3).to(device)
# model.load_state_dict(torch.load( "epoch_2500.pth"))

# Initial optimizer
opt = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)

# Learning Rate Scheduler
scheduler = optim.lr_scheduler.StepLR(opt, step_size=200, gamma=0.5)

# Load dataset and transforms
train_data_path = "train"
val_data_path = "valid"

LR_dim = (64, 64)
HR_dim = (256, 256)
resize_obj = transforms.Resize(LR_dim, interpolation=InterpolationMode.BICUBIC)
batch_size = 32

transform = transforms.Compose([
    transforms.RandomCrop(HR_dim, pad_if_needed=True),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])

val_transforms = transforms.Compose([
    transforms.Resize(HR_dim[0], interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(HR_dim[0]),
    transforms.ToTensor(),
])

train_dataset = tv.datasets.ImageFolder(train_data_path, transform=transform)
val_dataset = tv.datasets.ImageFolder(val_data_path, transform=val_transforms)

train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, batch_size=batch_size)

# Initialize losses
l1_loss = L1Loss()
l2_loss = nn.MSELoss()
combined_loss = CombinedL1PerceptualLoss(alpha=1.0 / 256, beta=255.0 / 256, device=device)
best_val_loss = float('inf')

train_losses= []
val_losses = []

# Training loop with warm-start policy
for epoch in range(1000):
    model.train()
    train_loss = 0.0

    for images, _ in train_loader:
        images = images.to(device)
        HR_crop = images
        LR_crop = resize_obj(HR_crop).to(device)

        SR_img = model(LR_crop)

        loss = l1_loss(SR_img, HR_crop)
        loss.backward()
        opt.step()
        opt.zero_grad()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss);
    print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}')

    scheduler.step()
    model.eval()
    val_loss = 0.0


    SR_img = []
    HR_crop = []
    LR_crop = []

    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)
            HR_crop = images
            LR_crop = resize_obj(HR_crop).to(device)

            SR_img = model(LR_crop)
            val_loss += l1_loss(SR_img, HR_crop).item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    print(f'Epoch {epoch + 1}, Validation Loss: {val_loss:.4f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'trained-model/saves/best_model_l1_loss.pth')
        print(f'Saved new best model at epoch {epoch+1}')

    # Save model periodically
    if (epoch + 1) % 10 == 1:
        torch.save(model.state_dict(), f'trained-model/saves/epoch_{epoch + 1}.pth')
        idx = random.randint(0, SR_img.shape[0] - 1)
        SR_img_single = tensor2uint(SR_img[idx])
        HR_img = tensor2uint(HR_crop[idx])
        LR_img = tensor2uint(LR_crop[idx])
        save_epoch_images(LR_img, HR_img, SR_img_single, HR_dim, epoch + 1)
        save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')


save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')

opt = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=200, gamma=0.5)

# Warm-start policy - continue training with the same settings
for epoch in range(1000, 2000):
    model.train()
    train_loss = 0.0

    for images, _ in train_loader:
        images = images.to(device)
        HR_crop = images
        LR_crop = resize_obj(HR_crop).to(device)

        SR_img = model(LR_crop)

        loss = l1_loss(SR_img, HR_crop)
        loss.backward()
        opt.step()
        opt.zero_grad()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss);
    print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}')

    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    SR_img = []
    HR_crop = []
    LR_crop = []
    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)
            HR_crop = images
            LR_crop = resize_obj(HR_crop).to(device)

            SR_img = model(LR_crop)
            val_loss += l1_loss(SR_img, HR_crop).item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    print(f'Epoch {epoch + 1}, Validation Loss: {val_loss:.4f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'trained-model/saves/best_model_l1_loss.pth')
        print(f'Saved new best model at epoch {epoch+1}')

    # Save model periodically
    if (epoch + 1) % 10 == 1:
        torch.save(model.state_dict(), f'trained-model/saves/epoch_{epoch + 1}.pth')
        idx = random.randint(0, SR_img.shape[0] - 1)
        SR_img_single = tensor2uint(SR_img[idx])
        HR_img = tensor2uint(HR_crop[idx])
        LR_img = tensor2uint(LR_crop[idx])
        save_epoch_images(LR_img, HR_img, SR_img_single, HR_dim, epoch + 1)
        save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')

save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')


# After initial training and warm-start, switch to L1 + Perceptual loss and retrain with warm-start policy
opt = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=200, gamma=0.5)
best_val_loss = best_val_loss = float('inf')

# Fine-tune model
for epoch in range(2000, 2300):
    model.train()
    train_loss = 0.0

    for images, _ in train_loader:
        images = images.to(device)
        HR_crop = images
        LR_crop = resize_obj(HR_crop).to(device)

        SR_img = model(LR_crop)

        loss = combined_loss(SR_img, HR_crop)
        loss.backward()
        opt.step()
        opt.zero_grad()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss);
    print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}')

    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    SR_img = []
    HR_crop = []
    LR_crop = []
    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)
            HR_crop = images
            LR_crop = resize_obj(HR_crop).to(device)

            SR_img = model(LR_crop)
            val_loss += combined_loss(SR_img, HR_crop).item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    print(f'Epoch {epoch + 1}, Validation Loss: {val_loss:.4f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'trained-model/saves/best_model_combined.pth')
        print(f'Saved new best model at epoch {epoch+1}')

    # Save model periodically
    if (epoch + 1) % 10 == 1:
        torch.save(model.state_dict(), f'trained-model/saves/epoch_{epoch + 1}.pth')
        idx = random.randint(0, SR_img.shape[0] - 1)
        SR_img_single = tensor2uint(SR_img[idx])
        HR_img = tensor2uint(HR_crop[idx])
        LR_img = tensor2uint(LR_crop[idx])
        save_epoch_images(LR_img, HR_img, SR_img_single, HR_dim, epoch + 1)
        save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')


save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')

opt = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(opt, step_size=200, gamma=0.5)
best_val_loss = float('inf')

for epoch in range(2300, 2501):
    model.train()
    train_loss = 0.0

    for images, _ in train_loader:
        images = images.to(device)
        HR_crop = images
        LR_crop = resize_obj(HR_crop).to(device)

        SR_img = model(LR_crop)

        loss = l2_loss(SR_img, HR_crop)
        loss.backward()
        opt.step()
        opt.zero_grad()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss);
    print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}')

    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    SR_img = []
    HR_crop = []
    LR_crop = []
    with torch.no_grad():
        for images, _ in val_loader:
            images = images.to(device)
            HR_crop = images
            LR_crop = resize_obj(HR_crop).to(device)

            SR_img = model(LR_crop)
            val_loss += l2_loss(SR_img, HR_crop).item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    print(f'Epoch {epoch + 1}, Validation Loss: {val_loss:.4f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'trained-model/saves/best_model_l2_loss.pth')
        print(f'Saved new best model at epoch {epoch+1}')

    # Save model periodically
    if (epoch + 1) % 10 == 1:
        torch.save(model.state_dict(), f'trained-model/saves/epoch_{epoch + 1}.pth')
        idx = random.randint(0, SR_img.shape[0] - 1)
        SR_img_single = tensor2uint(SR_img[idx])
        HR_img = tensor2uint(HR_crop[idx])
        LR_img = tensor2uint(LR_crop[idx])
        save_epoch_images(LR_img, HR_img, SR_img_single, HR_dim, epoch + 1)
        save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')

save_losses_to_csv(train_losses, val_losses, 'trained-model/saves/losses.csv')

Epoch 1, Training Loss: 0.2694
Epoch 1, Validation Loss: 0.1413
Saved new best model at epoch 1
Epoch 2, Training Loss: 0.1039
Epoch 2, Validation Loss: 0.0945
Saved new best model at epoch 2
Epoch 3, Training Loss: 0.0783
Epoch 3, Validation Loss: 0.0814
Saved new best model at epoch 3
Epoch 4, Training Loss: 0.0702
Epoch 4, Validation Loss: 0.0742
Saved new best model at epoch 4
Epoch 5, Training Loss: 0.0643
Epoch 5, Validation Loss: 0.0681
Saved new best model at epoch 5
Epoch 6, Training Loss: 0.0584
Epoch 6, Validation Loss: 0.0620
Saved new best model at epoch 6
Epoch 7, Training Loss: 0.0528
Epoch 7, Validation Loss: 0.0577
Saved new best model at epoch 7
Epoch 8, Training Loss: 0.0487
Epoch 8, Validation Loss: 0.0546
Saved new best model at epoch 8
Epoch 9, Training Loss: 0.0464
Epoch 9, Validation Loss: 0.0521
Saved new best model at epoch 9
Epoch 10, Training Loss: 0.0445
Epoch 10, Validation Loss: 0.0509
Saved new best model at epoch 10
Epoch 11, Training Loss: 0.0426
Epoch

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

In [None]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

def preprocess_image(image_path, device):
    """
    Preprocess the input image.

    Args:
        image_path (str): Path to the input image.
        LR_dim (tuple): Dimensions to resize the low-resolution image.
        device (torch.device): Device to use (CPU or GPU).

    Returns:
        torch.Tensor: Preprocessed image tensor.
    """
    transform = transforms.Compose([
        # transforms.Resize(LR_dim),
        transforms.ToTensor()
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image.to(device)

def super_resolve_image(model, image_path, upscale_factor, device, output_image_path):
    """
    Super-resolve an image using the trained model.

    Args:
        model (torch.nn.Module): Trained super-resolution model.
        image_path (str): Path to the input low-resolution image.
        LR_dim (tuple): Low-resolution dimensions expected by the model.
        HR_dim (tuple): High-resolution dimensions for display.
        device (torch.device): Device to use (CPU or GPU).
        output_image_path (str): Path to save the super-resolved output image.
    """
    # Preprocess the input image
    input_image = preprocess_image(image_path, device)

    # Run the model on the preprocessed image
    model.eval()
    with torch.no_grad():
        SR_image = model(input_image)

    # Postprocess the output image
    output_image = tensor2uint(SR_image)

    input2 = tensor2uint(input_image)

    # Save the output image
    Image.fromarray(output_image).save(output_image_path)

    # Load the original image for comparison
    original_image = Image.open(image_path).convert('RGB')
    original_image = np.array(original_image)
    original_height, original_width = original_image.shape[:2]

    new_width = int(original_width * upscale_factor)
    new_height = int(original_height * upscale_factor)

    bicubic_image = cv2.resize(original_image, (new_width, new_height), interpolation=cv2.INTER_CUBIC)

    # Plot the input, output, and bicubic images
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    axes[0].imshow(original_image)
    axes[0].set_title('Low-Resolution Input')

    axes[1].imshow(output_image)
    axes[1].set_title('Super-Resolved Output')

    axes[2].imshow(bicubic_image)
    axes[2].set_title('Bicubic Interpolation')


    plt.tight_layout()
    plt.show()

# Example usage
device = choose_device()

# --------------------------------
# load model
# --------------------------------

model = RLFN_S(in_channels=3, out_channels=3)

model.load_state_dict(torch.load( "trained-model/saves/best_model_l2_loss.pth"))
model.to(device)

image_path = 'output_faces/face_39.jpg'  # Replace with the path to your low-res input image
# LR_dim = (64, 64)  # Replace with your model's low-res input dimensions
# HR_dim = (256, 256)  # Replace with your desired high-res output dimensions
upscale_factor = 4
output_image_path = 'super_resolved_image.png'  # Replace with the desired path to save the output image

super_resolve_image(model, image_path, upscale_factor, device, output_image_path)
