In [16]:
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2
from PIL import Image
import numpy as np
import os
os.environ["OPENCV_VIDEOIO_PRIORITY_MSMF"] = "0"
import IPython.display
import time
from IPython.core.display import HTML
import io
import argparse
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np
from torch.utils import data
import cv2
from PIL import Image
import random
from typing import Tuple, Union
import math


norm_layer = nn.InstanceNorm2d

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        norm_layer(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        norm_layer(in_features)
                        ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
        super(Generator, self).__init__()

        # Initial convolution block
        model0 = [  nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    norm_layer(64),
                    nn.ReLU(inplace=True) ]
        self.model0 = nn.Sequential(*model0)

        # Downsampling
        model1 = []
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model1 += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        norm_layer(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2
        self.model1 = nn.Sequential(*model1)

        model2 = []
        # Residual blocks
        for _ in range(n_residual_blocks):
            model2 += [ResidualBlock(in_features)]
        self.model2 = nn.Sequential(*model2)

        # Upsampling
        model3 = []
        out_features = in_features//2
        for _ in range(2):
            model3 += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        norm_layer(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2
        self.model3 = nn.Sequential(*model3)

        # Output layer
        model4 = [  nn.ReflectionPad2d(3),
                        nn.Conv2d(64, output_nc, 7)]
        if sigmoid:
            model4 += [nn.Sigmoid()]

        self.model4 = nn.Sequential(*model4)

    def forward(self, x, cond=None):
        out = self.model0(x)
        out = self.model1(out)
        out = self.model2(out)
        out = self.model3(out)
        out = self.model4(out)

        return out   

IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir, stop=10000):
    images = []
    count = 0
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
                count += 1
            if count >= stop:
                return images
    return images

class UnpairedDepthDataset(data.Dataset):
    def __init__(self, root, root2, opt, transforms_r=None, mode='train', midas=False, depthroot=''):
        self.root = root
        self.mode = mode
        self.midas = midas
        all_img = make_dataset(self.root)
        self.depth_maps = 0
        self.data = all_img
        self.mode = mode
        self.transform_r = transforms.Compose(transforms_r)
        self.opt = opt
        self.min_length = len(self.data)

    def __getitem__(self, index):
        img_path = self.data[index]
        basename = os.path.basename(img_path)
        base = basename.split('.')[0]
        img_r = Image.open(img_path).convert('RGB')
        transform_params = get_params(self.opt, img_r.size)
        A_transform = get_transform(self.opt, transform_params, grayscale=(self.opt.input_nc == 1), norm=False)   
        A_transform = self.transform_r
        img_r = A_transform(img_r )
        img_depth = 0
        label = 0
        input_dict = {'r': img_r, 'depth': img_depth, 'path': img_path, 'index': index, 'name' : base, 'label': label}
        return input_dict

    def __len__(self):
        return self.min_length

def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img

    __print_size_warning(ow, oh, w, h)
    return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), method)

