In [1]:
import cv2
from PIL import Image
import numpy as np
import os
os.environ["OPENCV_VIDEOIO_PRIORITY_MSMF"] = "0"
import IPython.display
import time
from IPython.core.display import HTML
import io

import argparse
import sys
import os
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np
from torch.utils import data
import cv2
from PIL import Image
import random

norm_layer = nn.InstanceNorm2d

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        norm_layer(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        norm_layer(in_features)
                        ]
        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
        super(Generator, self).__init__()
        # Initial convolution block
        model0 = [  nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    norm_layer(64),
                    nn.ReLU(inplace=True) ]
        self.model0 = nn.Sequential(*model0)
        # Downsampling
        model1 = []
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model1 += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        norm_layer(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2
        self.model1 = nn.Sequential(*model1)
        model2 = []
        # Residual blocks
        for _ in range(n_residual_blocks):
            model2 += [ResidualBlock(in_features)]
        self.model2 = nn.Sequential(*model2)
        # Upsampling
        model3 = []
        out_features = in_features//2
        for _ in range(2):
            model3 += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        norm_layer(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2
        self.model3 = nn.Sequential(*model3)
        # Output layer
        model4 = [  nn.ReflectionPad2d(3),
                        nn.Conv2d(64, output_nc, 7)]
        if sigmoid:
            model4 += [nn.Sigmoid()]
        self.model4 = nn.Sequential(*model4)

    def forward(self, x, cond=None):
        out = self.model0(x)
        out = self.model1(out)
        out = self.model2(out)
        out = self.model3(out)
        out = self.model4(out)
        return out   

IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir, stop=10000):
    images = []
    count = 0
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
                count += 1
            if count >= stop:
                return images
    return images

class UnpairedDepthDataset(data.Dataset):
    def __init__(self, root, root2, opt, transforms_r=None, mode='train', midas=False, depthroot=''):
        self.root = root
        self.mode = mode
        self.midas = midas
        all_img = make_dataset(self.root)
        self.depth_maps = 0
        self.data = all_img
        self.mode = mode
        self.transform_r = transforms.Compose(transforms_r)
        self.opt = opt
        self.min_length = len(self.data)

    def __getitem__(self, index):
        img_path = self.data[index]
        basename = os.path.basename(img_path)
        base = basename.split('.')[0]
        img_r = Image.open(img_path).convert('RGB')
        transform_params = get_params(self.opt, img_r.size)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.opt.input_nc == 1), norm=False)   
        A_transform = self.transform_r
        img_r = A_transform(img_r )
        img_depth = 0
        label = 0
        input_dict = {'r': img_r, 'depth': img_depth, 'path': img_path, 'index': index, 'name' : base, 'label': label}
        return input_dict

    def __len__(self):
        return self.min_length

def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img

    __print_size_warning(ow, oh, w, h)
    return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), method)

