In [19]:
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "3"
os.environ["OMP_NUM_THREADS"]= "4"

import argparse
import yaml
import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import time
import logging
import random
import numpy as np
import cv2

from models.full_model import ModelAGDsup as Model
from dataset.data import get_loader as get_loader
from models.metric import KLD, SIM, NSS


def set_random_seed(seed, deterministic=False):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        INTERPOLATE_MODE = "nearest"
        torch.use_deterministic_algorithms(True)
    else:
        INTERPOLATE_MODE = "bilinear"
    return INTERPOLATE_MODE


def parse_args():
    parser = argparse.ArgumentParser(description='Finetuning on AGD20K')
    parser.add_argument('--config', type=str, help='Path to the configuration file', required=True)
    
    args = parser.parse_args()
    return args
 
  
def load_config(config_path):
    with open(config_path, 'r') as stream:
        config = yaml.safe_load(stream)
    return config


def plot_annotation(image, heatmap, alpha=0.5, name=""):
    """Plot the heatmap on the target image.

    Args:
        image: The target image.
        points: The annotated points.
        heatmap: The generated heatmap.
        alpha: The alpha value of the overlay image.
    """
    # Plot the overlay of heatmap on the target image.
    processed_heatmap = heatmap * 255 / np.max(heatmap)
    processed_heatmap = np.tile(processed_heatmap[:, :, np.newaxis], (1, 1, 3)).squeeze(2)
    processed_heatmap = processed_heatmap.astype('uint8')
    processed_heatmap = cv2.applyColorMap(processed_heatmap, cv2.COLORMAP_JET)
    # print(processed_heatmap.shape, image.shape)
    # assert processed_heatmap.shape == image.shape
    overlay = cv2.addWeighted(processed_heatmap, alpha, image, 1-alpha, 0) # TODO: [:, :, ::-1]
    # cv2.imwrite(name, overlay) # TODO: , cv2.COLOR_BGR2RGB)
    cv2.show(overlay) # TODO: , cv2.COLOR_BGR2RGB)

            
def plot_annotation_with_gt(image, heatmap, gt, alpha=0.5, name=""):
    """Plot the heatmap on the target image.

    Args:
        image: The target image.
        points: The annotated points.
        heatmap: The generated heatmap.
        gt: The ground truth mask.
        alpha: The alpha value of the overlay image.
    """
    # Plot the overlay of heatmap on the target image.
    processed_heatmap = heatmap * 255 / np.max(heatmap)
    processed_heatmap = np.tile(processed_heatmap[:, :, np.newaxis], (1, 1, 3)).squeeze(2)
    processed_heatmap = processed_heatmap.astype('uint8')
    processed_heatmap = cv2.applyColorMap(processed_heatmap, cv2.COLORMAP_JET)
    # print(processed_heatmap.shape, image.shape)
    # assert processed_heatmap.shape == image.shape
    print("shape:",processed_heatmap.shape, image.shape)
    overlay = cv2.addWeighted(processed_heatmap, alpha, image, 1-alpha, 0) # TODO: [:, :, ::-1]
    
    ### Process the ground truth mask
    # Plot the overlay of heatmap on the target image.
    processed_gt = gt * 255 / np.max(gt)
    processed_gt = np.tile(processed_gt[:, :, np.newaxis], (1, 1, 3)).squeeze(2)
    processed_gt = processed_gt.astype('uint8')
    processed_gt = cv2.applyColorMap(processed_gt, cv2.COLORMAP_JET)
    
    overlay_gt = cv2.addWeighted(processed_gt, alpha, image, 1-alpha, 0) # TODO: [:, :, ::-1]

    concat = np.concatenate([overlay, overlay_gt], axis=1)
    # cv2.imwrite(name, concat)
    cv2.imshow(concat)

In [4]:
cfg_pth="configs/seen_test.yaml"
config = load_config(cfg_pth)

os.makedirs(f"{config['work_dir']}", exist_ok=True)
print("save_dir: ", config['work_dir'])

if not os.path.exists(f"{config['work_dir']}/ckpt"):
    os.makedirs(f"{config['work_dir']}/ckpt")
if not os.path.exists(f"{config['work_dir']}/img"):
    os.makedirs(f"{config['work_dir']}/img")

args_text = yaml.safe_dump(config, default_flow_style=False)
print(args_text)

