### downloading data

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import zipfile
from pathlib import Path

# Path to the downloaded zip file on Google Drive
zip_file_path = "/content/drive/MyDrive/celebS.zip"

# Create the directory to extract to
extract_dir = Path("images")
extract_dir.mkdir(parents=True, exist_ok=True)

# Extract the contents
try:
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
    print(f"Extracted files to: {extract_dir}")
except FileNotFoundError:
    print(f"Error: The file {zip_file_path} was not found.")
except zipfile.BadZipFile:
    print(f"Error: The file {zip_file_path} is not a valid zip file.")
except Exception as e:
    print(f"An error occurred during extraction: {e}")

Extracted files to: images


In [3]:
import os
import shutil
from pathlib import Path

# Define source and destination directories
source_dir = Path("images/img_align_celeba") # Assuming the images are in a subdirectory after extraction
train_dir = Path("images/train")
valid_dir = Path("images/valid")

# Create destination directories
train_dir.mkdir(parents=True, exist_ok=True)
valid_dir.mkdir(parents=True, exist_ok=True)

# Get list of image files
image_files = list(source_dir.glob("*.jpg"))

# Define split ratio (e.g., 80% train, 20% valid)
split_ratio = 0.8
split_index = int(len(image_files) * split_ratio)

# Split files
train_files = image_files[:split_index]
valid_files = image_files[split_index:]

# Move files to respective directories
print("Moving training images...")
for file in train_files:
    shutil.move(str(file), str(train_dir / file.name))

print("Moving validation images...")
for file in valid_files:
    shutil.move(str(file), str(valid_dir / file.name))

print("Image splitting complete.")

Moving training images...
Moving validation images...
Image splitting complete.


In [4]:
import os
import shutil
from pathlib import Path

# Define source and destination directories
source_dir = Path("images/img_align_celeba") # Assuming the images are in a subdirectory after extraction
train_dir = Path("images/train")
valid_dir = Path("images/valid")

# Create destination directories
train_dir.mkdir(parents=True, exist_ok=True)
valid_dir.mkdir(parents=True, exist_ok=True)

# Get list of image files
image_files = list(source_dir.glob("*.jpg"))

# Define split ratio (e.g., 80% train, 20% valid)
split_ratio = 0.8
split_index = int(len(image_files) * split_ratio)

# Split files
train_files = image_files[:split_index]
valid_files = image_files[split_index:]

# Move files to respective directories
print("Moving training images...")
for file in train_files:
    shutil.move(str(file), str(train_dir / file.name))

print("Moving validation images...")
for file in valid_files:
    shutil.move(str(file), str(valid_dir / file.name))

print("Image splitting complete.")

Moving training images...
Moving validation images...
Image splitting complete.


In [5]:
!pip install piq

Collecting piq
  Downloading piq-0.8.0-py3-none-any.whl.metadata (17 kB)
Downloading piq-0.8.0-py3-none-any.whl (106 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.9/106.9 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: piq
Successfully installed piq-0.8.0


In [6]:
import os
from pathlib import Path
from PIL import Image
import random

import torch
from torch import nn
from torch.utils.data import Dataset
import torchvision.transforms.functional as F
import torchvision.transforms as tf
from torch.utils.data import DataLoader

from piq import ssim
import matplotlib.pyplot as plt
import tqdm

In [7]:
class UFaceDataset(Dataset):
    def __init__(self, path, transform):
        super().__init__()
        self.paths = [p for p in Path(path).rglob("*.jpg")]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        try:
            img = Image.open(self.paths[index])
            main = self.transform(img)

            kernel_size = random.choice([3, 5, 7])
            sigma = random.uniform(0.1, 2.0)
            blured = F.gaussian_blur(main, kernel_size=kernel_size, sigma=sigma)

            return {
                'main': main,
                'blured': blured
            }

        except Exception as e:
            print(f"error while opening image at index {index}, path: {self.paths[index]}")
            raise e


In [8]:
class FaceUNet(nn.Module):
    def __init__(self):
        super().__init__()

        # feature size / 2
        # channel 3 -> 64
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # /2
        )

        # feature size * 2
        # channel 64 -> 3
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )

        # after concatenating output of up 3 with the input image
        self.conv1 = nn.Conv2d(6, 3, kernel_size=3, padding=1, stride=1)  # 3 channels

        # ===============================================================================================

        # feature size / 2
        # channel 64 -> 128
        self.down2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=1),  # [128]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # /2
        )

        # feature size * 2
        # channel 128 -> 64
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )

        # after concatenating output of up 2 with output of down 1
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1, stride=1)  # 64 channels

        # ===============================================================================================

        # feature size / 2
        # channel 128 -> 256
        self.down3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=1),  # [256]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),  # /2
        )

        # feature size * 2
        # channel 256 -> 128
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # *2 , [64]
            nn.ReLU(),
        )

        # after concatenating output of up 1 with output of down 2
        self.conv3 = nn.Conv2d(256, 128,  kernel_size=3, padding=1, stride=1)  # 128 channels

        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        identity0 = x  # c = 3

        # ==--==--==--==-
        x = self.down1(x)  # c = 64
        identity1 = x  # c = 64

        # ==--==--==--==-
        x = self.down2(x)  # c = 128
        identity2 = x  # c = 128

        # ==--==--==--==-
        x = self.down3(x)  # c = 256

        # ==--==--==--==-
        x = self.up3(x)  # c = 128
        x = torch.cat([x, identity2], dim=1)  # c = 256
        x = self.conv3(x)  # c = 128  |  recover the channels again

        # ==--==--==--==-
        x = self.up2(x)  # c = 64
        x = torch.cat([x, identity1], dim=1)  # c = 128
        x = self.conv2(x)  # c = 64  |  recover the channels again

        # ==--==--==--==-
        x = self.up1(x)
        x = torch.cat([x, identity0], dim=1)  # c = 6
        x = self.sigmoid(self.conv1(x))  # c = 3  |  recover the channels again

        return x

