In [None]:
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

for demo
- 필요한 파일: 소스(인물 사진), 장면 이미지, 포즈 센터 및 포즈 키포인트
- keypoint result를 example_imgs/keypoints.txt 파일에 위치하도록 저장하기
- pose center 값 "position" 변수에 저장  


- 나머지 source/target 파일 경로 변경

1. 배경 이미지 - 즉, 장면 이미지

In [None]:
background_path = './example_imgs/friends_5.png'
position = (350, 550)

try:
    # Open the image file
    img = Image.open(background_path)
    w, h = img.size

    # Check the image mode and convert if necessary
    if img.mode == 'RGBA':
        img = img.convert('RGB')
    elif img.mode != 'RGB':
        img = img.convert('RGB')

    # Display the image
    plt.imshow(img)
    plt.axis('off')  # Turn off axis labels

    # Plot the point on the image
    plt.scatter([position[0]], [position[1]], c='red', s=100)

    plt.show()

    # Save the corrected image if there was an issue
    corrected_image_path = background_path
    img.save(corrected_image_path)
    print(f"Corrected image saved at: {corrected_image_path}")

except Exception as e:
    print(f"An error occurred: {e}")

2. Insert 할 인물 이미지

In [None]:
source_path = './example_imgs/harrystyles.png'

img = Image.open(source_path)
plt.imshow(img)
plt.axis('off') 
plt.show()

In [None]:
# 키포인트파일 경로 및 결과 경로
keypoints_path = './example_imgs/keypoints_6.txt'
save_path = './example_imgs/CFLD_result_6.png'

### Pose generation

In [None]:
%cd pose-generation

In [None]:
from inference_demo_server import inference
path = '.' + background_path
# 'generated_pose' is final output!
generated_pose, base_pose, pose_image, base_pose_image, position_marked_image = inference(path, position)

In [None]:
plt.imshow(pose_image)
plt.axis('off')
plt.show()

In [None]:
plt.imshow(base_pose_image)
plt.axis('off')
plt.show()

### Keypoint 변경

Model Definition

In [None]:
import torch.nn as nn

In [None]:
class HeadEstimator(nn.Module):
    def __init__(self):
        super(HeadEstimator, self).__init__()
        self.fc1 = nn.Linear(8, 64)
        self.fc2 = nn.Linear(64, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.2)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(512)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight.data)
                nn.init.zeros_(m.bias.data)
                

    def forward(self, x):
        x = torch.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn3(self.fc3(x)))
        out = self.fc4(x)

        return out

In [None]:
model = HeadEstimator()
state_dict = torch.load('./pose_generation/model.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

In [None]:
from pose-mapping.utils import convert_pose, resize_pose

In [None]:
generated_pose

In [None]:
pose = convert_pose(model, 'cpu', generated_pose)
pose = resize_pose(pose) 

In [None]:
del resize_pose

In [None]:
pose

save keypoints as file

In [None]:
keypoints_x = [round(point[0]+10, 6) for point in pose]
keypoints_y = [round(point[1]+10, 6) for point in pose]

keypoints_y_str = f'{keypoints_y}'
keypoints_x_str = f'{keypoints_x}'

with open('../example_imgs/keypoints_6.txt', "w") as file:
    file.write(keypoints_y_str + '\n')
    file.write(keypoints_x_str + '\n')

수정된 키포인트 확인

In [None]:
def plot_points(coordinates):
    x_coords, y_coords = keypoints_x, keypoints_y

    # Plotting the coordinates
    plt.figure(figsize=(10, 10))
    plt.scatter(x_coords, y_coords, c='blue', marker='o')
    plt.title('Scatter Plot of Given Coordinates with y=x Line')
    plt.xlabel('X')
    plt.ylabel('Y')

    # Set the aspect ratio of the plot to be equal
    plt.gca().set_aspect('equal', adjustable='box')

    # Set the same scaling for both axes
    min_val = min(0, 1000)
    max_val = max(0, 1000)
    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)

    # Plot the y=x line
    plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--')

    # Display the plot
    plt.grid(True)
    plt.show()

plot_points(pose)

In [None]:
!pwd

### CFLD

In [None]:
%cd ../CFLD

run model

In [None]:
source_path_an = '.' + source_path
keypoints_path_an = '.' + keypoints_path

In [None]:
!python clfd_app.py \
    --source_path $source_path_an \
        --keypoints_path $keypoints_path_an --save_path $'../example_imgs/CFLD_result_6.png'

### Face Swap

In [None]:
torch.cuda.is_available()

In [None]:
%cd ../sber-swap

In [None]:
target_path = '../example_imgs/CFLD_result_7.png'
save_path = '../example_imgs/SWAP_result_7.png'

run model

In [None]:
import torch

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please check your CUDA setup.")

device_index = 4
num_cuda_devices = torch.cuda.device_count()
if device_index >= num_cuda_devices:
    raise RuntimeError(f"Invalid CUDA device index: {device_index}. Available devices: {num_cuda_devices}")

In [None]:
!python swap_app.py --target_path {target_path} --source_path {'.' + source_path} --save_path {save_path}

### Insert image

In [None]:
%cd ..

In [None]:
import torch
import cv2
import supervision as sv
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# weight 불러오기
CHECKPOINT_PATH = 'SAM/weights/sam_vit_h_4b8939.pth'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device='cpu')
# mask_generator 생성
mask_generator = SamAutomaticMaskGenerator(sam) # device 추가

In [None]:
from PIL import Image
from rembg import remove
import numpy as np

def overlay_images(image_path1, image2_path):
    base_image = Image.open(image_path1).convert("RGBA")
    image2 = Image.open(image2_path).convert('RGB')
    overlay_width, overlay_height = image2.size

    # image2 = np.array(image2)

    # remove background using segment-anythig

    image_bgr = cv2.imread(image2_path)
    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    sam_result = mask_generator.generate(image_rgb)

    object_mask = sorted(sam_result, key=lambda x: x['area'], reverse=True)[1]['segmentation']
    # extracted_image = image2 * np.stack([object_mask]*3, axis=-1)
    return object_mask

def overlay_images_2(image_path1, image2_path, object_mask, position):
    # overlay_image = remove(img2) # 배경 제거
    base_image = Image.open(image_path1).convert("RGBA")
    image2 = Image.open(image2_path).convert('RGB')
    overlay_width, overlay_height = image2.size

    # display image
    overlay_center = (position[0] - overlay_width // 2, position[1] - overlay_height // 2)
    # extracted_image = Image.fromarray(extracted_image)

    mask_image = Image.fromarray(object_mask).convert('L')
    base_image.paste(image2, overlay_center, mask=mask_image)

    return base_image

In [None]:
IMG = overlay_images(background_path, save_path[1:])

In [None]:
position = (1000, 650) # 위치 조정이 필요하다면

In [None]:
RESULT = overlay_images_2(background_path, save_path[1:], IMG, position)
RESULT