In [None]:
import torch
# from models_cmae import *
import models_cmae
import models_mae
import torch.nn.functional as F
from util.datasets import build_dataset
import matplotlib
from einops import rearrange
from matplotlib import pyplot as plt
from dataclasses import dataclass
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os


# DEVICE = torch.device('cpu')

In [None]:
@dataclass
class args:
    data_path: str = '/home/assafsho/data/ILSVRC2012'
    input_size: int = 224
    color_jitter: bool = False
    aa: bool = False
    reprob: bool = False
    remode: bool = False
    recount: bool = False
    batch_size: int = 1024
    num_workers: int = 0
    pin_mem: bool = False
    num_groups: int = 4
    group_sz: int = 49
    mask_ratio: float = 0.


# load my model
chkpt_dir = '/home/assafsho/mae/only_batchwise/checkpoint-799.pth'
model = getattr(models_cmae, 'mae_vit_base_patch16')().to(DEVICE)
checkpoint = torch.load(chkpt_dir, map_location='cuda')
msg = model.load_state_dict(checkpoint['model'], strict=False)
print(msg)


# load original mae model
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth
!wget -nc https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth
# chkpt_dir = 'mae_visualize_vit_large.pth'
chkpt_dir = 'mae_pretrain_vit_base.pth'
# model_orig = getattr(models_mae, 'mae_vit_large_patch16')().cuda()
model_orig = getattr(models_mae, 'mae_vit_base_patch16')().cuda()
checkpoint = torch.load(chkpt_dir, map_location='cuda')
msg = model_orig.load_state_dict(checkpoint['model'], strict=False)
print(msg)


def imshow(x, mask=None, norm=True, orig_sz=True):
    imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
    imagenet_std = torch.tensor([0.229, 0.224, 0.225])
    
    if len(x.shape) > 3:
        x = x.squeeze(0)
    if torch.is_tensor(x):
        x = x.detach().cpu()
    if torch.is_tensor(mask):
        mask = mask.detach().cpu()
    if x.shape[0] <= 3:
        x = torch.einsum('chw->hwc', x)
    if mask is not None and mask.shape[0] <= 3:
        mask = torch.einsum('chw->hwc', mask)
    if norm:
        x = x * imagenet_std + imagenet_mean
    if orig_sz:
        height, width, depth = x.shape
        dpi = matplotlib.rcParams['figure.dpi']
        figsize = width / float(dpi), height / float(dpi)
        plt.figure(figsize=figsize)
    if mask is not None:
        x = x * mask
    plt.imshow(x.clip(0, 1), vmin=0.)
    plt.show()

In [None]:
transform_val = transforms.Compose([
            transforms.Resize(256, interpolation=3),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# dataset_val = build_dataset(is_train=False, args=args)
dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)


sampler_val = torch.utils.data.SequentialSampler(dataset_val)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, #sampler=sampler_val,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    shuffle=False,
    drop_last=False
)

In [None]:
@torch.no_grad()
def nn_evaluate(data_loader, model, model_orig, device, num_queries, num_neighbors):
    # assumes use_cls_token=True and data_loader(shuffle=False)

    # switch to evaluation mode
    model.eval()
    full_sim_matrix = torch.empty(num_queries, 0, device=torch.device('cpu'))

    query_inds = torch.randperm(50000)[:num_queries]
    query_ims = torch.stack([data_loader.dataset[ind][0] for ind in query_inds]).to(device)

    with torch.cuda.amp.autocast():
        q_reps_full = model(query_ims, num_groups=args.num_groups, group_sz=args.group_sz, encode_only=True, group_duplicates=False)
        q_reps_cls = q_reps_full[:, :args.num_groups, :].mean(1)  # [Q, E]
        queries = F.normalize(q_reps_cls, dim=-1)
            
        if model_orig is not None:
            model_orig.eval()
            q_reps_full_orig, _, _ = model_orig.forward_encoder(query_ims, mask_ratio=args.mask_ratio)
            q_reps_cls_orig = q_reps_full_orig[:, 0, :]  # [Q, E]
            queries_orig = F.normalize(q_reps_cls_orig, dim=-1)
    
            full_sim_matrix_orig = torch.empty(num_queries, 0, device=torch.device('cpu'))

    
    ind = 0
    for samples, _ in data_loader:
        # get a batch
        samples = samples.to(device, non_blocking=True)

        # calcualte representations
        with torch.cuda.amp.autocast():
            representations_full = model(samples, num_groups=args.num_groups, group_sz=args.group_sz, encode_only=True, group_duplicates=False)
            rep_cls = representations_full[:, :args.num_groups, :].mean(1)  # [B, E]
            reps = F.normalize(rep_cls, dim=-1)
            
        cur_sim_matrix = torch.einsum('QE,BE->QB', queries, reps)
        full_sim_matrix = torch.cat((full_sim_matrix, cur_sim_matrix.to(full_sim_matrix.device)), 1)

        if model_orig is not None:
            with torch.cuda.amp.autocast():
                representations_full_orig, _, _ = model_orig.forward_encoder(samples, mask_ratio=args.mask_ratio)
                rep_cls_orig = representations_full_orig[:, 0, :]  # [B, E]
                reps_orig = F.normalize(rep_cls_orig, dim=-1)


            cur_sim_matrix_orig = torch.einsum('QE,BE->QB', queries_orig, reps_orig)
            full_sim_matrix_orig = torch.cat((full_sim_matrix_orig, cur_sim_matrix_orig.to(full_sim_matrix.device)), 1)
            
            
            
            
            