def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    color = (255, 255, 255)
    if img.mode == 'L':
        color = (255)
    elif img.mode == 'RGBA':
        color = (255, 255, 255, 255)

    if (ow > tw and oh > th):
        return img.crop((x1, y1, x1 + tw, y1 + th))
    elif ow > tw:
        ww = img.crop((x1, 0, x1 + tw, oh))
        return add_margin(ww, size, 0, (th-oh)//2, color)
    elif oh > th:
        hh = img.crop((0, y1, ow, y1 + th))
        return add_margin(hh, size, (tw-ow)//2, 0, color)
    return img

def add_margin(pil_img, newsize, left, top, color=(255, 255, 255)):
    width, height = pil_img.size
    result = Image.new(pil_img.mode, (newsize, newsize), color)
    result.paste(pil_img, (left, top))
    return result

def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

def __print_size_warning(ow, oh, w, h):
    """Print warning information about image size(only print once)"""
    if not hasattr(__print_size_warning, 'has_printed'):
        print("The image size needs to be a multiple of 4. "
              "The loaded image size was (%d, %d), so it was adjusted to "
              "(%d, %d). This adjustment will be done to all images "
              "whose sizes are not multiples of 4" % (ow, oh, w, h))
        __print_size_warning.has_printed = True

def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.preprocess == 'resize_and_crop':
        new_h = new_w = opt.load_size
    elif opt.preprocess == 'scale_width_and_crop':
        new_w = opt.load_size
        new_h = opt.load_size * h // w
    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
    flip = random.random() > 0.5
    return {'crop_pos': (x, y), 'flip': flip}

def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True, norm=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor()]
        if not grayscale:
            if norm:
                transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

opt = {
    'checkpoints_dir': 'checkpoints',
    'results_dir': 'results',
    'geom_name': 'feats2Geom',
    'batchSize': 1,
    'dataroot': 'examples/test',
    'depthroot': '',
    'input_nc': 3,
    'output_nc': 1,
    'geom_nc': 3,
    'every_feat': 1,
    'num_classes': 55,
    'midas': 0,
    'ngf': 64,
    'n_blocks': 3,
    'size': 256, #1080
    'cuda': True,
    'n_cpu': 8,
    'which_epoch': 'latest',
    'aspect_ratio': 1.0,
    'mode': 'test',
    'load_size': 256, #1080
    'crop_size': 256, #1080
    'max_dataset_size': float("inf"),
    'preprocess': 'resize_and_crop',
    'no_flip': False,  # Default is False because it's a store_true argument
    'norm': 'instance',
    'predict_depth': 0,
    'save_input': 0,
    'reconstruct': 0,
    'how_many': 100,
}

# opt['name'] = 'opensketch_style'
opt['name'] = 'anime_style'
# opt['name'] = 'contour_style'
opt['cuda'] = False
opt = argparse.Namespace(**opt)
print(opt)

# opt.no_flip = True
# Check for CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() and opt.cuda else "cpu")

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

import onnx
from onnx2torch import convert
onnx_model = onnx.load(os.path.join(opt.checkpoints_dir, "cartoonGAN","cartoonGAN.onnx" ))

# Convert the ONNX model to a PyTorch model
cartoonGAN_model = convert(onnx_model)

cartoonGAN_model_path = os.path.join(opt.checkpoints_dir, "cartoonGAN","cartoonGAN.pth" )  # Replace with the actual path to your .pth file
state_dict = torch.load(cartoonGAN_model_path, weights_only=True)

# Load the weights into the model
cartoonGAN_model.load_state_dict(state_dict)

def pre_processing(image_path, style="", expand_dim=True):
    input_image = image_path #PIL.Image.open(image_path).convert("RGB")
    input_image = np.asarray(input_image)
    input_image = input_image.astype(np.float32)
    input_image = input_image[:, :, [2, 1, 0]]
    if expand_dim:
        input_image = np.expand_dims(input_image, axis=0)
    return torch.from_numpy(input_image)

def post_processing(transformed_image, style=""):
    if not type(transformed_image) == np.ndarray:
        transformed_image = transformed_image.numpy()
    transformed_image = transformed_image[0]
    transformed_image = transformed_image[:, :, [2, 1, 0]]
    transformed_image = transformed_image * 0.5 + 0.5
    transformed_image = transformed_image * 255
    return transformed_image

def dist_func(left, right):
    scale = np.sqrt((right[0]-left[0])**2+(right[1]-left[1])**2)/300.0
    if (right[0]-left[0])!=0:
        theta = np.arctan((right[1]-left[1])/(right[0]-left[0]))+np.pi/4
    else:
        theta = np.pi/2
    x = scale*np.sqrt(2)*80/np.sqrt(np.tan(theta)**2+1)
    y = np.tan(theta)*x
    return (x, y)

def mask_picture_frame(image, corners, second):
    """
    Masks the area inside the trapezoid defined by corners with pink pixels.
    
    Parameters:
        image (numpy.ndarray): The input image.
        corners (list of tuple): Four corner points (x, y) of the trapezoid in any order.
        
    Returns:
        numpy.ndarray: The modified image with the trapezoid filled with pink pixels.
    """
    # Ensure corners are a numpy array
    corners = np.array(corners, dtype=np.float32)

    # Calculate the center of the corners
    center = np.mean(corners, axis=0)

    # Sort corners based on their angle from the center
    def angle_from_center(point):
        return np.arctan2(point[1] - center[1], point[0] - center[0])
    
    sorted_corners = sorted(corners, key=angle_from_center)

    # Convert sorted corners back to numpy array
    sorted_corners = np.array(sorted_corners, dtype=np.int32)

    # Create a mask the same size as the image
    mask = np.zeros_like(image, dtype=np.uint8)

    # Fill the trapezoid in the mask with white (255, 255, 255)
    cv2.fillPoly(mask, [sorted_corners], (255, 255, 255))

    # Define the pink color (BGR)
    pink_color = (255, 105, 180)  # Bright pink in BGR

    # Use the mask to replace the trapezoid area with pink
    image[np.where((mask == 255).all(axis=2))] = second[np.where((mask == 255).all(axis=2))]

    return image

Namespace(checkpoints_dir='checkpoints', results_dir='results', geom_name='feats2Geom', batchSize=1, dataroot='examples/test', depthroot='', input_nc=3, output_nc=1, geom_nc=3, every_feat=1, num_classes=55, midas=0, ngf=64, n_blocks=3, size=256, cuda=False, n_cpu=8, which_epoch='latest', aspect_ratio=1.0, mode='test', load_size=256, crop_size=256, max_dataset_size=inf, preprocess='resize_and_crop', no_flip=False, norm='instance', predict_depth=0, save_input=0, reconstruct=0, how_many=100, name='anime_style')


In [2]:
with torch.no_grad():
    switch_it=0
    # Networks
    net_G = Generator(opt.input_nc, opt.output_nc, opt.n_blocks).to(device)
    net_G2 = Generator(opt.input_nc, opt.output_nc, opt.n_blocks).to(device)
    net_G3 = Generator(opt.input_nc, opt.output_nc, opt.n_blocks).to(device)
    # Load state dicts
    net_G.load_state_dict(torch.load(os.path.join(opt.checkpoints_dir, opt.name, f'netG_A_{opt.which_epoch}.pth'), map_location=device, weights_only=True))
    net_G2.load_state_dict(torch.load(os.path.join(opt.checkpoints_dir, "opensketch_style", f'netG_A_{opt.which_epoch}.pth'), map_location=device, weights_only=True))
    net_G3.load_state_dict(torch.load(os.path.join(opt.checkpoints_dir, "contour_style", f'netG_A_{opt.which_epoch}.pth'), map_location=device, weights_only=True))
    # print('loaded', os.path.join(opt.checkpoints_dir, opt.name, f'netG_A_{opt.which_epoch}.pth'))
    # Set model's test mode
    net_G.eval()
    net_G2.eval()
    net_G3.eval()
    cartoonGAN_model.eval()
    transforms_r = [transforms.Resize(int(opt.size), Image.BICUBIC),
                    transforms.ToTensor()]
    the_transformation = transforms.Compose(transforms_r)
    cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
    running = 0
    dictionary = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_4X4_100)
    parameters =  cv2.aruco.DetectorParameters()
    detector = cv2.aruco.ArucoDetector(dictionary, parameters)
    a = 0
    b = 0
    c = 0
    d = 0
    new_w = 20
    new_h = 20
    full_w = 50
    full_h = 50
    while(cap.isOpened()):
        ret, frame = cap.read()
        cv2.startWindowThread()
        try:
            if ret == True:
                locs_and_ids = []
                image = frame.copy()
                image = cv2.resize(image, (456,256))
                corners, ids, rejectedCandidates = detector.detectMarkers(image)
            
                if len(corners) > 0:
                    # flatten the ArUco IDs list
                    ids = ids.flatten()
                    # loop over the detected ArUCo corners
                    for (markerCorner, markerID) in zip(corners, ids):
                        corners = markerCorner.reshape((4, 2))
                        (topLeft, topRight, bottomRight, bottomLeft) = corners
                        topRight = (int(topRight[0]), int(topRight[1]))
                        bottomRight = (int(bottomRight[0]), int(bottomRight[1]))
                        bottomLeft = (int(bottomLeft[0]), int(bottomLeft[1]))
                        topLeft = (int(topLeft[0]), int(topLeft[1]))
                        if markerID==1:
                            temp_a_dist = dist_func(topLeft, topRight)
                            a = (topLeft[0]-int(temp_a_dist[0]), topLeft[1]-int(temp_a_dist[1]))
                        elif markerID==2:
                            # temp_b_dist = int((topRight[0]-topLeft[0])/300.0*80)
                            temp_b_dist = dist_func(topLeft, topRight)
                            b = (topRight[0]+int(temp_b_dist[0]), topRight[1]-int(temp_b_dist[1]))
                        elif markerID==3:
                            # temp_c_dist = int((bottomRight[0]-bottomLeft[0])/300.0*80)
                            temp_c_dist = dist_func(bottomLeft, bottomRight)
                            c = (bottomLeft[0]-int(temp_c_dist[0]), bottomLeft[1]+int(temp_c_dist[1]))
                        elif markerID==4:
                            # temp_d_dist = int((bottomRight[0]-bottomLeft[0])/300.0*80)
                            temp_d_dist = dist_func(bottomLeft, bottomRight)
                            d = (bottomRight[0]+int(temp_d_dist[0]), bottomRight[1]+int(temp_d_dist[1]))
                        cv2.line(image, topLeft, topRight, (0, 255, 0), 2)
                        cv2.line(image, topRight, bottomRight, (0, 255, 0), 2)
                        cv2.line(image, bottomRight, bottomLeft, (0, 255, 0), 2)
                        cv2.line(image, bottomLeft, topLeft, (0, 255, 0), 2)
                if a != 0 and b != 0 and c != 0 and d != 0:
                    src = np.float32([a,c,d,b])
                else:
                    src = np.float32([((full_w-new_w)/2, (full_h-new_h)/2), 
                                      ((full_w-new_w)/2, (full_h-new_h)/2+new_h), 
                                      ((full_w-new_w)/2+new_w,  (full_h-new_h)/2+new_h), 
                                      ((full_w-new_w)/2+new_w,  (full_h-new_h)/2)])

                # half = cv2.resize(image, (1920, 1080))
                half = image.copy()
                frame_rgb = cv2.cvtColor(half, cv2.COLOR_BGR2RGB)
                pil_image = Image.fromarray(frame_rgb)
                img_r = the_transformation(pil_image)
                input_dict = {'r': img_r}
                img_r2 = Variable(input_dict['r']).to(device)
                real_A = img_r2

                if switch_it==0:
                    new_image = net_G(real_A)
                    image_np = new_image.cpu().detach().numpy()  # Move to CPU and convert to NumPy
                    image_np = np.transpose(image_np, (1, 2, 0))
                    image_np = (image_np * 255).astype(np.uint8)
                    image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
                    the_output_image = image_bgr
                elif switch_it==1:
                    new_image2 = net_G2(real_A)
                    image_np2 = new_image2.cpu().detach().numpy()
                    image_np2 = np.transpose(image_np2, (1, 2, 0))
                    image_np2 = (image_np2 * 255).astype(np.uint8)
                    image_bgr2 = cv2.cvtColor(image_np2, cv2.COLOR_RGB2BGR)
                    the_output_image = image_bgr2
                elif switch_it==2:
                    new_image3 = net_G3(real_A)
                    image_np3 = new_image3.cpu().detach().numpy()
                    image_np3 = np.transpose(image_np3, (1, 2, 0))
                    image_np3 = (image_np3 * 255).astype(np.uint8)
                    image_bgr3 = cv2.cvtColor(image_np3, cv2.COLOR_RGB2BGR)
                    the_output_image = image_bgr3
                elif switch_it==3:
                    CG_input_image = pre_processing(pil_image)
                    CG_transformed_image = cartoonGAN_model(CG_input_image)
                    CG_output_image = post_processing(CG_transformed_image)
                    CG_bgr_output = cv2.cvtColor(CG_output_image.astype(np.uint8), cv2.COLOR_BGR2RGB)
                    the_output_image = CG_bgr_output

                # cv2.imshow('smaller', half)
                # cv2.imshow('output', the_output_image)
                masked = mask_picture_frame(image, src, the_output_image)
                cv2.imshow('masked', masked)

                pressedKey = cv2.waitKey(1) & 0xFF
                if pressedKey == ord('q'):
                    break
                elif pressedKey == ord('s'):
                    cv2.imwrite('full_frame_output.png', image_bgr)
                elif pressedKey == ord('w'):
                    switch_it+=1
                    if switch_it==4:
                        switch_it = 0
            else:
                break
        except:
            break
    cv2.destroyAllWindows()
    for i in range(5):
        cv2.waitKey(1)
    cap.release()

In [None]:
mask_picture_frame(image, src, the_output_image)

In [None]:
the_output_image.shape

In [None]:
image.shape