def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    color = (255, 255, 255)
    if img.mode == 'L':
        color = (255)
    elif img.mode == 'RGBA':
        color = (255, 255, 255, 255)

    if (ow > tw and oh > th):
        return img.crop((x1, y1, x1 + tw, y1 + th))
    elif ow > tw:
        ww = img.crop((x1, 0, x1 + tw, oh))
        return add_margin(ww, size, 0, (th-oh)//2, color)
    elif oh > th:
        hh = img.crop((0, y1, ow, y1 + th))
        return add_margin(hh, size, (tw-ow)//2, 0, color)
    return img

def add_margin(pil_img, newsize, left, top, color=(255, 255, 255)):
    width, height = pil_img.size
    result = Image.new(pil_img.mode, (newsize, newsize), color)
    result.paste(pil_img, (left, top))
    return result

def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

def __print_size_warning(ow, oh, w, h):
    """Print warning information about image size(only print once)"""
    if not hasattr(__print_size_warning, 'has_printed'):
        print("The image size needs to be a multiple of 4. "
              "The loaded image size was (%d, %d), so it was adjusted to "
              "(%d, %d). This adjustment will be done to all images "
              "whose sizes are not multiples of 4" % (ow, oh, w, h))
        __print_size_warning.has_printed = True

def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.preprocess == 'resize_and_crop':
        new_h = new_w = opt.load_size
    elif opt.preprocess == 'scale_width_and_crop':
        new_w = opt.load_size
        new_h = opt.load_size * h // w
    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
    flip = random.random() > 0.5
    return {'crop_pos': (x, y), 'flip': flip}

def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True, norm=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor()]
        if not grayscale:
            if norm:
                transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

# parser = argparse.ArgumentParser()
# parser.add_argument('--name', required=True, type=str, help='name of this experiment')
# parser.add_argument('--checkpoints_dir', type=str, default='checkpoints', help='Where the model checkpoints are saved')
# parser.add_argument('--results_dir', type=str, default='results', help='where to save result images')
# parser.add_argument('--geom_name', type=str, default='feats2Geom', help='name of the geometry predictor')
# parser.add_argument('--batchSize', type=int, default=1, help='size of the batches')
# parser.add_argument('--dataroot', type=str, default='', help='root directory of the dataset')
# parser.add_argument('--depthroot', type=str, default='', help='dataset of corresponding ground truth depth maps')
# parser.add_argument('--input_nc', type=int, default=3, help='number of channels of input data')
# parser.add_argument('--output_nc', type=int, default=1, help='number of channels of output data')
# parser.add_argument('--geom_nc', type=int, default=3, help='number of channels of geometry data')
# parser.add_argument('--every_feat', type=int, default=1, help='use transfer features for the geometry loss')
# parser.add_argument('--num_classes', type=int, default=55, help='number of classes for inception')
# parser.add_argument('--midas', type=int, default=0, help='use midas depth map')
# parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
# parser.add_argument('--n_blocks', type=int, default=3, help='number of resnet blocks for generator')
# parser.add_argument('--size', type=int, default=256, help='size of the data (squared assumed)')
# parser.add_argument('--cuda', action='store_true', help='use GPU computation', default=True)
# parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
# parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load from')
# parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
# parser.add_argument('--mode', type=str, default='test', help='train, val, test, etc')
# parser.add_argument('--load_size', type=int, default=256, help='scale images to this size')
# parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
# parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
# parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
# parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
# parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
# parser.add_argument('--predict_depth', type=int, default=0, help='run geometry prediction on the generated images')
# parser.add_argument('--save_input', type=int, default=0, help='save input image')
# parser.add_argument('--reconstruct', type=int, default=0, help='get reconstruction')
# parser.add_argument('--how_many', type=int, default=100, help='number of images to test')

# opt = parser.parse_args()

opt = {
    'checkpoints_dir': 'checkpoints',
    'results_dir': 'results',
    'geom_name': 'feats2Geom',
    'batchSize': 1,
    'dataroot': 'examples/test',
    'depthroot': '',
    'input_nc': 3,
    'output_nc': 1,
    'geom_nc': 3,
    'every_feat': 1,
    'num_classes': 55,
    'midas': 0,
    'ngf': 64,
    'n_blocks': 3,
    'size': 1080, #256,
    'cuda': True,
    'n_cpu': 8,
    'which_epoch': 'latest',
    'aspect_ratio': 1.0,
    'mode': 'test',
    'load_size': 1080, #256,
    'crop_size': 1080, #256,
    'max_dataset_size': float("inf"),
    'preprocess': 'resize_and_crop',
    'no_flip': False,  # Default is False because it's a store_true argument
    'norm': 'instance',
    'predict_depth': 0,
    'save_input': 0,
    'reconstruct': 0,
    'how_many': 100,
}

opt['name'] = 'opensketch_style'
opt['cuda'] = False
opt = argparse.Namespace(**opt)
print(opt)

# opt.no_flip = True
# Check for CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() and opt.cuda else "cpu")

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")


def visualize(image,detection_result) -> np.ndarray:
    """Draws bounding boxes and keypoints on the input image and return it.
    Args:
    image: The input RGB image.
    detection_result: The list of all "Detection" entities to be visualize.
    Returns:
    Image with bounding boxes.
    """
    annotated_image = image.copy()
    height, width, _ = image.shape
    
    for detection in detection_result.detections:
    # Draw bounding_box
        bbox = detection.bounding_box
        start_point = bbox.origin_x, bbox.origin_y
        end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height
        cv2.rectangle(annotated_image, start_point, end_point, TEXT_COLOR, 3)

    # Draw keypoints
        for keypoint in detection.keypoints:
            keypoint_px = _normalized_to_pixel_coordinates(keypoint.x, keypoint.y,
                                                         width, height)
            color, thickness, radius = (0, 255, 0), 2, 2
            cv2.circle(annotated_image, keypoint_px, thickness, color, radius)

    # Draw label and score
            category = detection.categories[0]
            category_name = category.category_name
            category_name = '' if category_name is None else category_name
            probability = round(category.score, 2)
            result_text = category_name + ' (' + str(probability) + ')'
            text_location = (MARGIN + bbox.origin_x,
                             MARGIN + ROW_SIZE + bbox.origin_y)
            cv2.putText(annotated_image, result_text, text_location, cv2.FONT_HERSHEY_PLAIN,
                        FONT_SIZE, TEXT_COLOR, FONT_THICKNESS)

    return annotated_image

def crop_it(image, detection_result):
    annotated_image = image.copy()
    height, width, _ = image.shape
    
    for detection in detection_result.detections:
    # Draw bounding_box
        bbox = detection.bounding_box
        start_point = bbox.origin_x, bbox.origin_y
        end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height
    if bbox.origin_x-50>0:
        init_x_adjust = -50
    else:
        init_x_adjust = 0
    if bbox.origin_x + bbox.width + 50 < width:
        init_w_adjust = 50
    else:
        init_w_adjust = 0
    if bbox.origin_y-100>0:
        init_y_adjust = -100
    else:
        init_y_adjust = 0
    if bbox.origin_y + bbox.height + 50 < height:
        init_h_adjust = 50
    else:
        init_h_adjust = 0    
    return annotated_image[bbox.origin_y+init_y_adjust:bbox.origin_y+bbox.height+init_h_adjust, bbox.origin_x+init_x_adjust:bbox.origin_x + bbox.width+init_w_adjust]


MARGIN = 10  # pixels
ROW_SIZE = 10  # pixels
FONT_SIZE = 1
FONT_THICKNESS = 1
TEXT_COLOR = (255, 0, 0)  # red

def _normalized_to_pixel_coordinates(
    normalized_x: float, normalized_y: float, image_width: int,
    image_height: int) -> Union[None, Tuple[int, int]]:
    """Converts normalized value pair to pixel coordinates."""

  # Checks if the float value is between 0 and 1.
def is_valid_normalized_value(value: float) -> bool:
    return (value > 0 or math.isclose(0, value)) and (value < 1 or math.isclose(1, value))

    if not (is_valid_normalized_value(normalized_x) and
        is_valid_normalized_value(normalized_y)):
    # TODO: Draw coordinates even if it's outside of the image bounds.
        return None
    
    x_px = min(math.floor(normalized_x * image_width), image_width - 1)
    y_px = min(math.floor(normalized_y * image_height), image_height - 1)
    return x_px, y_px

Namespace(checkpoints_dir='checkpoints', results_dir='results', geom_name='feats2Geom', batchSize=1, dataroot='examples/test', depthroot='', input_nc=3, output_nc=1, geom_nc=3, every_feat=1, num_classes=55, midas=0, ngf=64, n_blocks=3, size=1080, cuda=False, n_cpu=8, which_epoch='latest', aspect_ratio=1.0, mode='test', load_size=1080, crop_size=1080, max_dataset_size=inf, preprocess='resize_and_crop', no_flip=False, norm='instance', predict_depth=0, save_input=0, reconstruct=0, how_many=100, name='opensketch_style')


In [17]:
BG_COLOR = (192, 192, 192) # gray
MASK_COLOR = (255, 255, 255) # white

# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='deeplab_v3.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
switch_it = 0
with torch.no_grad():
    # Networks
    net_G = Generator(opt.input_nc, opt.output_nc, opt.n_blocks).to(device)
    # Load state dicts
    net_G.load_state_dict(torch.load(os.path.join(opt.checkpoints_dir, opt.name, f'netG_A_{opt.which_epoch}.pth'), map_location=device, weights_only=True))
    # print('loaded', os.path.join(opt.checkpoints_dir, opt.name, f'netG_A_{opt.which_epoch}.pth'))
    # Set model's test mode
    net_G.eval()
    transforms_r = [transforms.Resize(int(opt.size), Image.BICUBIC),
                    transforms.ToTensor()]
    the_transformation = transforms.Compose(transforms_r)
    with vision.ImageSegmenter.create_from_options(options) as segmenter:
        cap = cv2.VideoCapture(0)
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)
        running = 0
        while(cap.isOpened()):
            ret, frame = cap.read()
            cv2.startWindowThread()
            try:
                if ret == True:
                    image = frame.copy()
                    # half = image.copy()
                    half = cv2.resize(image, (1920, 1080))
                    frame_rgb = cv2.cvtColor(half, cv2.COLOR_BGR2RGB)
                    pil_image = Image.fromarray(frame_rgb)
                    mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_image))

                    segmentation_result = segmenter.segment(mp_image)
                    category_mask = segmentation_result.category_mask
                    frame_rgb = cv2.cvtColor(half, cv2.COLOR_BGR2RGB)
                    pil_image = Image.fromarray(frame_rgb)
                    img_r = the_transformation(pil_image)
                    input_dict = {'r': img_r}
                    img_r2 = Variable(input_dict['r']).to(device)
                    real_A = img_r2
                    new_image = net_G(real_A)
                    image_np = new_image.cpu().detach().numpy()  # Move to CPU and convert to NumPy
                    
                    # # Transpose from (C, H, W) to (H, W, C)
                    image_np = np.transpose(image_np, (1, 2, 0))
                    
                    # # Scale from [0, 1] to [0, 255] and convert to uint8
                    image_np = (image_np * 255).astype(np.uint8)
                    # Generate solid color images for showing the output segmentation mask.
                    image_data = mp_image.numpy_view()
                    fg_image = np.zeros(image_data.shape, dtype=np.uint8)
                    bg_image = np.zeros(image_data.shape, dtype=np.uint8)
                    if switch_it%2==0:
                        fg_image[:] = image_np
                        bg_image[:] = half
                    else:
                        fg_image[:] = half
                        bg_image[:] = image_np
                    
                    condition = np.stack((category_mask.numpy_view(),) * 3, axis=-1) > 0.2
                    output_image = np.where(condition, fg_image, bg_image)
                    cv2.imshow('segmented', output_image)
                    cv2.imshow('true', half)
                    pressedKey = cv2.waitKey(1) & 0xFF
                    if pressedKey == ord('q'):
                        break
                    elif pressedKey == ord('w'):
                        switch_it+=1
                else:
                    break
            except:
                break
        cv2.destroyAllWindows()
        for i in range(5):
            cv2.waitKey(1)
        cap.release()