In [1]:
import copy
import itertools
import os
import sys
sys.path.append("/workspace/mta_vision_transformers/")
from collections import OrderedDict
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple

import matplotlib
import matplotlib.colors
import numpy as np
import einops
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import torch.utils.data
from matplotlib import pyplot as plt
from tensordict import TensorDict
from torch.utils._pytree import tree_flatten

from core.monitor import Monitor
from infrastructure import utils
from infrastructure.settings import DEVICE, OUTPUT_DEVICE, DTYPE
from dataset.construct import ImageDataset
from dataset.library import DATASETS


dataset_name, n_classes = DATASETS["Common"][1]
OUTPUT_DIR = "experiments/compression"
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
# Ocean: 901085904
# Rose: 100390212
torch.set_printoptions(linewidth=400, sci_mode=False)

Seed: 1149496617


In [4]:
from dataset.evaluation import ImageTextDataset, run_retrieval_evaluation, print_retrieval_metrics, DEFAULT_DATASET
from modeling.image_features import ImageFeatures
from modeling.openclip_vit import OpenCLIPViT
from modeling.vit_attention import OpenCLIPAttentionViT


utils.reset_seed()
# Run evaluation
evaluation_kwargs: Dict[str, Any] = {"subsample": 5000, "n_ev": 1}

# Evaluate base model
print("=" * 120)
print("Base model")
print("=" * 120)

baseline_fname = f"{OUTPUT_DIR}/metrics/baseline.pt"
if not os.path.exists(baseline_fname):
    baseline_model = OpenCLIPViT().to(DEVICE)
    baseline_metrics: TensorDict = run_retrieval_evaluation(baseline_model, **evaluation_kwargs)
    torch.save(baseline_metrics, baseline_fname)
else:
    baseline_metrics: TensorDict = torch.load(baseline_fname, map_location=DEVICE)
print_retrieval_metrics(baseline_metrics)
print()

# Evaluate compression model
print("=" * 120)
print("Compression model")
print("=" * 120)

mode: OpenCLIPAttentionViT.ModeOptions = "sink"
mask_type: OpenCLIPAttentionViT.MaskOptions = "X -> T"

MA_mask: torch.Tensor = torch.load(f"experiments/saved_masks/MA_mask.pt", map_location=DEVICE)
AS_mask: torch.Tensor = torch.load(f"experiments/saved_masks/AS_mask.pt", map_location=DEVICE)
mask_dict: Dict[str, torch.Tensor] = {
    # "all": (torch.arange(ImageFeatures.N + 1) > 0).expand((len(DEFAULT_DATASET), ImageFeatures.N + 1)),
    # "normal": (torch.arange(ImageFeatures.N + 1) > 0) * ~AS_mask,
    "MA": MA_mask,
    # "AS": AS_mask,
}

lo, hi = 13, 24
for k, mask in mask_dict.items():
    print(f"{k}:")
    compression_fname = f"{OUTPUT_DIR}/metrics/{mode}_({mask_type})_{k}[{lo}:{hi}].pt"
    if not os.path.exists(compression_fname):
        mask_layers: Iterable[int] = range(lo, hi)

        dataset = copy.copy(DEFAULT_DATASET)
        dataset.load_cache({"mask": mask})
        compression_model = OpenCLIPAttentionViT({i: (mode, mask_type) for i in range(lo, hi)}).to(DEVICE)
        compression_metrics: TensorDict = run_retrieval_evaluation(compression_model, dataset=dataset, **evaluation_kwargs)
        torch.save(compression_metrics, compression_fname)
    else:
        compression_metrics: TensorDict = torch.load(compression_fname, map_location=DEVICE)
    print_retrieval_metrics(compression_metrics)

# result_grid = np.empty((ImageFeatures.NUM_LAYERS, ImageFeatures.NUM_LAYERS), dtype=object).tolist()
# for lo in range(12, ImageFeatures.NUM_LAYERS):
#     for hi in range(lo + 1, ImageFeatures.NUM_LAYERS + 1):
#         compression_fname = f"{OUTPUT_DIR}/metrics/remove_normal[{lo}:{hi}].pt"
#         if not os.path.exists(compression_fname):
#             mask_layers: Iterable[int] = range(lo, hi)

#             AS_mask: torch.Tensor = torch.load(f"experiments/saved_masks/AS_mask.pt", map_location=DEVICE)
#             normal_mask: torch.Tensor = (torch.arange(ImageFeatures.N + 1) > 0) * ~AS_mask

#             dataset = copy.copy(DEFAULT_DATASET)
#             dataset.load_cache({"mask": normal_mask})
#             compression_model = OpenCLIPAttentionViT(mode, mask_type, mask_layers).to(DEVICE)
#             compression_metrics: TensorDict = run_retrieval_evaluation(compression_model, dataset=dataset, **evaluation_kwargs)
#             torch.save(compression_metrics, compression_fname)
#         else:
#             compression_metrics: TensorDict = torch.load(compression_fname, map_location=DEVICE)
        
#         result_grid[lo, hi] = compression_metrics
#         print_retrieval_metrics(compression_metrics)

Base model
Text-to-Image Retrieval Metrics:
R@1: 39.62%
R@2: 50.40%
R@5: 64.62%
Image-to-Text Retrieval Metrics:
R@1: 36.92%
R@2: 47.74%
R@5: 61.44%

Compression model
MA:


  baseline_metrics: TensorDict = torch.load(baseline_fname, map_location=DEVICE)
  MA_mask: torch.Tensor = torch.load(f"experiments/saved_masks/MA_mask.pt", map_location=DEVICE)
  AS_mask: torch.Tensor = torch.load(f"experiments/saved_masks/AS_mask.pt", map_location=DEVICE)
100%|██████████| 157/157 [00:55<00:00,  2.84it/s]

Text-to-Image Retrieval Metrics:
R@1: 39.64%
R@2: 50.50%
R@5: 64.68%
Image-to-Text Retrieval Metrics:
R@1: 37.24%
R@2: 47.88%
R@5: 61.36%



