In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
from torchvision import models

# Load a pre-trained model for low-light enhancement from TorchHub
model = torch.hub.load('AK391/animegan2-pytorch', 'generator', pretrained=True, trust_repo=True)
model.eval()

# Image preprocessing
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

# Image post-processing
def postprocess_image(tensor):
    tensor = tensor.squeeze(0).detach().cpu().numpy()
    tensor = np.transpose(tensor, (1, 2, 0))  # CHW -> HWC
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * 255  # Normalize
    return tensor.astype(np.uint8)

# Enhance image
def enhance_image(image_path, output_path):
    image = preprocess_image(image_path)
    with torch.no_grad():
        enhanced_image = model(image)
    enhanced_image = postprocess_image(enhanced_image)
    cv2.imwrite(output_path, cv2.cvtColor(enhanced_image, cv2.COLOR_RGB2BGR))

Downloading: "https://github.com/AK391/animegan2-pytorch/zipball/main" to /Users/achintyajha/.cache/torch/hub/main.zip
Downloading: "https://github.com/bryandlee/animegan2-pytorch/raw/main/weights/face_paint_512_v2.pt" to /Users/achintyajha/.cache/torch/hub/checkpoints/face_paint_512_v2.pt
100%|██████████████████████████████████████| 8.20M/8.20M [00:00<00:00, 26.1MB/s]


In [5]:
# Example usage
enhance_image("./lol_dataset/our485/low/2.png", "enhanced_output.jpg")