In [1]:
import torchvision
import torchvision.transforms as transforms
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import re
import csv
import os
import operator
from PIL import Image
from PIL import ImageDraw

In [None]:
class HandPoseDataset(T.utils.data.Dataset):
    def __init__(self, csv_path, root_path, transform = None):
        self.root_path = root_path
        self.csv_data = pd.read_csv(csv_path)
        self.transform = transform

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, i):
        img_name = 'Path' + self.csv_data.iloc[i, 0]
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        pose = np.array(self.csv_data.iloc[i, 1:])
        return image, pose

In [None]:
class RandomImageDataset(T.utils.data.Dataset):
    def __init__(self, root_path, transform = None):
        self.root_path = root_path
        self.transform = transform

    def __len__(self):
        return len(os.listdir(self.root_path))

    def __getitem__(self, i):
        img = Image.open(os.listdir(self.root_path)[i])
        tr = transforms.Compose([transforms.res])
        if self.transform:
            img = self.transform(img)
        return img, [[0, 0] for _ in range(11)]

In [None]:
class MergedDataset(T.utils.data.Dataset):
    def __init__(self, hand_pose_datsset: T.utils.data.Dataset, random_image_dataset: T.utils.data.Dataset, random_seed_1, random_seed_2, transform = None):
        self.hand_pose_datsset = hand_pose_datsset
        self.random_image_dataset = random_image_dataset
        self.left_indices = list(*range(len(self.hand_pose_datsset) + len(self.random_image_dataset)))
        self.right_indices = list(*range(len(self.left_indices)))
        self.transform = transform
        np.random.seed(random_seed_1)
        np.random.shuffle(self.left_indices)
        np.random.seed(random_seed_2)
        np.random.shuffle(self.right_indices)
    
    def __len__(self):
        return len(self.left_indices)

    def __getitem__(self, i):
        ind_l = self.left_indices[i]
        ind_r = self.right_indices[i]
        image_l, pose_l = self.hand_pose_datsset[ind_l] if ind_l < len(self.hand_pose_datsset) else self.random_image_dataset[ind_l - len(self.hand_pose_datsset)]
        image_r, pose_r = self.hand_pose_datsset[ind_r] if ind_r < len(self.hand_pose_datsset) else self.random_image_dataset[ind_r - len(self.hand_pose_datsset)]
        img = None
        pose = None
        # img = image_l + image_r # merge images
        pose = pose_l + pose_r
        return img, pose

In [None]:
def merge(img1, img2, shift):
    # compute the size of the panorama
    nw, nh = map(max, map(operator.add, img2.size, shift), img1.size)

    # paste img1 on top of img2
    newimg1 = Image.new('RGBA', size=(nw, nh), color=(0, 0, 0, 0))
    newimg1.paste(img2, shift)
    newimg1.paste(img1, (0, 0))

    # paste img2 on top of img1
    newimg2 = Image.new('RGBA', size=(nw, nh), color=(0, 0, 0, 0))
    newimg2.paste(img1, (0, 0))
    newimg2.paste(img2, shift)

    # blend with alpha=0.5
    result = Image.blend(newimg1, newimg2, alpha=0.5)
    return result