In [9]:
class FaceUNetLoss(nn.Module):
    def __init__(self, lambda_ssim: float = 0.3):
        super().__init__()
        self.lambda_ssim = lambda_ssim
        self.mse_loss = nn.MSELoss()

    def forward(self, pred, target):
        mse_part = self.mse_loss(pred, target)
        ssim_part = 1 - ssim(pred, target)
        return mse_part + self.lambda_ssim * ssim_part

In [10]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"starting operation on device: {DEVICE}")

starting operation on device: cuda


In [11]:
def plot_losses(train_losses, val_losses):
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

In [12]:
# paths
train_path = "images/train"
valid_path = "images/valid"

# transforms
transform = tf.Compose([
    tf.Resize([176, 216]),
    tf.ToTensor(),
])

In [13]:
train_dataset = UFaceDataset(train_path, transform)
valid_dataset = UFaceDataset(valid_path, transform)

print(f"found {len(train_dataset)} training images")
print(f"found {len(valid_dataset)} validation images")

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False)

found 162079 training images
found 40520 validation images


In [14]:
model = FaceUNet().to(DEVICE)
criterion = FaceUNetLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
epochs = 15

# early stopping
patience_number = 3
delta = 0.001
best_val_loss = float("inf")
patience = 0

# outputs
os.makedirs("outputs", exist_ok=True)
os.makedirs("outputs/checkpoints", exist_ok=True)
train_losses = []
val_losses = []

In [None]:
print(f"🚀 Start training with {epochs} epochs ...")

for epoch in range(epochs):
    # train
    model.train()
    train_loss = 0.0

    for batch in tqdm.tqdm(train_loader, total=len(train_loader), desc="training  "):
        blrd_img, main_img = batch["blured"].to(DEVICE), batch["main"].to(DEVICE)

        optimizer.zero_grad()
        output = model(blrd_img)
        loss = criterion(output, main_img)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    # validation
    model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for batch in tqdm.tqdm(valid_loader, total=len(valid_loader), desc="validating"):
            blrd_img, main_img = batch["blured"].to(DEVICE), batch["main"].to(DEVICE)

            output = model(blrd_img)
            loss = criterion(output, main_img)
            valid_loss += loss.item()

    avg_val_loss = valid_loss / len(valid_loader)

    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}", end="")
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)

    # early stopping
    if avg_val_loss < best_val_loss - delta:
        best_val_loss = avg_val_loss
        patience = 0
        torch.save(model.state_dict(), "outputs/checkpoints/superres_best.pth")
        print(f"✅ Validation improved. Saving model.")
    else:
        patience += 1
        print(f"⚠️ No improvement. Early stop counter: {patience}/{patience_number}")
        if patience >= patience_number:
            print("⛔ Early stopping triggered!")
            break

🚀 Start training with 15 epochs ...


training  : 100%|██████████| 2533/2533 [28:15<00:00,  1.49it/s]
validating: 100%|██████████| 634/634 [04:27<00:00,  2.37it/s]


[Epoch 1] Train Loss: 0.0218 | Val Loss: 0.0122✅ Validation improved. Saving model.


training  : 100%|██████████| 2533/2533 [27:22<00:00,  1.54it/s]
validating: 100%|██████████| 634/634 [04:23<00:00,  2.40it/s]


[Epoch 2] Train Loss: 0.0106 | Val Loss: 0.0095✅ Validation improved. Saving model.


training  : 100%|██████████| 2533/2533 [27:32<00:00,  1.53it/s]
validating: 100%|██████████| 634/634 [04:26<00:00,  2.38it/s]


[Epoch 3] Train Loss: 0.0090 | Val Loss: 0.0085✅ Validation improved. Saving model.


training  : 100%|██████████| 2533/2533 [27:42<00:00,  1.52it/s]
validating: 100%|██████████| 634/634 [04:24<00:00,  2.39it/s]


[Epoch 4] Train Loss: 0.0083 | Val Loss: 0.0085⚠️ No improvement. Early stop counter: 1/3


training  : 100%|██████████| 2533/2533 [27:44<00:00,  1.52it/s]
validating: 100%|██████████| 634/634 [04:25<00:00,  2.39it/s]


[Epoch 5] Train Loss: 0.0079 | Val Loss: 0.0076⚠️ No improvement. Early stop counter: 2/3


training  : 100%|██████████| 2533/2533 [27:41<00:00,  1.52it/s]
validating: 100%|██████████| 634/634 [04:26<00:00,  2.38it/s]


[Epoch 6] Train Loss: 0.0076 | Val Loss: 0.0074✅ Validation improved. Saving model.


training  :  68%|██████▊   | 1712/2533 [18:48<08:58,  1.52it/s]

In [None]:
plot_losses(train_losses, val_losses)