INTERPOLATE_MODE = set_random_seed(1, deterministic=config["deterministic"])



save_dir:  logs/seen_test
PL_mode: refined
aug4imgRatio: 0.5
batch_size: 10
data_dir: ../../AGD20K
deterministic: false
img_size: 224
load:
  all_ckpt: ./logs/seen/ckpt/bestKLD.ckpt
  encoder_ckpt: null
model:
  decoder_embed_dim: 512
  encoder_params:
    heads: 12
    layers: 12
    output_dim: 512
    width: 768
  encoder_type: CLIP
  margin: 0.1
  pred_decoder_args:
    conv_first: true
    depth: 2
    mlp_dim: 2048
    use_additional_token: true
    use_up: 2
  pred_model_type: SAM
num_exo: 1
split_type: Seen
work_dir: logs/seen_test



In [5]:

model_config = config['model']
model = Model(**model_config)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [6]:
load_config = config["load"]
all_ckpt, encoder_ckpt = load_config["all_ckpt"], load_config["encoder_ckpt"]

In [7]:
all_ckpt

'./logs/seen/ckpt/bestKLD.ckpt'

In [8]:
### N분정도 소요됨 (6분 이하..?)
load_config = config["load"]
all_ckpt, encoder_ckpt = load_config["all_ckpt"], load_config["encoder_ckpt"]
if all_ckpt:
    with open("/home/shlee/workspace/02_robotics_miniProject/4_Affordance/WSAG-PLSP/codes/logs/seen/ckpt/bestKLD.ckpt", "rb") as f:
        print(">> start_load")
        state_dict = torch.load(f)["state_dict"]
        print(">> loaded")
    print("Loaded from ", all_ckpt)
    u, w = model.load_state_dict(state_dict, False)
else:
    raise NotImplementedError

num_parameters = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_parameters}')
num_parameters = sum([p.numel() for p in model.encoder.parameters()])
print(f'#Encoder Params: {num_parameters}')
num_parameters = sum([p.numel() for p in model.pred_decoder.parameters()])
print(f'#Final Decoder Params: {num_parameters}')
  

>> start_load
>> loaded
Loaded from  ./logs/seen/ckpt/bestKLD.ckpt
#Params: 112520740
#Encoder Params: 86192640
#Final Decoder Params: 12152576


In [9]:
model = torch.nn.DataParallel(model).cuda()
    
model.eval()
vall_kld = 0.
vall_sim = 0.
vall_nss = 0.
vall_num = 0
vall_num_sum = 0


In [10]:
eval_data_loader = get_loader(
        batch_size=1,
        img_size=config["img_size"], # follow LOCATE, Cross-View-AG, eval at 224*224
        split_file=config["split_type"],
        data_dir=config["data_dir"],
        shuffle=False,
        train=False,
        exo_obj_file=None, 
        ego_obj_file=None, 
        no_pad_gt=True
    )
eval_data_loader

<torch.utils.data.dataloader.DataLoader at 0x7f20ba28d310>

In [11]:
config["data_dir"]

'../../AGD20K'

In [12]:
eval_data_loader.dataset.verb2vid

{'beat': 0,
 'boxing': 1,
 'brush_with': 2,
 'carry': 3,
 'catch': 4,
 'cut': 5,
 'cut_with': 6,
 'drag': 7,
 'drink_with': 8,
 'eat': 9,
 'hit': 10,
 'hold': 11,
 'jump': 12,
 'kick': 13,
 'lie_on': 14,
 'lift': 15,
 'look_out': 16,
 'open': 17,
 'pack': 18,
 'peel': 19,
 'pick_up': 20,
 'pour': 21,
 'push': 22,
 'ride': 23,
 'sip': 24,
 'sit_on': 25,
 'stick': 26,
 'stir': 27,
 'swing': 28,
 'take_photo': 29,
 'talk_on': 30,
 'text_on': 31,
 'throw': 32,
 'type_on': 33,
 'wash': 34,
 'write': 35}

In [13]:
eval_data_loader.dataset.noun2nid

