In [5]:
import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import image as mpimg

segment_path = '/mnt/DATA/dronetracking/final_composition/foreground/video15/original_sized_mask'
img_path = "/mnt/DATA/dronetracking/Anti-UAV-Tracking-V0/video15"

seg_file_list = sorted([f for f in os.listdir(segment_path) if f.endswith('.png')])
img_file_list = sorted([f for f in os.listdir(img_path) if f.endswith('.jpg')])

output_path = "/mnt/DATA/dronetracking/final_composition/foreground/video15/foreground_image"  

# Ensure the output path exists
os.makedirs(output_path, exist_ok=True)
# Check if the number of segment images matches the number of images
if len(seg_file_list) != len(img_file_list):
    raise ValueError("The number of segment images does not match the number of original images.")

for seg_file, img_file in zip(seg_file_list, img_file_list):
    segment = mpimg.imread(os.path.join(segment_path, seg_file))
    image = mpimg.imread(os.path.join(img_path, img_file))

    # Ensure the segment image is in the correct format (e.g., RGB)
    if len(segment.shape) == 3 and segment.shape[2] == 4:
        # Convert RGBA to RGB if the image has an alpha channel
        segment = segment[:, :, :3]

    rounded_segment = np.around(segment).astype(np.uint8)

    # Ensure the image is also in the correct format
    if len(image.shape) == 3 and image.shape[2] == 4:
        image = image[:, :, :3]

    sgm = rounded_segment * image

    # Save the result
    output_file = os.path.join(output_path, f"result_{os.path.splitext(img_file)[0]}.png")
    plt.imsave(output_file, sgm)

