In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

print("PyTorch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())
# !git clone https://github.com/SysCV/sam-hq.git
# os.chdir('sam-hq')
# !export PYTHONPATH=$(pwd)
from fdog import run
from segment_anything import sam_model_registry, SamPredictor

In [None]:
from train.utils.dataloader import create_dataloaders, get_im_gt_name_dict
dataset_thin_val = get_im_gt_name_dict([{"name": "ThinObject5k",
                 "im_dir": "./train/data/thin_object_detection/ThinObject5K/images_test",
                 "gt_dir": "./train/data/thin_object_detection/ThinObject5K/masks_test",
                 "im_ext": ".jpg",
                 "gt_ext": ".png"}])
dataLoader , _ = create_dataloaders(name_im_gt_list=dataset_thin_val)


In [None]:
for batch in dataLoader[0]:
    print(batch.keys())
    # print(batch['image'].squeeze(0).shape)
    # img = batch['image'].squeeze(0)
    # img = img.permute(1, 2, 0)
    # print(img.shape)
    break

In [None]:
sam_checkpoint = "../sam_hq_vit_b.pth"
model_type = "vit_b"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

In [None]:
import train.utils.misc as misc
def hsv_i(image, point_coords, point_labels, data_val):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=False,
        hq_token_only=False,
    )
    # print(masks.shape, scores, logits.shape)
    mask_save = masks[0, :].astype(np.uint8)*255
    fname = data_val['ori_im_path'][0].split("\\")
    cv2.imwrite(f'../masks_hsv/{ fname[1] }', mask_save)

def ori_i(image, point_coords, point_labels, data_val):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=False,
        hq_token_only=False,
    )
    # print(masks.shape, scores, logits.shape)
    mask_save = masks[0, :].astype(np.uint8)*255
    fname = data_val['ori_im_path'][0].split("\\")
    cv2.imwrite(f'../masks_ori/{ fname[1] }', mask_save)

def liner_trans(image, point_coords, point_labels, data_val):
    gamma = 0.8
    image = np.float32(image)
    label = cv2.imread(data_val['ori_gt_path'][0], cv2.IMREAD_COLOR)
    label = np.float32(label)
    image[label==0] = image[label==0] * gamma // 1
    image[image<0] = 0
    # image[image>255] = 255
    image = np.uint8(image)
    

    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=False,
        hq_token_only=False,
    )
    # print(masks.shape, scores, logits.shape)
    mask_save = masks[0, :].astype(np.uint8)*255
    fname = data_val['ori_im_path'][0].split("\\")
    cv2.imwrite(f'../masks_lumin/{ fname[1] }', mask_save)

def fdog_i(image, point_coords, point_labels, data_val):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    frangi_edges = run(
                                img=image, sobel_size=5,
                                etf_iter=4, etf_size=7,
                                fdog_iter=2, sigma_c=1.0, rho=0.997, sigma_m=3.0,
                                tau=0.907
                            )
    image[frangi_edges[:,:]==0] = [0,0,0]

    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=False,
        hq_token_only=False,
    )
    # print(masks.shape, scores, logits.shape)
    mask_save = masks[0, :].astype(np.uint8)*255
    fname = data_val['ori_im_path'][0].split("\\")
    cv2.imwrite(f'../masks_fdog/{ fname[1] }', mask_save)

torch.manual_seed(88)
torch.cuda.manual_seed(88)
import gc

for index, data_val in enumerate(dataLoader[0]):
    gc.collect()
    if index <= 16:
        continue
    print(index, data_val['ori_im_path'][0])
    imidx_val, inputs_val, labels_val, shapes_val, labels_ori = data_val['imidx'], data_val['image'], data_val['label'], data_val['shape'], data_val['ori_label']
    image = cv2.imread(data_val['ori_im_path'][0], cv2.IMREAD_COLOR)
    # label = cv2.imread(data_val['ori_gt_path'][0], cv2.IMREAD_COLOR)

    labels_points = misc.masks_sample_points(labels_val[:,0,:,:], k=3)
    # print(labels_points)
    point_coords = labels_points[0, :, :].cpu().detach().numpy().astype(np.int64)
    # print(point_coords)
    point_labels = np.ones(shape=point_coords.shape[0])
    # print(point_labels)
    # break

    # hsv_i(image, point_coords, point_labels, data_val)
    # ori_i(image, point_coords, point_labels, data_val)
    # liner_trans(image, point_coords, point_labels, data_val)
    fdog_i(image, point_coords, point_labels, data_val)
    if index == 50:
        break
    # break