In [9]:
import os
import time
import argparse
import numpy as np
from PIL import Image
from glob import glob
from ntpath import basename
from os.path import join, exists
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import torchvision.transforms as transforms

In [10]:
class Options:
    data_dir = "data/A/1.jpg"
    sample_dir = "data/output/"
    model_name = "funiegan"  # or "ugan"
    model_path = "models/funie_generator.pth"

opt = Options()

In [12]:
assert exists(opt.model_path), "model not found"
os.makedirs(opt.sample_dir, exist_ok=True)
is_cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if is_cuda else torch.FloatTensor 
print(is_cuda)

True


In [13]:
if opt.model_name.lower()=='funiegan':
    from nets import funiegan
    model = funiegan.GeneratorFunieGAN()
else:
    pass

In [14]:
model.load_state_dict(torch.load(opt.model_path))
if is_cuda: model.cuda()
model.eval()
print ("Loaded model from %s" % (opt.model_path))

Loaded model from models/funie_generator.pth


In [None]:
img_width, img_height, channels = 512, 512, 3
transforms_ = [transforms.Resize((img_height, img_width), Image.BICUBIC),
               transforms.ToTensor(),
               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]
transform = transforms.Compose(transforms_)

In [16]:
print(f"Model loaded: {opt.model_name}")
print(f"Input folder: {opt.data_dir}")
print(f"Output folder: {opt.sample_dir}")
print(f"Looking for test images...")

test_files = sorted(glob(join(opt.data_dir, "*")))
print(f"Found {len(test_files)} images.\n")

Model loaded: funiegan
Input folder: data/A/1.jpg
Output folder: data/output/
Looking for test images...
Found 0 images.



In [8]:
input_image_path = "data/A/1.jpg"  # <-- Change this to your image path

try:
    print(f"Processing {input_image_path}")
    inp_img = transform(Image.open(input_image_path).convert("RGB"))
    inp_img = Variable(inp_img).type(Tensor).unsqueeze(0)

    s = time.time()
    gen_img = model(inp_img)
    elapsed = time.time() - s
    print(f"Inference time: {elapsed:.3f} sec")

    # Save result (side-by-side input and output)
    img_sample = torch.cat((inp_img.data, gen_img.data), -1)
    save_path = join(opt.sample_dir, basename(input_image_path))
    save_image(img_sample, save_path, normalize=True)

    print(f"Saved output to: {save_path}")

except Exception as e:
    print(f"Error processing image: {e}")

Processing data/A/1.jpg
Inference time: 1.127 sec
Saved output to: data/output/1.jpg


In [17]:
import cv2

input_video_path = "data/A/vid.mp4"  # <-- your video input
output_video_path = join(opt.sample_dir, "enhanced_" + basename(input_video_path))

try:
    print(f"Processing video: {input_video_path}")

    # Open video capture
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {input_video_path}")

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Output writer: double width for side-by-side display
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width * 2, height))

    times = []
    frame_count = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        # Resize + convert to PIL for model input
        pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).resize((256, 256), Image.BICUBIC)
        inp_tensor = transform(pil_img).unsqueeze(0).type(Tensor)

        # Inference
        s = time.time()
        with torch.no_grad():
            gen_tensor = model(inp_tensor)
        times.append(time.time() - s)

        # Convert output to numpy
        inp_np = inp_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
        gen_np = gen_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)

        # Denormalize
        inp_np = ((inp_np * 0.5 + 0.5) * 255).astype(np.uint8)
        gen_np = ((gen_np * 0.5 + 0.5) * 255).astype(np.uint8)

        # Resize output back to original video size
        inp_resized = cv2.resize(inp_np, (width, height))
        gen_resized = cv2.resize(gen_np, (width, height))

        # Stack side-by-side and convert to BGR for OpenCV
        side_by_side = cv2.hconcat([
            cv2.cvtColor(inp_resized, cv2.COLOR_RGB2BGR),
            cv2.cvtColor(gen_resized, cv2.COLOR_RGB2BGR)
        ])

        out.write(side_by_side)
        frame_count += 1
        print(f"Processed frame {frame_count}")

    # Release everything
    cap.release()
    out.release()

    if len(times) > 1:
        Ttime, Mtime = np.sum(times[1:]), np.mean(times[1:])
        print(f"\nProcessed {frame_count} frames.")
        print(f"Time taken: {Ttime:.2f} sec at {1. / Mtime:.2f} fps")
        print(f"Saved enhanced video to: {output_video_path}")

except Exception as e:
    print(f"Error processing video: {e}")