{'drum': 0,
 'punching_bag': 1,
 'toothbrush': 2,
 'skateboard': 3,
 'skis': 4,
 'snowboard': 5,
 'surfboard': 6,
 'frisbee': 7,
 'rugby_ball': 8,
 'soccer_ball': 9,
 'apple': 10,
 'banana': 11,
 'carrot': 12,
 'orange': 13,
 'knife': 14,
 'scissors': 15,
 'suitcase': 16,
 'bottle': 17,
 'cup': 18,
 'wine_glass': 19,
 'broccoli': 20,
 'hot_dog': 21,
 'axe': 22,
 'baseball_bat': 23,
 'hammer': 24,
 'tennis_racket': 25,
 'badminton_racket': 26,
 'book': 27,
 'bowl': 28,
 'fork': 29,
 'golf_clubs': 30,
 'bed': 31,
 'bench': 32,
 'couch': 33,
 'binoculars': 34,
 'microwave': 35,
 'oven': 36,
 'refrigerator': 37,
 'bicycle': 38,
 'motorcycle': 39,
 'chair': 40,
 'camera': 41,
 'cell_phone': 42,
 'baseball': 43,
 'basketball': 44,
 'discus': 45,
 'javelin': 46,
 'keyboard': 47,
 'laptop': 48,
 'pen': 49}

In [13]:
## Check text feature

# data_dir="../../AGD20K"
# nounsFeat = torch.load(os.path.join(data_dir, "sentenceFeatNounAGD.pth"))
# verbsFeat = torch.load(os.path.join(data_dir, "sentenceFeatVerbAGD.pth"))
# partsFeat = torch.load(os.path.join(data_dir, "sentenceFeatPartAGD.pth"))
# print(nounsFeat.keys())
# print(verbsFeat.keys())
# print(partsFeat.keys())
# verbsFeat["beat"].shape

