In [1]:
# google/vit-base-patch16-224-in21k
"""
python actions/vanilla/train_vanilla_imagenet.py --need-cls
"""
import torch
# from transformers import ViltProcessor, ViltForQuestionAnswering, ViltConfig
import os
import argparse
# import evaluate
from tqdm.auto import tqdm
from transformers import ViTImageProcessor, AutoModel
from pathlib import Path
import pickle

import numpy as np
import random
from tqdm import tqdm
import h5py
from PIL import Image
import math
from collections import defaultdict
import matplotlib.pyplot as plt
from skimage.transform import resize
import matplotlib
from skimage import measure
from scipy.ndimage import binary_dilation

plt.rcParams.update({'font.size': 8})


def resize_image_masks(image, masks):
    num_channel, img_dim1, img_dim2 = image.shape
    masks = masks.int()
    new_masks = []
    new_images = []
    for i, mask in enumerate(masks):
        non_zeros = torch.nonzero(masks[i])
        if non_zeros.size(0) == 0:
            continue
        min_x = torch.min(non_zeros[:, 0])
        max_x = torch.max(non_zeros[:, 0])
        min_y = torch.min(non_zeros[:, 1])
        max_y = torch.max(non_zeros[:, 1])
        resized_mask = resize(masks[i][min_x:max_x, min_y:max_y].cpu().numpy(), (img_dim1, img_dim2), preserve_range=True)
        thresholded_mask = (torch.tensor(resized_mask) > 0.5).int()
        new_masks.append(thresholded_mask)
        resized_image = resize(image[:, min_x:max_x, min_y:max_y].cpu().numpy(), 
                                (num_channel, img_dim1, img_dim2), 
                                preserve_range=True)
        new_images.append(torch.tensor(resized_image))
    return torch.stack(new_images).to(image.device), torch.stack(new_masks).unsqueeze(1).to(masks.device)


def convert_idx_masks_to_bool(masks):
    """
    input: masks (1, img_dim1, img_dim2)
    output: masks_bool (num_masks, img_dim1, img_dim2)
    """
    unique_idxs = torch.sort(torch.unique(masks)).values
    idxs = unique_idxs.view(-1, 1, 1)
    broadcasted_masks = masks.expand(unique_idxs.shape[0],
                                     masks.shape[1],
                                     masks.shape[2])
    masks_bool = (broadcasted_masks == idxs)
    return masks_bool


def transform(image, processor=None):
    # Preprocess the image using the ViTImageProcessor
    image = image.convert("RGB")
    if processor is not None:
        inputs = processor(image, return_tensors='pt')
        return inputs['pixel_values'].squeeze(0)
    else:
        return np.asarray(image)

In [2]:
def sum_weights_for_unique_masks(masks, masks_weights, preds): #, poolers):
    # Convert each boolean mask to a unique string of 0s and 1s
    mask_strs = [''.join(map(str, mask.bool().int().flatten().tolist())) for mask in masks]
    img_size = 66

    # Dictionary to store summed weights for each unique mask
    unique_masks_weights = {}
    unique_masks_preds = {}
    unique_masks_count = {}
    unique_masks_dict = {}

    for i, (mask_str, weight, pred) in enumerate(zip(mask_strs, masks_weights, preds)):
        if mask_str in unique_masks_weights:
            unique_masks_weights[mask_str] += weight
            unique_masks_preds[mask_str] += pred
            unique_masks_count[mask_str] += 1
        else:
            unique_masks_dict[mask_str] = masks[i]
            unique_masks_weights[mask_str] = weight
            unique_masks_preds[mask_str] = pred
            unique_masks_count[mask_str] = 1

    # Convert dictionary keys back to boolean masks
    unique_keys = sorted(unique_masks_weights.keys())
    unique_masks = [unique_masks_dict[key] for key in unique_keys]
    summed_weights = [unique_masks_weights[key] for key in unique_keys]
    mean_preds = [unique_masks_preds[key] for key in unique_keys]

    return unique_masks, summed_weights, mean_preds #, mean_poolers

In [None]:
input_dir = '/shared_data0/weiqiuy/explainable_attn/exps/cosmogrid_wrapper_lr0005_phs32_watershed_diagonal_os16/best/val_results'

X = []
images = []
mask_paths = []
mask_idxs = []
preds_all = []
labels_all = []
mask_id_feat_dict = {}


count = 0
for filename in tqdm(sorted(os.listdir(input_dir))):
    data = pickle.load(open(os.path.join(input_dir, filename), 'rb'))
    
    image = data['image']
    image_array = np.array(image)[0]
    masks_used = data['masks_used'].cpu().numpy()
    mask_weights = data["mask_weights"]
    pred = data['outputs_avg']
    preds = data['outputs']
    original_preds = data['outputs_original']
    label = data['label']
    
    unique_masks, summed_weights, mean_preds = sum_weights_for_unique_masks(torch.tensor(masks_used), 
                                                                                          torch.tensor(mask_weights), 
                                                                                          torch.tensor(preds))

    for i in range(len(unique_masks)):
        mask_id_feat_dict[(filename, i)] = (image, 
                                            label,
                                            unique_masks[i],
                                            summed_weights[i], 
                                            mean_preds[i],
                                           )
    count += 1
    if count > 1000:
        break