In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
from collections import defaultdict

import cv2
import imutils
import numpy as np
import pandas as pd
from tqdm import tqdm
from ultralytics import YOLO
from denku import show_images, show_image
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [2]:
def draw_mask(image, mask, alpha=0.3):
    img = image.copy()
    img = cv2.addWeighted(img, 1, mask, alpha, 0)
    return img

In [3]:
DATASET_ROOT = '/home/raid_storage/datasets/rosatom'
CSV_PATH = os.path.join(DATASET_ROOT, 'filtered_dataset.csv')

In [4]:
df = pd.read_csv(CSV_PATH, index_col=0)
df = df[df['stage'] == 'test']
df

Unnamed: 0,filename,x,y,class,stage
7,FRAMES/0/1538/frame0009.bmp,576,313,3,test
8,FRAMES/0/1538/frame0009.bmp,724,509,3,test
13,FRAMES/0/1538/frame0012.bmp,279,326,3,test
14,FRAMES/0/1538/frame0012.bmp,432,503,3,test
16,FRAMES/0/1538/frame0013.bmp,242,331,3,test
...,...,...,...,...,...
33463,FRAMES/2023.10.25/5_498.bmp,493,364,1,test
33536,FRAMES/2023.10.25/5_809.bmp,267,371,8,test
33563,FRAMES/2023.10.25/5_882.bmp,254,95,1,test
33564,FRAMES/2023.10.25/5_882.bmp,250,255,1,test


In [5]:
sam_checkpoint = './sam_vit_h_4b8939.pth'
model_type = 'vit_h'

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam = sam.to(device='cuda:0')

predictor = SamPredictor(sam)

In [9]:
model_name = '/home/raid_storage/isakov/hacks/notebooks/runs/detect/train7/weights/best.pt'
model = YOLO(model_name)

In [14]:
visualize = False
out = defaultdict(list)
c_out = defaultdict(list)

u_images = df['filename'].unique()


for u_image in tqdm(u_images[:]):
    image_path = os.path.join(DATASET_ROOT, u_image)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predict_img = image.copy()
    
    results = model.predict(image_path, verbose=False, device='cuda:0')
    result = results[0]
    
    img_h, img_w = result.orig_shape
    masks = np.zeros((img_h, img_w), dtype=np.uint8)
    
    predicts = defaultdict(list)
    for box in result.boxes:
        x1, y1, x2, y2 = [int(x) for x in box.xyxy[0].tolist()]
        x = int(x1 + (x2 - x1) / 2)
        y = int(y1 + (y2 - y1) / 2)
        
        conf = round(box.conf[0].item(), 2)
        if conf < 0.25:
            continue
        
        class_id = int(box.cls[0].item()) + 1
        predicts[class_id].append([x, y])
        
        out['filename'].append(u_image)
        out['x'].append(x)
        out['y'].append(y)
        out['class'].append(class_id)    
        
        
        ## visualizing predict points ==========
        if visualize:
            predict_img = cv2.circle(predict_img, (x, y), radius=5, color=(255, 0, 255), thickness=-1)
            predict_img = cv2.putText(predict_img, str(class_id), (x - 5, y - 5), cv2.FONT_HERSHEY_COMPLEX, 0.9, (128, 0, 0), 2)
        ## =====================================
        
#     for label, points in predicts.items():
#         predictor.set_image(image)
#         mask, _, _ = predictor.predict(
#             point_coords=np.array(points),
#             point_labels=[label] * len(points),
#             box=None, 
#             multimask_output=True,
#         )
#         masks[mask[0]] = class_id #* 18
    
#     u_labels = np.unique(masks)
#     for u_label in u_labels:
#         if u_label == 0:
#             continue
#         mask = (masks == u_label).astype(np.uint8) * 255
#         cnts = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
#         cnts = imutils.grab_contours(cnts)
        
#         if len(cnts):
#             for c in cnts:
#                 M = cv2.moments(c)
#                 if M["m00"] == 0:
#                     for box in result.boxes:
#                         x1, y1, x2, y2 = [int(x) for x in box.xyxy[0].tolist()]
#                         x = int(x1 + (x2 - x1) / 2)
#                         y = int(y1 + (y2 - y1) / 2)

#                         conf = round(box.conf[0].item(), 2)
#                         if conf < 0.25:
#                             continue

#                         class_id = int(box.cls[0].item()) + 1
#                         predicts[class_id].append([x, y])

#                         c_out['filename'].append(u_image)
#                         c_out['x'].append(x)
#                         c_out['y'].append(y)
#                         c_out['class'].append(class_id)  
#                     continue
#                 cX = int(M["m10"] / M["m00"])
#                 cY = int(M["m01"] / M["m00"])

#                 c_out['filename'].append(u_image)
#                 c_out['class'].append(u_label)
#                 c_out['x'].append(cX)
#                 c_out['y'].append(cY)

#                 if visualize:
#                     predict_img = cv2.drawContours(predict_img, [c], -1, (0, 128, 255), 2)
#                     predict_img = cv2.circle(predict_img, (cX, cY), 7, (32, 32, 255), -1)
#         else:
#             for box in result.boxes:
#                 x1, y1, x2, y2 = [int(x) for x in box.xyxy[0].tolist()]
#                 x = int(x1 + (x2 - x1) / 2)
#                 y = int(y1 + (y2 - y1) / 2)

#                 conf = round(box.conf[0].item(), 2)
#                 if conf < 0.25:
#                     continue

#                 class_id = int(box.cls[0].item()) + 1
#                 predicts[class_id].append([x, y])

#                 c_out['filename'].append(u_image)
#                 c_out['x'].append(x)
#                 c_out['y'].append(y)
#                 c_out['class'].append(class_id)  
    
#     #### VISUALIZE GROUND TRUE POINTS
#     if visualize:
#         for i, row in df[df['filename'] == u_image].iterrows():
#             predict_img = cv2.circle(predict_img, (row.x, row.y), radius=5, color=(255, 0, 255), thickness=-1)
#             label = str(row['class'])
#             predict_img = cv2.putText(predict_img, label, (row.x - 5, row.y - 5), cv2.FONT_HERSHEY_COMPLEX, 0.7, (0, 255, 0), 2)
#     #### =============================
    
#     if visualize:
#         mask = np.stack([np.zeros_like(masks), np.zeros_like(masks), masks]).transpose(1, 2, 0).astype(np.uint8) * 100
#         predict_img = draw_mask(predict_img, mask, 0.8)
#         show_images([predict_img], figsize=(15, 15))
#     break

 30%|██▉       | 269/908 [00:06<00:14, 43.25it/s]

KeyboardInterrupt



In [13]:
result_df = pd.DataFrame(data=c_out)
result_df.to_csv('pr_segments_max_model.csv')