In [1]:
!git clone https://github.com/yohan-pg/robust-unsupervised.git

Cloning into 'robust-unsupervised'...
remote: Enumerating objects: 204, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 204 (delta 5), reused 5 (delta 5), pack-reused 194 (from 1)[K
Receiving objects: 100% (204/204), 34.85 MiB | 49.91 MiB/s, done.
Resolving deltas: 100% (80/80), done.


In [2]:
%cd robust-unsupervised

/kaggle/working/robust-unsupervised


In [3]:
!pip install tyro "git+https://github.com/jwblangley/pytorch-fid.git"

Collecting git+https://github.com/jwblangley/pytorch-fid.git
  Cloning https://github.com/jwblangley/pytorch-fid.git to /tmp/pip-req-build-r_1bzznl
  Running command git clone --filter=blob:none --quiet https://github.com/jwblangley/pytorch-fid.git /tmp/pip-req-build-r_1bzznl
  Resolved https://github.com/jwblangley/pytorch-fid.git to commit 3d604a25516746c3a4a5548c8610e99010b2c819
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting tyro
  Downloading tyro-0.9.33-py3-none-any.whl.metadata (12 kB)
Collecting shtab>=1.5.6 (from tyro)
  Downloading shtab-1.7.2-py3-none-any.whl.metadata (7.4 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid==0.2.1)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid==0.2.1)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecti

In [4]:
!wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl -O pretrained_networks/ffhq.pkl

# This cell replaces the StyleGAN2 download command with one for StyleGAN3.
# We are using the 'stylegan3-r' (rotation equivariant) variant trained on FFHQ.
# !wget https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhq-1024x1024.pkl -O pretrained_networks/ffhq.pkl

--2025-10-12 08:30:22--  https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
Resolving nvlabs-fi-cdn.nvidia.com (nvlabs-fi-cdn.nvidia.com)... 18.244.202.81, 18.244.202.50, 18.244.202.77, ...
Connecting to nvlabs-fi-cdn.nvidia.com (nvlabs-fi-cdn.nvidia.com)|18.244.202.81|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 381624121 (364M) [binary/octet-stream]
Saving to: ‘pretrained_networks/ffhq.pkl’


2025-10-12 08:30:35 (31.4 MB/s) - ‘pretrained_networks/ffhq.pkl’ saved [381624121/381624121]



In [5]:
!pip install lpips

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lpips
Successfully installed lpips-0.1.4


In [6]:
# Read the original content from run.py
with open('/kaggle/working/robust-unsupervised/run.py', 'r') as f:
    run_py_content = f.read()

# Read the original content from cli.py
with open('/kaggle/working/robust-unsupervised/cli.py', 'r') as f:
    cli_py_content = f.read()

In [7]:
# This cell combines the early stopping logic, the switch to the Adam optimizer,
# and the addition of detailed metrics logging into a single, sequential update.

import re
import math

# --- STEP 1: READ THE ORIGINAL, UNMODIFIED FILES ---
with open('/kaggle/working/robust-unsupervised/run.py', 'r') as f:
    run_py_content = f.read()

with open('/kaggle/working/robust-unsupervised/cli.py', 'r') as f:
    cli_py_content = f.read()

# --- STEP 2: DEFINE THE `project` FUNCTION WITH EARLY STOPPING ---
def get_project_with_early_stopping():
    return """
def project(
    G,
    target: torch.Tensor,
    original_image: torch.Tensor,
    *,
    config: Config,
    initial_w=None,
    progress=None,
):
    # Phase I
    w = initial_w if initial_w is not None else G.mapping.w_avg.clone()
    w.requires_grad_(True)
    optimizer = Adam([w], lr=config.lr_phase_1)
    
    best_loss_1 = math.inf
    patience_counter_1 = 0

    for step in range(config.n_steps_phase_1):
        synth_images = G.synthesis(w, noise_mode="const")
        loss, loss_dict = loss_fn(synth_images, target, original_image)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        progress.update(1, **loss_dict)

        if loss.item() < best_loss_1:
            best_loss_1 = loss.item()
            patience_counter_1 = 0
        else:
            patience_counter_1 += 1
        
        if patience_counter_1 >= config.patience:
            print(f"Stopping early in Phase 1 at step {step}.")
            break

    # Phase II
    w_plus = w.unsqueeze(1).repeat(1, G.num_ws, 1)
    w_plus.requires_grad_(True)
    optimizer = Adam([w_plus], lr=config.lr_phase_2)

    best_loss_2 = math.inf
    patience_counter_2 = 0

    for step in range(config.n_steps_phase_2):
        synth_images = G.synthesis(w_plus, noise_mode="const")
        loss, loss_dict = loss_fn(synth_images, target, original_image)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        progress.update(1, **loss_dict)

        if loss.item() < best_loss_2:
            best_loss_2 = loss.item()
            patience_counter_2 = 0
        else:
            patience_counter_2 += 1
        
        if patience_counter_2 >= config.patience:
            print(f"Stopping early in Phase 2 at step {step}.")
            break

    # Phase III
    w_plus_plus = w_plus.unsqueeze(1).repeat(1, G.num_filters, 1, 1)
    w_plus_plus.requires_grad_(True)
    optimizer = Adam([w_plus_plus], lr=config.lr_phase_3)
    
    best_loss_3 = math.inf
    patience_counter_3 = 0

    for step in range(config.n_steps_phase_3):
        synth_images = G.synthesis(w_plus_plus, noise_mode="const")
        loss, loss_dict = loss_fn(synth_images, target, original_image)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        progress.update(1, **loss_dict)
        
        if loss.item() < best_loss_3:
            best_loss_3 = loss.item()
            patience_counter_3 = 0
        else:
            patience_counter_3 += 1
        
        if patience_counter_3 >= config.patience:
            print(f"Stopping early in Phase 3 at step {step}.")
            break

    return synth_images
"""

