In [None]:
import os
import sys
import subprocess
import cv2
import torch
import numpy as np
from pathlib import Path
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from tqdm.notebook import tqdm

def load_model(model_path):
    # ClearReality model parameters
    model = RRDBNet(
        num_in_ch=3, 
        num_out_ch=3, 
        num_feat=32, 
        num_block=16, 
        num_grow_ch=32
    )
    
    # Load state dict
    state_dict = torch.load(model_path, map_location='mps')
    
    # Handle model format
    if "params" in state_dict:
        state_dict = state_dict["params"]
        
    # Clean state dict keys
    clean_state_dict = {}
    for k, v in state_dict.items():
        k = k.replace('module.', '')  # Remove module prefix
        if k.startswith('body.'): 
            k = k.replace('body.', '')  # Remove body prefix
        clean_state_dict[k] = v
        
    model.load_state_dict(clean_state_dict, strict=False)
    model.eval()
    return model.to('mps') if torch.backends.mps.is_available() else model.cpu()

def get_unique_output_path(output_dir, base_name, suffix):
    """Generate unique output filename with incremental counter if needed"""
    counter = 0
    while True:
        if counter == 0:
            output_name = f"{base_name}_4xUpscaled{suffix}"
        else:
            output_name = f"{base_name}_4xUpscaled_{counter:03d}{suffix}"
        output_path = output_dir / output_name
        if not output_path.exists():
            return output_path
        counter += 1

def find_model(model_dir):
    """Find first .pth file in model directory"""
    model_files = list(model_dir.glob('*.pth'))
    if not model_files:
        raise FileNotFoundError(f"No .pth model found in {model_dir}")
    return model_files[0]

def process_image(img_path, model_path, output_path):
    device = 'mps' if torch.backends.mps.is_available() else 'cpu'
    
    # Load and preprocess image
    img = cv2.imread(str(img_path))
    if img is None:
        print(f"Failed to load image: {img_path}")
        return False
        
    # Convert BGR to RGB and normalize
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_t = torch.from_numpy(img).float() / 255.
    img_t = img_t.permute(2, 0, 1).unsqueeze(0).to(device)
    
    # Load model and upscale
    model = load_model(model_path)
    with torch.no_grad():
        output = model(img_t)
    
    # Post-process and save
    output = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
    output = (output * 255.0).round().astype(np.uint8)
    output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
    cv2.imwrite(str(output_path), output)
    
    print(f"Saved result to {output_path.name}")
    return True

def process_folder():
    # Setup paths
    base_path = Path('.')
    input_path = base_path / 'input'
    model_dir = base_path / 'model'
    output_path = base_path / 'output'
    output_path.mkdir(exist_ok=True)
    
    # Find available model
    try:
        model_path = find_model(model_dir)
        print(f"Using model: {model_path.name}")
    except FileNotFoundError as e:
        print(e)
        return
    
    # Get all images
    image_files = []
    for ext in ['.jpg', '.jpeg', '.png']:
        image_files.extend(input_path.glob(f'*{ext}'))
    
    if not image_files:
        print("No images found in input folder!")
        return
    
    # Process each image
    print(f"Found {len(image_files)} images to process")
    for img_path in tqdm(image_files, desc="Processing images"):
        output_file = get_unique_output_path(output_path, img_path.stem, img_path.suffix)
        process_image(img_path, model_path, output_file)

if __name__ == "__main__":
    process_folder()