In [1]:
import cv2
import sys, os
import argparse
import numpy as np
import torch
from pathlib import Path
from matplotlib import pyplot as plt
from typing import Any, Dict, List
import pandas as pd

from sam_segment import predict_masks_with_sam
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
    show_mask, show_points, get_clicked_point

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
img_folder_path =f'../wim_data/final_dataset/train/images/'
txt_folder_path = f'../wim_data/final_dataset/train/labels/'

In [3]:
file_names = [os.path.splitext(f)[0] for f in os.listdir(img_folder_path) if f.endswith('.jpg') or f.endswith('.png')]
print((file_names))

['2023-04-24_45.mp4#t=280', '20231031_19a155ec-253b-4ee5-8ec0-7cbb7ff842a9', '20231031_cad5af4d-c1e7-4a72-872c-19a07299300b', '2023-04-21_48.mp4#t=46', '20231031_fe33d25a-c856-4675-9739-3254e0020fb8', '20231031_4cd5143e-27d2-407d-910a-cce658c7757f', '2023-04-24_46.mp4#t=36', '20231031_60afff24-4415-43ae-8a39-2f39b21df4b2', '20231031_846eaf04-2b22-4a49-9feb-66bc1b75ff71', '20231031_e0adb3be-4057-4478-8f75-2bccfda92fed', '20231031_58b7eb28-620c-4006-8722-035db2cc6fba', '20231031_a64de3fd-fa98-4db5-9fcd-373284cf4462', '20231031_eca4edaa-0749-41d9-8d56-0dcac6c325d9', '20231031_752f3809-0117-4997-bdfe-71b01cdb5ebb', '20231031_ef04491c-3334-4dd6-acbd-ed65a2ef0479', '2023-04-21_61.mp4#t=40', '2023-04-24_63.mp4#t=156', '2023-04-21_48.mp4#t=102', '2023-04-24_57.mp4#t=72', '20231031_cfa401b5-4a6a-4bcb-b625-4eb0049ec261', '20231031_35e3524d-47c4-4e1b-922e-39e9a3bb548a', '20231031_798f0698-7819-4dcb-a614-9bcc30981551', '2023-04-24_45.mp4#t=120', '20231031_7e5d73b1-4bf2-42da-a715-a1a8de6efe01', '20

In [4]:
def normalized_coordinate_to_absolute(norm_x, norm_y, image_width=1280, image_height=720):
    abs_x = int(norm_x * image_width)
    abs_y = int(norm_y * image_height)
    return [abs_x, abs_y]

In [5]:
def extract_coordinates(txt_folder_path, img_folder_path):
    # 이미지 폴더에서 모든 파일명을 가져옴
    file_names = [os.path.splitext(f)[0] for f in os.listdir(img_folder_path) if f.endswith('.jpg') or f.endswith('.png')]

    # 데이터를 저장할 리스트 초기화
    data = []

    for file_name in file_names:
        txt_file = os.path.join(txt_folder_path, file_name + '.txt')

        # 해당 .txt 파일이 존재하는지 확인
        if os.path.exists(txt_file):
            with open(txt_file, 'r') as file:
                lines = file.readlines()
                for line in lines:
                    try:
                        class_id, x_center, y_center, width, height = line.strip().split()
                        data.append({
                            "file_id": file_name,
                            "x_center": float(x_center),
                            "y_center": float(y_center),
                            "width": float(width),
                            "height": float(height),
                            "label": int(class_id)
                        })
                    except ValueError:
                        print(f"Line parsing error in file {file_name}: {line}")
        else:
            print(f"No annotation for image {file_name}")

    # 데이터프레임 생성
    return pd.DataFrame(data)

In [6]:
df = extract_coordinates(txt_folder_path, img_folder_path)
df

Unnamed: 0,file_id,x_center,y_center,width,height,label
0,2023-04-24_45.mp4#t=280,0.462414,0.574525,0.167586,0.250000,23
1,2023-04-24_45.mp4#t=280,0.620000,0.470358,0.124138,0.127451,1
2,2023-04-24_45.mp4#t=280,0.625862,0.374157,0.084828,0.121324,23
3,2023-04-24_45.mp4#t=280,0.646897,0.282246,0.053793,0.111520,25
4,2023-04-24_45.mp4#t=280,0.518276,0.311045,0.080690,0.147059,12
...,...,...,...,...,...,...
17231,2023-04-24_64.mp4#t=140,0.637330,0.193738,0.101226,0.210913,8
17232,2023-04-24_64.mp4#t=140,0.604080,0.061949,0.155527,0.117286,0
17233,2023-04-24_64.mp4#t=140,0.465657,0.217663,0.124204,0.243011,13
17234,2023-04-24_64.mp4#t=140,0.358363,0.383565,0.097756,0.130002,3


In [11]:
input_img = '../wim_data/train/images/2023-04-21_48.mp4#t=0.jpg'
point_labels = [1]
sam_model_type = "vit_h"
sam_ckpt = './pretrained_models/sam_vit_h_4b8939.pth'

In [7]:
len(df)

17236

In [6]:
import tqdm

In [19]:
df1 = df.iloc[[0]].copy()
df1

Unnamed: 0,file_id,x_center,y_center,width,height,label
0,2023-04-24_45.mp4#t=280,0.462414,0.574525,0.167586,0.25,23


In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [23]:
ob_bd_path = '../wim_data/objects/images_bounding/'
ob_path = '../wim_data/objects/images_object/'
mask_object = '../wim_data/objects/masks_object/'
mask_full_path = '../wim_data/objects/masks_full/'

for index, row in df1.iterrows():
    input_img = img_folder_path + row.file_id + '.jpg'
    img = load_img_to_array(input_img)
    x_center = row.x_center
    y_center = row.y_center
    latest_coords = normalized_coordinate_to_absolute(x_center, y_center)
    
    masks, _, _ = predict_masks_with_sam(
        img,
        [latest_coords],
        point_labels,
        model_type=sam_model_type,
        ckpt_p=sam_ckpt,
        device=device,
    )
    masks = masks.astype(np.uint8)* 255
    mask_full = masks[2]
    
    extracted_object = cv2.bitwise_and(img, img, mask=mask_full)
    mask_full_cropped = cv2.bitwise_and(mask_full, mask_full)
    
    x, y, w, h = cv2.boundingRect(mask_full)
    cropped_image = extracted_object[y:y+h, x:x+w]
    cropped_mask = mask_full_cropped[y:y+h, x:x+w]
    
    common_name = f"{row.file_id}_{row.label}_{row.x_center}_{row.y_center}.jpg"
    
    cropped_object = cv2.bitwise_and(cropped_image, cropped_image, mask=cropped_mask)
    
    cropped_bd_object = cv2.cvtColor(extracted_object[y:y+h, x:x+w], cv2.COLOR_BGR2RGB)
    
    # 배경을 투명하게 만들기 위해 알파 채널 추가
    b_channel, g_channel, r_channel = cv2.split(cropped_object)
    alpha_channel = np.where(cropped_mask==255, 255, 0).astype(np.uint8)  # 마스크에 따라 알파 채널 설정
    rgba_image = cv2.merge((b_channel, g_channel, r_channel, alpha_channel))
    
    save_array_to_img(cropped_mask, mask_object+common_name)
    save_array_to_img(mask_full, mask_full_path+common_name)
    save_array_to_img(rgba_image, ob_path+common_name)
    save_array_to_img(cropped_bd_object, ob_bd_path+common_name)

In [None]:


masks, _, _ = predict_masks_with_sam(
        img,
        [latest_coords],
        point_labels,
        model_type=sam_model_type,
        ckpt_p=sam_ckpt,
        device=device,
    )