dict_keys(['drum', 'baseball', 'rugby_ball', 'soccer_ball', 'javelin', 'frisbee', 'discus', 'basketball', 'cup', 'bottle', 'wine_glass', 'punching_bag', 'camera', 'cell_phone', 'suitcase', 'skis', 'laptop', 'keyboard', 'refrigerator', 'microwave', 'book', 'oven', 'bowl', 'pen', 'bicycle', 'motorcycle', 'tennis_racket', 'golf_clubs', 'badminton_racket', 'baseball_bat', 'surfboard', 'skateboard', 'snowboard', 'fork', 'knife', 'toothbrush', 'orange', 'apple', 'carrot', 'banana', 'bench', 'couch', 'bed', 'chair', 'scissors', 'hammer', 'axe', 'broccoli', 'hot_dog', 'binoculars'])
dict_keys(['beat', 'throw', 'sip', 'kick', 'pour', 'take_photo', 'drink_with', 'pick_up', 'type_on', 'open', 'stir', 'talk_on', 'catch', 'text_on', 'write', 'ride', 'push', 'swing', 'drag', 'boxing', 'jump', 'wash', 'peel', 'cut', 'lie_on', 'sit_on', 'brush_with', 'cut_with', 'pack', 'hit', 'carry', 'hold', 'eat', 'look_out', 'lift', 'stick'])
dict_keys(['beat drum', 'throw baseball', 'throw rugby_ball', 'throw soc

torch.Size([512])

In [14]:
len(partsFeat.keys())

124

In [60]:
verbsFeat["beat"]

tensor([-2.9030e-03,  4.7852e-02,  1.4282e-01, -3.5217e-02, -4.0112e-01,
         1.9913e-02, -2.3828e-01,  3.1219e-02, -4.6204e-02,  5.0385e-02,
        -1.5762e-02,  1.7297e-01, -8.8989e-02,  4.6661e-02,  1.7908e-01,
        -1.0382e-01, -1.2134e-01, -8.0713e-01,  2.7664e-02, -1.9873e-01,
        -1.4429e-01,  2.2583e-01, -7.9773e-02,  7.0923e-02,  8.2245e-03,
         3.0838e-02,  1.1139e-01,  2.7490e-01, -2.0850e-01, -7.3181e-02,
         1.7065e-01,  3.8300e-02, -2.0142e-01, -1.4124e-01,  4.6289e-01,
         4.3384e-01, -1.1462e-01, -4.3896e-01,  4.3628e-01,  1.2006e-01,
        -2.1118e-01, -1.0907e-01, -1.1047e-01,  3.6678e-03,  2.1582e-01,
        -1.7053e-01, -3.8501e-01,  2.0190e-01,  3.2318e-02, -7.8003e-02,
         1.7322e-01,  3.2080e-01, -1.3451e-02, -3.2422e-01, -1.0345e-01,
         1.6602e-01,  5.9521e-01,  4.6875e-02,  5.2551e-02,  2.0386e-01,
         6.9702e-02, -9.8206e-02, -1.6577e-01,  3.0859e-01,  8.8501e-04,
         2.9510e-02,  5.4962e-02,  2.1118e-01,  2.7

In [46]:
## part
data_sample[7]+' '+data_sample[8]

'beat drum'

In [30]:
##      0         1         2               3           4          5         6        7     8      9     10   11
##  input_img, gt_mask, gt_mask_prob, input_shape, sent_feat, noun_feat, part_feat, verb, noun, input_p, vid, nid
data_sample=eval_data_loader.dataset[0]
print(len(data_sample))
print(data_sample[0].shape)
print(data_sample[1].shape)
print(data_sample[2].shape)
print(data_sample[3])
print(data_sample[4].shape)
print(data_sample[5].shape)
print(data_sample[6].shape)
print(data_sample[7])
print(data_sample[8])
print(data_sample[9])
print(data_sample[10])
print(data_sample[11])  



12
torch.Size([3, 224, 224])
torch.Size([1, 224, 224])
torch.Size([1, 224, 224])
[1500, 1297]
torch.Size([512])
torch.Size([512])
torch.Size([512])
beat
drum
../../AGD20K/Seen/testset/egocentric/beat/drum/drum_000346.jpg
0
0


In [88]:
# from transformers import AutoTokenizer, CLIPModel

# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
# tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16")

# inputs = tokenizer(["beat"], padding=True, return_tensors="pt")
# with torch.no_grad():
#     text_features2 = model.get_text_features(**inputs)
# text_features2.squeeze(0).cpu().type(torch.float16)[:40]

In [14]:
# pip install git+https://github.com/openai/CLIP.git

import os
import clip
import torch
from torchvision.datasets import CIFAR100


# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_ckpt = "./ViT-B-16.pt"
clip_model, preprocess = clip.load(clip_ckpt, device)

# text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
text_inputs = torch.cat([clip.tokenize("beat") ]).to(device)

# Calculate features
with torch.no_grad():
    text_features = clip_model.encode_text(text_inputs)

# # Pick the top 5 most similar labels for the image
# text_features /= text_features.norm(dim=-1, keepdim=True)

In [51]:
text_features.squeeze(0).cpu()[:10]

tensor([-0.0025,  0.0491,  0.1433, -0.0350, -0.4016,  0.0185, -0.2386,  0.0314,
        -0.0471,  0.0494], dtype=torch.float16)

In [94]:
verbsFeat["beat"][:40]

tensor([-0.0029,  0.0479,  0.1428, -0.0352, -0.4011,  0.0199, -0.2383,  0.0312,
        -0.0462,  0.0504, -0.0158,  0.1730, -0.0890,  0.0467,  0.1791, -0.1038,
        -0.1213, -0.8071,  0.0277, -0.1987, -0.1443,  0.2258, -0.0798,  0.0709,
         0.0082,  0.0308,  0.1114,  0.2749, -0.2085, -0.0732,  0.1707,  0.0383,
        -0.2014, -0.1412,  0.4629,  0.4338, -0.1146, -0.4390,  0.4363,  0.1201],
       dtype=torch.float16)

In [17]:
text_features.shape

torch.Size([1, 512])

In [15]:
'''
### img path
1. cup - drink_with
    AGD20K/Unseen/testset/egocentric/drink_with/cup/cup_002062.jpg
'''
img_path_list=["../../AGD20K/Unseen/testset/egocentric/drink_with/cup/cup_002062.jpg",
            #   "../../AGD20K/Unseen/testset/egocentric/drink_with/cup/cup_002062.jpg",
            #   "../../AGD20K/Unseen/testset/egocentric/drink_with/cup/cup_002062.jpg",
            #   "../../AGD20K/Unseen/testset/egocentric/drink_with/cup/cup_002062.jpg",
            #   "../../AGD20K/Unseen/testset/egocentric/drink_with/cup/cup_002062.jpg",
            #   "../../AGD20K/Unseen/testset/egocentric/drink_with/cup/cup_002062.jpg",
              ]



from torchvision import transforms
from PIL import Image
from torchvision.transforms import ToTensor, Normalize

def _convert_image_to_rgb(image):
    return image.convert("RGB")

transform_noresize = transforms.Compose([
            _convert_image_to_rgb,
            ToTensor(),
            Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

In [28]:

with torch.no_grad():
    img_path=img_path_list[0]
    gt_pth=img_path.replace(".jpg", ".png").replace("egocentric", "GT")
    ### 1-a) text feature
    # text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
    text_inputs = torch.cat([clip.tokenize("beat") ]).to(device)
    text_features = clip_model.encode_text(text_inputs)
    
    ### 1-b) image feature
    input_img = Image.open(img_path)
    input_shape = [input_img.size[1], input_img.size[0]]
    input_img_no_resize = transform_noresize(input_img)
    
    if config["img_size"]:
        input_image = F.interpolate(
            input_img_no_resize.unsqueeze(0),
            size=config["img_size"],
            mode="bilinear",
        ).squeeze(0)
    else:
        input_image = F.interpolate(
            input_img_no_resize.unsqueeze(0),
            size=224,
            mode="bilinear",
        ).squeeze(0)
    input_image = input_image.unsqueeze(0)

    ## 2) gt mask process
    gt_ori = torch.tensor(np.array(Image.open(gt_pth))).float().reshape(1, *input_shape)

    if config["img_size"]:
        gt = F.interpolate(
            gt_ori.unsqueeze(0),
            size=config["img_size"],
            mode="bilinear",
        ).squeeze(0)
    else:
        gt = gt_ori
    # gt=gt.unsqueeze(0)

    if gt.max() == 0: # Seen/testset/GT/hold/bottle/bottle_000341.png
        gt = torch.ones_like(gt)
    gt_mask = gt / gt.max()
    gt_mask_prob = gt / gt.sum()

    ### 3) model inference
    aff_res, _, _ = model(
        input_image, text_features,
    )
    pred = aff_res.detach()
    
    r_pred = F.interpolate(
        pred, 
        size=gt_mask.shape[-2:],
        mode=INTERPOLATE_MODE,
    )
    
    gt_prob = gt_mask_prob.cuda().reshape(len(pred), -1)
    r_prob = F.softmax(r_pred.reshape(len(pred), -1), dim=1)
    
    kld_per_sample = KLD(r_prob, gt_prob, "none").sum(dim=1)
    kld = kld_per_sample.sum()
    sim = SIM(r_prob, gt_prob) * len(pred)
    nss = NSS(r_prob, gt_prob) * len(pred)
    vall_kld += kld
    vall_sim += sim
    vall_nss += nss
    vall_num += 1
    vall_num_sum += len(pred)
    
    verbs = None
    nouns = None
    
    
    for bid in range(len(r_pred)):
        pp = r_prob[bid] / r_prob[bid].max()
        
        ii = F.interpolate(
            input_image[bid:bid+1], 
            size=gt_mask.shape[-2:],
            mode=INTERPOLATE_MODE,
        ).reshape(3, *gt_mask.shape[-2:])
        
        ## visualize with gt
        plot_annotation_with_gt(ii, 
                        pp.reshape(*gt_mask[bid].shape).detach().unsqueeze(0).cpu().numpy().transpose(1, 2, 0), 
                        gt_mask[bid].unsqueeze(0).cpu().numpy().transpose(1, 2, 0), 
                        name=None)
        
    
print(
    f"Result on AGD: \nKLD={vall_kld/vall_num_sum}, SIM={vall_sim/vall_num_sum}, NSS={vall_nss/vall_num_sum}")

shape: (224, 224, 3) torch.Size([3, 224, 224])


error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'addWeighted'
> Overload resolution failed:
>  - src2 is not a numpy array, neither a scalar
>  - Expected Ptr<cv::UMat> for argument 'src2'


In [32]:
gt_mask[bid].unsqueeze(0).cpu().numpy().shape

(1, 224, 224)

: 

In [18]:
pp.reshape(*gt_mask[bid].shape).detach().cpu().numpy().shape

(224, 224)

In [45]:
gt_mask[bid].cpu().numpy().transpose(1, 2, 0).shape

(224, 224, 1)