In [1]:
import os
import subprocess
import sys
import numpy as np
import cv2
import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

In [None]:
# Downscale video first, but keep the aspect ratio
# Get current dimensions from vid.mp4
current_dimensions = os.popen('ffprobe -v error -select_streams v:0 -show_entries stream=width,height -of csv=s=x:p=0 vid.mp4').read().split('x')
current_dimensions = (int(current_dimensions[0]), int(current_dimensions[1]))
print('Current dimensions:', current_dimensions)

# Get new dimensions, making the smallest dimension 256
new_dimensions = (256, 256)
if current_dimensions[0] > current_dimensions[1]:
    new_dimensions = (int(256 * current_dimensions[0] / current_dimensions[1]) // 2 * 2, 256)
else:
    new_dimensions = (256, int(256 * current_dimensions[1] / current_dimensions[0]) // 2 * 2)
print('New dimensions:', new_dimensions)

# Downscale video, using subprocess, and capture output in case of error
# Remove old file if it exists
if os.path.exists('vid_downscaled.mp4'):
    os.remove('vid_downscaled.mp4')
command = 'ffmpeg -i vid.mp4 -vf scale=' + str(new_dimensions[0]) + ':' + str(new_dimensions[1]) + ' vid_downscaled.mp4'
output = subprocess.run(command, shell=True, capture_output=True)
if output.returncode != 0:
    print('Error:', output.stderr.decode('utf-8'))
    sys.exit(1)

# Input

In [None]:
# Use ffmpeg to convert a video to a series of images
target_dir = 'images'
os.makedirs(target_dir, exist_ok=True)
# Remove old files if they exist
for file in os.listdir(target_dir):
    os.remove(os.path.join(target_dir, file))
vid_path = 'vid_downscaled.mp4'
# Get fps of video
fps = os.popen(f'ffprobe -v error -select_streams v -of default=noprint_wrappers=1:nokey=1 -show_entries stream=r_frame_rate {vid_path}').read().strip().split('/')
print(fps)

In [None]:
os.system(f'ffmpeg -i {vid_path} -vf fps={fps[0]}/{fps[1]} {target_dir}/%04d.png')

# Processing

In [None]:
img1 = cv2.imread('images/0001.png')
img2 = cv2.imread('images/0002.png')
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
imgs = np.stack([img1, img2], axis=0) # (B, H, W, C)
h, w, c = imgs.shape[1:]
crop = np.array([[50,50], [100,100]]) # (B, 2)
sliced = np.zeros_like(imgs)
# for i in range(2):
#     sliced[i, crop[i,0]:crop[i,0]+h, crop[i,1]:crop[i,1]+w] = imgs[i, crop[i,0]:crop[i,0]+h, crop[i,1]:crop[i,1]+w]

# Create a grid of indices
batch_indices = np.arange(imgs.shape[0])[:, None, None]
row_indices = np.arange(h)[None, :, None]
col_indices = np.arange(w)[None, None, :]

# Calculate the indices for the magnified image
row_indices = row_indices + crop[:, 1][:, None, None]
col_indices = col_indices + crop[:, 0][:, None, None]

# Remove indices that are out of bounds
row_indices = np.clip(row_indices, 0, h - 1)
col_indices = np.clip(col_indices, 0, w - 1)

# Use advanced indexing to crop the magnified image
sliced = imgs[batch_indices, row_indices, col_indices, :]

fig, ax = plt.subplots(1, 2)
ax[0].imshow(sliced[0])
ax[1].imshow(sliced[1])
plt.show()

In [None]:
def magnify(img, center, factor):
    """Return the image magnified by factor, from center as the reference point, cropped to original size"""
    h, w = img.shape[:2]
    magnified_img = cv2.resize(img, (int(w*factor), int(h*factor)), interpolation=cv2.INTER_LINEAR)
    # Crop to original size h, w, centered around center
    magnified_center = (int(center[0]*factor), int(center[1]*factor)) # Adjust center for magnification
    # Get upper left corner in magnified coordinates
    UL_corner = (magnified_center[0]-center[0], magnified_center[1]-center[1])
    magnified_img = magnified_img[UL_corner[1]:UL_corner[1]+h, UL_corner[0]:UL_corner[0]+w]
    return magnified_img

def magnify_tensor(img, center, factor):
    """Same as magnify, but for tensors of shape (B, H, W, C). Accept center to be either tuple: (x,y) or tensor: (B,2) for different centers for each image"""
    assert len(img.shape) == 4, 'Input tensor must have shape (B, H, W, C)'
    h, w = img.shape[1:3]
    img = torch.tensor(img).permute(0, 3, 1, 2) # (B, C, H, W)
    magnified_img = torchvision.transforms.functional.resize(img, (int(h*factor), int(w*factor)), interpolation=Image.BILINEAR).permute(0, 2, 3, 1).numpy() # (B, H, W, C)
    if isinstance(center, tuple):
        center = np.array(center)[None, :] # Convert to tensor shape (1,2)
    else:
        assert center.shape[0] == img.shape[0], 'Center must be tensor of shape (B,2)'
    # Adjust center for magnification
    magnified_center = (center*factor).astype(int)
    # Get upper left corner in magnified coordinates
    UL_corner = magnified_center - center
    batch_indices = np.arange(img.shape[0])[:, None, None]
    row_indices = np.arange(h)[None, :, None]
    col_indices = np.arange(w)[None, None, :]
    row_indices = row_indices + UL_corner[:, 1][:, None, None]
    col_indices = col_indices + UL_corner[:, 0][:, None, None]
    magnified_img = magnified_img[batch_indices, row_indices, col_indices, :]
    return magnified_img

def get_circle_mask(h, w, center, radius):
    """Return a mask for a circle with center and radius"""
    mask = np.zeros((h, w), np.uint8)
    cv2.circle(mask, center, radius, 255, -1, cv2.LINE_AA)
    return mask

def get_circle_mask_tensor(h, w, center, radius):
    """Same as get_circle_mask, but return (B, H, W) tensor. h and w should be constant, center (B, 2) and radius (B)"""
    assert isinstance(center, np.ndarray), 'center must be numpy array'
    assert isinstance(radius, np.ndarray), 'radius must be numpy array'
    row_indices = np.arange(h)[None, :, None]
    col_indices = np.arange(w)[None, None, :]
    row_indices = row_indices - center[:, 1][:, None, None]
    col_indices = col_indices - center[:, 0][:, None, None]
    mask = row_indices**2 + col_indices**2 <= radius[:, None, None]**2
    return mask.astype(np.uint8)
    

def get_pill_mask(h, w, center, height, width, angle):
    """Return a mask for a pill shape with center, height, width, and angle"""
    mask = np.zeros((h, w), np.uint8)
    # Create a rotated rectangle
    rect_height = height - width # Height of the rectangle part of the pill
    rect = ((center[0], center[1]), (width, rect_height), angle)
    box = cv2.boxPoints(rect)
    box = np.intp(box)
    cv2.drawContours(mask, [box], 0, 255, -1)
    # Place a circle at each end of the pill, in the middle of the width, to make it rounded
    # To do this, interpolate halfway between the two corners at each end
    # First end
    end1 = (box[1] + box[2]) // 2
    cv2.circle(mask, tuple(end1), width//2, 255, -1, cv2.LINE_AA)
    # Second end
    end2 = (box[0] + box[3]) // 2
    cv2.circle(mask, tuple(end2), width//2, 255, -1, cv2.LINE_AA)
    return mask

def feather_edges(img, magnified, mask, feather_size=10):
    """Return image with magnified on top of img, feathered edges"""    
    # Apply Gaussian blur to the mask to create feathered edges
    # feathered_mask = cv2.GaussianBlur(mask, (feather_size*2+1, feather_size*2+1), 0)
    # Use separable filter for faster processing
    feathered_mask = cv2.sepFilter2D(mask, -1, cv2.getGaussianKernel(feather_size*2+1, 0), cv2.getGaussianKernel(feather_size*2+1, 0))
    # Using scipy instead of cv2
    # feathered_mask = scipy.ndimage.gaussian_filter(mask, sigma=feather_size)

    # Normalize feathered mask to be between 0 and 1
    feathered_mask = feathered_mask / 255

    # Add dimension to feathered mask for broadcasting
    feathered_mask = feathered_mask[:, :, None]

    # Blend the magnified with the image using the feathered mask
    # The magnified is displayed where the mask is 1, and the image where the mask is 0
    # And the values in between are interpolated
    # This is done for each color channel
    blended = img*(1-feathered_mask) + magnified*feathered_mask
    return blended.astype(np.uint8)

def feather_edges_tensor(img, magnified, mask, feather_size=10):
    """Same as feather_edges, but for tensors of shape (B, H, W, C)"""
    assert len(img.shape) == 4, 'Input tensor must have shape (B, H, W, C)'
    assert img.shape == magnified.shape, 'img and magnified must have the same shape'
    mask = torch.tensor(mask).unsqueeze(1).float() # (B, 1, H, W)
    # Apply Gaussian blur to the mask to create feathered edges
    # Use separable filter for faster processing
    kernel = torch.tensor(cv2.getGaussianKernel(feather_size*2+1, 0)).float() # (2*feather_size+1, 1)
    kernel = kernel.unsqueeze(0).unsqueeze(0) # (1, 1, 2*feather_size+1, 1)
    feathered_mask = torch.nn.functional.conv2d(mask, kernel, stride=1, padding=feather_size) # (B, 1, H, W)
    feathered_mask = torch.nn.functional.conv2d(feathered_mask, kernel.permute(0, 1, 3, 2), stride=1, padding=feather_size) # (B, 1, H, W)
    feathered_mask = feathered_mask.squeeze(1)[:, feather_size:-feather_size, feather_size:-feather_size, None].numpy() # (B, H, W, 1)

    # Normalize feathered mask to be between 0 and 1
    if feathered_mask.max() > 1:
        feathered_mask = feathered_mask / 255
    # Blend the magnified with the image using the feathered mask
    blended = img*(1-feathered_mask) + magnified*feathered_mask
    return blended.astype(np.uint8)

# Batch image test
img1 = cv2.imread('images/0015.png')
img2 = cv2.imread('images/0016.png')
img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
imgs = np.stack([img1, img2], axis=0)
# center = (100, 100)
center = np.array([[100, 100], [100, 300]])
magnified_imgs = magnify_tensor(imgs, center, 1.5)

# fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# axs[0].imshow(magnified_imgs[0])
# axs[1].imshow(magnified_imgs[1])
# plt.show()

# Create a circle mask
h, w = imgs.shape[1:3]
radius = np.array([200])
mask = get_circle_mask_tensor(h, w, center, radius)

# fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# axs[0].imshow(mask[0])
# axs[1].imshow(mask[1])
# plt.show()

# Feathered edges
feathered = feather_edges_tensor(imgs, magnified_imgs, mask, feather_size=30)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(feathered[0])
axs[1].imshow(feathered[1])
plt.show()

In [None]:
# Single image test
img = cv2.imread('images/0015.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
center = (0.54,0.75)
center = (int(center[0]*img.shape[1]), int(center[1]*img.shape[0]))
radius = 0.74
radius = int(radius*img.shape[1])
feather_size = 0.19
feather_size = int(feather_size*img.shape[1])
magnified = magnify(img, center, 1.5)
# mask = get_circle_mask(img.shape[0], img.shape[1], center, radius)
mask = get_pill_mask(img.shape[0], img.shape[1], center, radius*2, radius, 0)
plt.imshow(mask, cmap='gray')
plt.show()
feathered_circle = feather_edges(img, magnified, mask, feather_size=feather_size)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(feathered_circle)
# Show cross where the center is
ax.plot(center[0], center[1], 'rx')
# Show dotted circle where the circle is
circle = plt.Circle(center, radius, color='r', fill=False, linestyle='dotted')
ax.add_artist(circle)
# Show area that is feathered (the radius of the circle plus the feather size)
circle = plt.Circle(center, radius-feather_size, color='b', fill=False, linestyle='dotted')
ax.add_artist(circle)
circle = plt.Circle(center, radius+feather_size, color='b', fill=False, linestyle='dotted')
ax.add_artist(circle)
plt.show()

In [None]:
# Another method; instead of feathering, we warp the image
# This means we do a non-linear transformation of the image, where the area around the center is magnified
# Which warps into the rest of the image
def warp_image(img, center, factor):
    """Return the image warped by factor, from center as the reference point, cropped to original size"""
    h, w = img.shape[:2]
    # Create a grid of coordinates
    x, y = np.meshgrid(np.arange(w), np.arange(h))
    # Convert to float32
    x = x.astype(np.float32)
    y = y.astype(np.float32)
    # Normalize to -1 to 1, with the variable 'center' as the origin
    x = (x - center[0]) / w # shape: (h, w)
    y = (y - center[1]) / h # shape: (h, w)
    # Calculate the distance from the center
    r = np.sqrt(x**2 + y**2) # Distance from center shape: (h, w)
    # Calculate the new distance from the center
    r = r ** (factor-1)
    # Calculate the new coordinates
    x = x * r * w + center[0]
    y = y * r * h + center[1]

    warped = cv2.remap(img, x, y, cv2.INTER_LINEAR)
    return warped

img = cv2.imread('images/0015.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
center = (0.54,0.75)
center = (int(center[0]*img.shape[1]), int(center[1]*img.shape[0]))
factor = 1.1
warped = warp_image(img, center, factor)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img)
ax[1].imshow(warped)
# Show cross where the center is
ax[1].plot(center[0], center[1], 'rx')
plt.show()

# Output

In [None]:
def feather_magnify(img, center, radius, feather_size):
    # Magnify the circle
    magnified = magnify_tensor(img, center, 1.5)
    # Feather the edges
    mask = get_circle_mask(img.shape[0], img.shape[1], center, radius)
    # mask = get_pill_mask(img.shape[0], img.shape[1], center, 150, 150, 0)
    feathered_circle = feather_edges_tensor(img, magnified, mask, feather_size=feather_size)
    img = feathered_circle
    return img

def warp_magnify(img, center, factor):
    # Warp the image
    warped = warp_image(img, center, factor)
    img = warped
    return img

# Process all images and save them to a new directory
new_dir = 'processed_images'
os.makedirs(new_dir, exist_ok=True)
# Make target dir empty
os.system(f'rm -rf {new_dir}/*')
for img_path in tqdm(os.listdir(target_dir)[:]):
    img = cv2.imread(f'{target_dir}/{img_path}')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # processed_img = feather_magnify(img, center=center, radius=radius, feather_size=feather_size)
    processed_img = warp_magnify(img, center=center, factor=1.2)
    processed_img = cv2.cvtColor(processed_img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(f'{new_dir}/{img_path}', processed_img)

In [None]:
# Change the processed images back into a video
# Remove the target file if it already exists
if os.path.exists('vid_out.mp4'):
    os.remove('vid_out.mp4')
os.system(f'ffmpeg -i {new_dir}/%04d.png -vf fps={fps[0]}/{fps[1]} -c:v libx264 -profile:v high -crf 20 -pix_fmt yuv420p vid_out.mp4')