# --- STEP 3: APPLY ALL MODIFICATIONS SEQUENTIALLY TO `run.py` ---

# 3.1: Replace the original `project` function with our new version.
# This version already includes the Adam optimizer and the correct function signature.
project_function_string = get_project_with_early_stopping()
updated_run_py_content = re.sub(r"def project\\(.*?\\):.*?(?=return synth_images)", project_function_string.strip(), run_py_content, flags=re.S)

# 3.2: Add all necessary imports at the top of the file.
imports_to_add = """
import math
from torch.optim import Adam
from metrics import calculate_accuracy_lpips, calculate_realism_fid, calculate_fidelity_lpips
import shutil
"""
updated_run_py_content = re.sub(r'(from robust_unsupervised.prelude import \*)', r'\\1' + imports_to_add, updated_run_py_content)


# 3.3: Update the main loop to pass the original image to the project function.
updated_run_py_content = updated_run_py_content.replace(
    'projected_w = project(G, target=target, progress=progress)',
    'projected_w = project(G, target=target, original_image=batch["image"], progress=progress)'
)


# 3.4: Add the metrics calculation logic to the end of the `run` function.
metrics_logging_code = """
    # --- DETAILED METRICS CALCULATION ---
    ground_truth_dir = config.dataset_path
    restored_dir = os.path.join(task_dir, os.path.basename(config.dataset_path))
    degraded_dir = task_dir
    
    accuracy = calculate_accuracy_lpips(restored_dir, ground_truth_dir)
    realism = calculate_realism_fid(restored_dir, ground_truth_dir)
    fidelity = calculate_fidelity_lpips(restored_dir, degraded_dir, degradation.f)

    shutil.rmtree(restored_dir)

    summary = f\"\"\"
    -----------------------------------------------------
    PERFORMANCE REPORT FOR TASK: {task_name}
    -----------------------------------------------------
    - Accuracy (LPIPS ↓): {accuracy:.4f}
    - Realism (FID ↓):    {realism:.2f}
    - Fidelity (LPIPS ↓): {fidelity:.4f}
    -----------------------------------------------------
    \"\"\"
    print(summary)
    with open(os.path.join(task_dir, 'performance_report.txt'), 'w') as f:
        f.write(summary)
"""
updated_run_py_content = re.sub(r'(if __name__ == "__main__":)', metrics_logging_code + r'\\n\\1', updated_run_py_content, flags=re.S)


# --- STEP 4: APPLY MODIFICATIONS TO `cli.py` ---
# Add the 'patience' parameter for the early stopping feature.
updated_cli_py_content = cli_py_content.replace(
    '    n_steps_phase_3: int = 150',
    '    n_steps_phase_3: int = 150\\n    patience: int = 15  # Steps to wait for improvement before stopping'
)


# --- STEP 5: WRITE THE FINAL, MODIFIED CONTENT TO THE FILES ---
with open('/kaggle/working/robust-unsupervised/run.py', 'w') as f:
    f.write(updated_run_py_content)
print("✅ run.py updated with early stopping, Adam optimizer, and metrics logging.")

with open('/kaggle/working/robust-unsupervised/cli.py', 'w') as f:
    f.write(updated_cli_py_content)
print("✅ cli.py updated with 'patience' parameter for early stopping.")

✅ run.py updated with early stopping, Adam optimizer, and metrics logging.
✅ cli.py updated with 'patience' parameter for early stopping.


