In [None]:
# Imports cell
import os
import sys
import subprocess
import cv2
import torch
import numpy as np
from pathlib import Path
from basicsr.archs.rrdbnet_arch import RRDBNet
from facexlib.detection import init_detection_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image
from IPython.display import display, Image as IPImage
from tqdm.notebook import tqdm

# Setup cell
def check_dependencies():
    required_packages = {
        'torch': 'torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu',
        'opencv-python': 'opencv-python',
        'basicsr': 'basicsr',
        'facexlib': 'facexlib'
    }
    
    for package, install_cmd in required_packages.items():
        try:
            __import__(package)
            print(f"{package} already installed")
        except ImportError:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", install_cmd])

# Model and processing cell
def load_model(model_path):
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
    model.load_state_dict(torch.load(model_path, map_location='mps'))
    model.eval()
    return model.to('mps') if torch.backends.mps.is_available() else model.cpu()

def process_face(img_path, model_path, output_path, scale=4):
    device = 'mps' if torch.backends.mps.is_available() else 'cpu'
    face_detector = init_detection_model('retinaface_resnet50', device=device)
    face_helper = FaceRestoreHelper(upscale_factor=scale)
    
    img = cv2.imread(str(img_path))
    if img is None:
        print(f"Failed to load image: {img_path}")
        return False
        
    face_helper.read_image(img)
    
    print(f"Processing {img_path.name}...")
    print("Detecting faces...")
    face_helper.detect_faces(face_detector)
    face_helper.align_warp_face()
    
    if len(face_helper.cropped_faces) == 0:
        print("No faces detected!")
        return False
    
    print(f"Found {len(face_helper.cropped_faces)} faces")
    model = load_model(model_path)
    
    for idx, cropped_face in enumerate(face_helper.cropped_faces):
        print(f"Processing face {idx+1}...")
        cropped_face_t = torch.from_numpy(cropped_face).float() / 255.
        cropped_face_t = cropped_face_t.permute(2, 0, 1).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(cropped_face_t)
        
        output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
        output = (output * 255.0).round().astype(np.uint8)
        face_helper.add_restored_face(output)
    
    face_helper.get_inverse_affine(None)
    restored_img = face_helper.paste_faces_to_input_image()
    cv2.imwrite(str(output_path), restored_img)
    print(f"Saved result to {output_path.name}")
    return True

# Main processing cell
def process_folder():
    # Setup paths
    base_path = Path('py_4x_upscaler')
    input_path = base_path / 'input'
    model_path = base_path / 'model' / 'upscaler.pth'
    output_path = base_path / 'output'
    output_path.mkdir(exist_ok=True)
    
    # Validate model exists
    if not model_path.exists():
        raise FileNotFoundError(f"Model file not found at {model_path}")
    
    # Get all images from input folder
    image_files = []
    for ext in ['.jpg', '.jpeg', '.png']:
        image_files.extend(input_path.glob(f'*{ext}'))
    
    if not image_files:
        print("No images found in input folder!")
        return
    
    # Process each image
    print(f"Found {len(image_files)} images to process")
    for img_path in tqdm(image_files, desc="Processing images"):
        output_file = output_path / f"{img_path.stem}_4xUpscaled{img_path.suffix}"
        process_face(img_path, model_path, output_file)

# Execution cell
if __name__ == "__main__":
    check_dependencies()
    process_folder()