#             if ind < 4:
#                 print(queries[0][:5]) 
#                 print(reps[0][:5])
#                 imshow(query_ims[0])
#                 imshow(samples[0])


            
            ind += 1
            print(ind * 100 * args.batch_size / 50000)
#             if ind >5:
#                 break

    vals, inds = torch.topk(full_sim_matrix, num_neighbors)
    neighbor_ims = torch.stack([data_loader.dataset[i][0] for i in inds.view(-1)])
    
    if model_orig is not None:
        vals_orig, inds_orig = torch.topk(full_sim_matrix_orig, num_neighbors)
        neighbor_ims_orig = torch.stack([data_loader.dataset[i][0] for i in inds_orig.view(-1)])
    else:
        neighbor_ims_orig = None

    return query_ims, neighbor_ims, neighbor_ims_orig

In [None]:
query_ims, neighbor_ims, neighbor_ims_orig = nn_evaluate(data_loader=data_loader_val, 
                                                         model=model,
                                                         model_orig=model_orig,
                                                         device=DEVICE, 
                                                         num_queries=50, 
                                                         num_neighbors=20)

In [None]:
num_neighbors = 20
num_neighbors_row = 5

for q, im in enumerate(query_ims):
    imshow(im)
    print('my')
    for row in range(num_neighbors // num_neighbors_row):
        imshow(rearrange(neighbor_ims[q * num_neighbors + row * num_neighbors_row: q * num_neighbors + (row + 1) * num_neighbors_row], 'B C H W -> C H (B W)'))
    print('orig')
    for row in range(num_neighbors // num_neighbors_row):
        imshow(rearrange(neighbor_ims_orig[q * num_neighbors + row * num_neighbors_row: q * num_neighbors + (row + 1) * num_neighbors_row], 'B C H W -> C H (B W)'))
    print('****************************************************************************************')

In [None]:
query_im = query_ims[7].to(DEVICE, non_blocking=True)
ref_im = neighbor_ims[141].to(DEVICE, non_blocking=True)

In [None]:
def get_patch_sim_mat(query_im, ref_im, device=DEVICE):
    query_im = query_im.to(device, non_blocking=True)
    ref_im = ref_im.to(device, non_blocking=True)
    
    qr_ims = torch.stack([query_im, ref_im]) if query_im.ndim == 3 else torch.cat([query_im, ref_im])
    
    with torch.cuda.amp.autocast():
        qr = model(qr_ims, num_groups=1, group_sz=196, encode_only=True, group_duplicates=False)
        qr = qr[:, 1:, :]
        q, r = F.normalize(qr)

    sim_mat = (q @ r.transpose(-2, -1) / 0.1).softmax(dim=-1)
    
    return sim_mat

In [None]:
def show_heat_map(q_im, r_im, sim_mat, q_patch_nums):
    qmap_small = torch.zeros(196, device=q_im.device)
    qmap_small[q_patch_nums] = 1.
    qmap_small = qmap_small.reshape(1,1,14,14).expand(1,3,14,14)
    qmap = F.interpolate(qmap_small, size=(224, 224))
    
    heatmap_small = sim_mat[q_patch_nums].mean(0) * 100
    
    heatmap_small = heatmap_small.reshape(196).softmax(-1).reshape(14,14)
    print(heatmap_small)
    heatmap_small = (heatmap_small - heatmap_small.min()) / heatmap_small.max()
    heatmap_small = (heatmap_small == heatmap_small.max()).float()
    print(heatmap_small)
    heatmap_small = torch.stack([heatmap_small]*3)[None, ...]
    heatmap = F.interpolate(heatmap_small, size=(224, 224)).squeeze(0)
    imshow(qmap * q_im, None)
    imshow(r_im, heatmap)

In [None]:
sim_mat = get_patch_sim_mat(query_im, ref_im)

In [None]:
show_heat_map(query_im, ref_im, sim_mat, [125])

In [None]:
imshow(query_ims[7])

In [None]:
imshow(neighbor_ims[141])