In [8]:
!pip install deepface

Collecting deepface
  Downloading deepface-0.0.95-py3-none-any.whl.metadata (35 kB)
Collecting flask-cors>=4.0.1 (from deepface)
  Downloading flask_cors-6.0.1-py3-none-any.whl.metadata (5.3 kB)
Collecting mtcnn>=0.1.0 (from deepface)
  Downloading mtcnn-1.0.0-py3-none-any.whl.metadata (5.8 kB)
Collecting retina-face>=0.0.14 (from deepface)
  Downloading retina_face-0.0.17-py3-none-any.whl.metadata (10 kB)
Collecting fire>=0.4.0 (from deepface)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Collecting gunicorn>=20.1.0 (from deepface)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting lz4>=4.3.3 (from mtcnn>=0.1.0->deepface)
  Downloading lz4-4.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading deepface-0.0.95-py3-none-any.whl (128 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m128.3/128.3 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fire-0.7.1-py3-none-a

In [9]:
# This code will completely overwrite your loss_function.py file

new_loss_function_content = """
from robust_unsupervised.prelude import *
from lpips import LPIPS
from deepface import DeepFace

# Helper function to preprocess images for the face recognition model
def preprocess_for_face_recognition(tensor_image):
    # The model expects images in a specific format (BGR, specific size etc.)
    # We convert our PyTorch tensor to a NumPy array that deepface can use.
    # The tensor is expected to be in range [-1, 1], so we shift to [0, 255]
    image_np = (tensor_image.permute(0, 2, 3, 1) * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).cpu().numpy()
    # DeepFace handles the rest of the preprocessing internally
    return image_np

class IDLoss(nn.Module):
    def __init__(self):
        super(IDLoss, self).__init__()
        # Load the ArcFace model. It will be downloaded automatically the first time.
        # We only need the model for embedding extraction.
        self.model = DeepFace.build_model('ArcFace')
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        # x and y are the generated and ground-truth images, respectively.
        # Both are PyTorch tensors.
        
        # Preprocess images for the model
        x_np = preprocess_for_face_recognition(x)
        y_np = preprocess_for_face_recognition(y)

        # Get the embeddings. DeepFace expects a list of images.
        x_embedding = DeepFace.represent(img_path=x_np[0], model=self.model, enforce_detection=False)
        y_embedding = DeepFace.represent(img_path=y_np[0], model=self.model, enforce_detection=False)
        
        # Convert embeddings back to PyTorch tensors
        x_embedding_tensor = torch.tensor(x_embedding[0]['embedding']).unsqueeze(0).to(x.device)
        y_embedding_tensor = torch.tensor(y_embedding[0]['embedding']).unsqueeze(0).to(y.device)

        # Calculate Cosine Similarity Loss
        loss = 1 - torch.cosine_similarity(x_embedding_tensor, y_embedding_tensor)
        return loss.mean()

class LossFunction(nn.Module):
    def __init__(self, lambda_l1: float = 0.1, lambda_id: float = 0.1):
        super().__init__()
        self.lambda_l1 = lambda_l1
        self.lambda_id = lambda_id
        self.l1 = nn.L1Loss()
        self.lpips = LPIPS(net="vgg").cuda()
        self.id_loss = IDLoss().cuda()

    def forward(self, synth_images, target_images, original_images):
        # The main forward pass now accepts the original image
        
        # L1 Loss (Pixel-wise)
        l1_loss = self.l1(synth_images, target_images)

        # LPIPS Loss (Perceptual)
        lpips_loss = self.lpips(synth_images, target_images).mean()
        
        # Identity Loss
        id_loss = self.id_loss(synth_images, original_images)

        # Combine the losses
        total_loss = lpips_loss + self.lambda_l1 * l1_loss + self.lambda_id * id_loss
        
        return total_loss, {
            "loss": total_loss.item(),
            "lpips": lpips_loss.item(),
            "l1": l1_loss.item(),
            "id": id_loss.item(),
        }

# Global instance
loss_fn = LossFunction()
"""

# Write the new content to the file
with open('/kaggle/working/robust-unsupervised/loss_function.py', 'w') as f:
    f.write(new_loss_function_content)

print("✅ loss_function.py has been updated with the Identity Loss.")

✅ loss_function.py has been updated with the Identity Loss.


In [10]:
# This cell creates a new file, 'metrics.py', to handle performance evaluation.
# We will calculate Accuracy (LPIPS), Realism (FID), and Fidelity (LPIPS on re-degraded images).

metrics_py_content = """
from robust_unsupervised.prelude import *
from pytorch_fid.fid_score import calculate_fid_given_paths
from lpips import LPIPS
import os

# Initialize the LPIPS model once to be reused.
lpips_fn = LPIPS(net="vgg").cuda()

def calculate_accuracy_lpips(restored_dir: str, ground_truth_dir: str) -> float:
    \"\"\"
    Calculates Accuracy, defined as the average LPIPS between restored
    images and their ground truth counterparts. 
    \"\"\"
    print("Calculating Accuracy (LPIPS)...")
    restored_files = sorted([os.path.join(restored_dir, f) for f in os.listdir(restored_dir)])
    gt_files = sorted([os.path.join(ground_truth_dir, f) for f in os.listdir(ground_truth_dir)])
    
    total_lpips = 0.0
    num_images = len(restored_files)
    
    for restored_path, gt_path in zip(restored_files, gt_files):
        restored_img = read_image_tensor(restored_path).cuda()
        gt_img = read_image_tensor(gt_path).cuda()
        lpips_score = lpips_fn(restored_img, gt_img).item()
        total_lpips += lpips_score
        
    return total_lpips / num_images if num_images > 0 else 0.0

def calculate_realism_fid(restored_dir: str, ground_truth_dir: str) -> float:
    \"\"\"
    Calculates Realism using Frechet Inception Distance (FID) between the set of 
    restored images and the set of ground truth images. [cite: 226]
    The paper uses a patch-based FID (pFID), but we use standard FID as a strong proxy.
    \"\"\"
    print("Calculating Realism (FID)...")
    # These parameters are commonly used for FID calculation.
    dims = 2048
    batch_size = 32
    device = torch.device("cuda")
    
    fid_value = calculate_fid_given_paths(
        paths=[restored_dir, ground_truth_dir],
        batch_size=batch_size,
        device=device,
        dims=dims
    )
    return fid_value
    
def calculate_fidelity_lpips(restored_dir: str, degraded_dir: str, degradation_fn) -> float:
    \"\"\"
    Calculates Fidelity, defined as the average LPIPS between a re-degraded
    restored image and the original degraded target image. [cite: 228]
    \"\"\"
    print("Calculating Fidelity (LPIPS)...")
    restored_files = sorted([os.path.join(restored_dir, f) for f in os.listdir(restored_dir)])
    degraded_files = sorted([os.path.join(degraded_dir, f) for f in os.listdir(degraded_dir)])

    total_lpips = 0.0
    num_images = len(restored_files)

    for restored_path, degraded_path in zip(restored_files, degraded_files):
        restored_img = read_image_tensor(restored_path).cuda()
        degraded_target_img = read_image_tensor(degraded_path).cuda()
        
        # Apply the same degradation to our restored image. [cite: 229]
        re_degraded_img = degradation_fn(restored_img)
        
        lpips_score = lpips_fn(re_degraded_img, degraded_target_img).item()
        total_lpips += lpips_score

    return total_lpips / num_images if num_images > 0 else 0.0

def read_image_tensor(path: str) -> torch.Tensor:
    \"\"\"Helper to read an image and convert it to a PyTorch tensor in [-1, 1] range.\"\"\"
    img = Image.open(path).convert("RGB")
    img_tensor = F.to_tensor(img) * 2 - 1
    return img_tensor.unsqueeze(0)
"""

# Write the new content to the file
with open('/kaggle/working/robust-unsupervised/metrics.py', 'w') as f:
    f.write(metrics_py_content)

print("✅ metrics.py created successfully with functions for Accuracy, Realism, and Fidelity.")

✅ metrics.py created successfully with functions for Accuracy, Realism, and Fidelity.


In [11]:
%cd /kaggle/working/robust-unsupervised

!python run.py --dataset_path datasets/samples

/kaggle/working/robust-unsupervised
restored_samples
Loading generator from pretrained_networks/ffhq.pkl...
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|█████████████████████████████████████████| 528M/528M [00:02<00:00, 217MB/s]
out/restored_samples/2025-10-12T083055/single_tasks/upsampling/XL/
/kaggle/working/robust-unsupervised/out/restored_samples/2025-10-12T083055/single_tasks/upsampling/XL/datasets/samples
- 0000
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Done.
W: 100%|██████████████████████████████████████| 150/150 [00:41<00:00,  3.59it/s]
W+: 100%|█████████████████████████████████████| 150/150 [00:18<00:00,  8.32it/s]
W++: 100%|████████████████████████████████████| 150/150 [00:20<00:00,  7.27it/s]
- 0001
W: 100%|██████████████████████████████████████| 150/150 [00:18<00:00,  8.30it/s]
W+: 100%|█████████████████████████████████████| 150/150 [00:18<00:00,