In [53]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
import numpy as np
import random
from itertools import product
from tqdm import trange 
import os, fnmatch
import torch
import torch.nn as nn
import cv2
import matplotlib.pyplot as plt
import sys
np.set_printoptions(threshold=sys.maxsize, suppress=True)
torch.set_printoptions(sci_mode=False, threshold=sys.maxsize)
torch.cuda.set_device(1)

def find(pattern, path):
    result = []
    for _, _, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name ,pattern):
                result.append(name)
    return result

In [58]:
sam_checkpoint = "/home/poscoict/Desktop/SAM/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
# self.predictor.reset_image()
mask_generator = SamAutomaticMaskGenerator(sam)
coords_fea_cmp_layer = nn.Sequential(
                nn.Linear(256, 3),
                nn.ReLU()
        ).to(device)

def extract_mask_fea(mask):
    _x, _y = np.meshgrid(np.arange(mask.shape[1]),np.arange(mask.shape[0])) # [376, 1241], [376, 1241]
    m_x, m_y = _x[mask],_y[mask] # (mask coords, )
    grid_x, grid_y = m_x/mask.shape[1]*2-1, m_y/mask.shape[0]*2-1 # (mask coords, )
    grid_tmp = torch.Tensor(np.stack([grid_x, grid_y],-1)[None,:,None,:]).cuda() # (1, mask coords, 1, 2)
    tmp_fea = torch.nn.functional.grid_sample(embedding, grid_tmp)[0,...,0].mean(-1) # [256]
    return tmp_fea

def extract_uv_fea(uvs, img):
    uv_x, uv_y = uvs[:, 0], uvs[:, 1]
    grid_x, grid_y = uv_x/img.shape[1]*2-1, uv_y/img.shape[0]*2-1 # (num uvs, )
    grid_tmp = torch.Tensor(np.stack([grid_x, grid_y],-1)[None,:,None,:]).cuda()
    tmp_fea = torch.nn.functional.grid_sample(embedding, grid_tmp)[0,...,0].mean(-1) # [256]
    return tmp_fea    

for i in range(1, 11):
    # ====================================================================== dir setting ==================================================================
    seq_root = '/home/poscoict/Desktop/c3d_semKITTI_refined/dataset/sequences'
    seq_dir_root = os.listdir(seq_root)
    seq_dir_root.sort()
    # for seq_dir in seq_dir_root:
    seq_dir = seq_dir_root[i] 

    image_root = os.path.join(seq_root,seq_dir,'image_2')
    uvs_root = os.path.join(seq_root,seq_dir,'uvs')
    label_root = os.path.join(seq_root,seq_dir,'labels')

    mask_map_root = os.path.join(seq_root, seq_dir, 'mask_map') 
    img_emb_root = os.path.join(seq_root, seq_dir, 'img_emb')
    masks_fea_root = os.path.join(seq_root, seq_dir, 'masks_fea')    
    coords_fea_root = os.path.join(seq_root, seq_dir, 'img_fea_256')
    uv_wise_fea_root =  os.path.join(seq_root, seq_dir, 'uv_wise_fea')

    
    os.makedirs(mask_map_root, exist_ok=True)    
    os.makedirs(img_emb_root, exist_ok=True)
    os.makedirs(masks_fea_root, exist_ok=True)
    os.makedirs(coords_fea_root, exist_ok=True)
    os.makedirs(uv_wise_fea_root, exist_ok=True)


    image_dirs = find("*.png", image_root)
    uvs_dirs = find("*.npy", uvs_root)
    label_dirs = find("*.npy", label_root) 
    image_dirs.sort()
    uvs_dirs.sort()
    label_dirs.sort()

    print("current seq : ", seq_dir)
    # ====================================================================== dir setting ==================================================================

    for idx in trange(len(image_dirs)):
        img_path = os.path.join(image_root, image_dirs[idx])
        uvs_path = os.path.join(uvs_root, uvs_dirs[idx])
        label_path = os.path.join(label_root, label_dirs[idx]) 
    
        uvs = np.load(uvs_path).astype(np.int32)
        x_coords, y_coords = uvs[:, 0], uvs[:, 1]
    
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        predictor.set_image(img)
        
        masks = mask_generator.generate(img)
        embedding = predictor.get_image_embedding() # [1, 256, 64, 64]

        # **
        uv_wise_fea = extract_uv_fea(uvs, img)
        # **
        
        num_pts = uvs.shape[0]
        coords_fea = torch.zeros((num_pts, 256)).to(device) # [num pts, 256]
        masks_fea = torch.zeros((len(masks), 256)).to(device) # [num masks, 256]
        mask_map = np.zeros((img.shape[0], img.shape[1]), dtype=np.int32) # [H, W]
        
        for m_idx, mask in enumerate(masks):
            mask = mask['segmentation']
            mask_fea = extract_mask_fea(mask) # [256]
            
            m_idxs = torch.from_numpy(mask[y_coords, x_coords]).to(device) # (num pts, )
            true_coords = torch.where(m_idxs)[0].to(device) # (num True pts, )
            coords_fea[true_coords] = mask_fea 
            
            masks_fea[m_idx] = mask_fea

            mask_map[mask] = m_idx+1

        np.save(os.path.join(mask_map_root, str(idx).zfill(6) + ".npy"), mask_map) # mask map        
        torch.save(embedding, os.path.join(img_emb_root, str(idx).zfill(6) + ".pt")) # image embedding from SAM Encoder
        torch.save(masks_fea, os.path.join(masks_fea_root, str(idx).zfill(6) + ".pt")) # gathering each mask's feature by index       
        torch.save(coords_fea, os.path.join(coords_fea_root, str(idx).zfill(6) + ".pt")) # 256 dimension SAM Image feature
        torch.save(uv_wise_fea, os.path.join(uv_wise_fea_root, str(idx).zfill(6) + ".pt")) 

current seq :  01


100%|██████████| 1101/1101 [52:37<00:00,  2.87s/it]


current seq :  02


100%|██████████| 4661/4661 [3:04:13<00:00,  2.37s/it]  


current seq :  03


100%|██████████| 801/801 [32:34<00:00,  2.44s/it]


current seq :  04


100%|██████████| 271/271 [10:48<00:00,  2.39s/it]


current seq :  05


100%|██████████| 2761/2761 [1:52:23<00:00,  2.44s/it]


current seq :  06


100%|██████████| 1101/1101 [46:52<00:00,  2.55s/it]