Processing video: data/A/vid.mp4
Processed frame 1
Processed frame 2
Processed frame 3
Processed frame 4
Processed frame 5
Processed frame 6
Processed frame 7
Processed frame 8
Processed frame 9
Processed frame 10
Processed frame 11
Processed frame 12
Processed frame 13
Processed frame 14
Processed frame 15
Processed frame 16
Processed frame 17
Processed frame 18
Processed frame 19
Processed frame 20
Processed frame 21
Processed frame 22
Processed frame 23
Processed frame 24
Processed frame 25
Processed frame 26
Processed frame 27
Processed frame 28
Processed frame 29
Processed frame 30
Processed frame 31
Processed frame 32
Processed frame 33
Processed frame 34
Processed frame 35
Processed frame 36
Processed frame 37
Processed frame 38
Processed frame 39
Processed frame 40
Processed frame 41
Processed frame 42
Processed frame 43
Processed frame 44
Processed frame 45
Processed frame 46
Processed frame 47
Processed frame 48
Processed frame 49
Processed frame 50
Processed frame 51
Process

In [20]:
import cv2
import time
import numpy as np
from os.path import join, basename
from PIL import Image
from torchvision import transforms
from torch.autograd import Variable

# Paths
input_video_path = "data/A/vid.mp4"
output_video_path = join(opt.sample_dir, "enhanced_" + basename(input_video_path))

# Transform (assumed same as used in model training)
transform = transforms.Compose([
    transforms.Resize((256, 256), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

try:
    print(f"Processing video: {input_video_path}")

    # Open video
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {input_video_path}")

    # Video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    max_frames = int(fps * 4)  # Only process first 4 seconds

    # Output writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width * 2, height))

    times = []
    frame_count = 0

    while frame_count < max_frames:
        ret, frame = cap.read()
        if not ret:
            break

        # Convert frame to PIL and apply transform
        pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        inp_tensor = transform(pil_img).unsqueeze(0)
        inp_tensor = Variable(inp_tensor).type(Tensor)

        # Inference
        s = time.time()
        with torch.no_grad():
            gen_tensor = model(inp_tensor)
        times.append(time.time() - s)

        # Convert tensors to numpy and denormalize
        inp_np = inp_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
        gen_np = gen_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
        inp_np = ((inp_np * 0.5 + 0.5) * 255).astype(np.uint8)
        gen_np = ((gen_np * 0.5 + 0.5) * 255).astype(np.uint8)

        # Resize back to original size
        inp_resized = cv2.resize(inp_np, (width, height))
        gen_resized = cv2.resize(gen_np, (width, height))

        # Side-by-side stack and save
        side_by_side = cv2.hconcat([
            cv2.cvtColor(inp_resized, cv2.COLOR_RGB2BGR),
            cv2.cvtColor(gen_resized, cv2.COLOR_RGB2BGR)
        ])
        out.write(side_by_side)
        frame_count += 1
        print(f"Processed frame {frame_count}/{max_frames}")

    cap.release()
    out.release()

    if len(times) > 0:
        Ttime, Mtime = np.sum(times), np.mean(times)
        print(f"\nProcessed {frame_count} frames.")
        print(f"Time taken: {Ttime:.2f} sec at {1. / Mtime:.2f} fps")
        print(f"Saved enhanced video to: {output_video_path}")

except Exception as e:
    print(f"Error processing video: {e}")


Processing video: data/A/vid.mp4
Processed frame 1/95
Processed frame 2/95
Processed frame 3/95
Processed frame 4/95
Processed frame 5/95
Processed frame 6/95
Processed frame 7/95
Processed frame 8/95
Processed frame 9/95
Processed frame 10/95
Processed frame 11/95
Processed frame 12/95
Processed frame 13/95
Processed frame 14/95
Processed frame 15/95
Processed frame 16/95
Processed frame 17/95
Processed frame 18/95
Processed frame 19/95
Processed frame 20/95
Processed frame 21/95
Processed frame 22/95
Processed frame 23/95
Processed frame 24/95
Processed frame 25/95
Processed frame 26/95
Processed frame 27/95
Processed frame 28/95
Processed frame 29/95
Processed frame 30/95
Processed frame 31/95
Processed frame 32/95
Processed frame 33/95
Processed frame 34/95
Processed frame 35/95
Processed frame 36/95
Processed frame 37/95
Processed frame 38/95
Processed frame 39/95
Processed frame 40/95
Processed frame 41/95
Processed frame 42/95
Processed frame 43/95
Processed frame 44/95
Processe

In [21]:
times = []
test_files = sorted(glob(join(opt.data_dir, "*.*")))
for path in test_files:
    try:
        print(f"Processing {path}")
        inp_img = transform(Image.open(path).convert("RGB"))
        inp_img = Variable(inp_img).type(Tensor).unsqueeze(0)
        s = time.time()
        gen_img = model(inp_img)
        times.append(time.time()-s)
        img_sample = torch.cat((inp_img.data, gen_img.data), -1)
        save_image(img_sample, join(opt.sample_dir, basename(path)), normalize=True)
        print ("Tested: %s" % path)
    except Exception as e:
        print(f"Error processing {path}: {e}")


## run-time    
if (len(times) > 1):
    print ("\nTotal samples: %d" % len(test_files)) 
    # accumulate frame processing times (without bootstrap)
    Ttime, Mtime = np.sum(times[1:]), np.mean(times[1:]) 
    print ("Time taken: %d sec at %0.3f fps" %(Ttime, 1./Mtime))
    print("Saved generated images in in %s\n" %(opt.sample_dir))