In [1]:
%matplotlib inline 
%load_ext autoreload
%autoreload 2

import timm
import torch
import os
import tabulate

device = 'cpu' # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
basepath = os.path.abspath('../figures/key-value-and-class-value-agreement')
os.makedirs(basepath, exist_ok=True)

model_names = ['vit_base_patch16_224', 
               'vit_base_patch16_224_miil', 
               'vit_base_patch32_224', 
               'vit_large_patch16_224']
pretty_model_names = {
    'vit_base_patch16_224': 'ViT-B/16', 
    'vit_base_patch16_224_miil': 'ViT-B/16-MIIL', 
    'vit_base_patch32_224': 'ViT-B/32', 
    'vit_large_patch16_224': 'ViT-L/16',
    'imagenet_val': 'ImageNet1k-Val'
}
models = {
    model_name: timm.create_model(model_name, pretrained=True).eval().to(device)
    for model_name in model_names
}

### MLP

In [22]:
from src.utils.extraction import extract_value_vectors
from src.utils.model import embedding_projection
from src.analyzers.vector_analyzer import k_most_predictive_ind_for_classes
from src.analyzers.vector_analyzer import shared_value_vectors

k = 1

projected_values = {
    model_name: embedding_projection(models[model_name], 
                                     extract_value_vectors(models[model_name], device), device)
    for model_name in model_names
}
k_most_pred_inds = {
    model_name: k_most_predictive_ind_for_classes(projected_values[model_name], k, device)
    for model_name in model_names
}
shared_vectors = {
    model_name: shared_value_vectors(k_most_pred_inds[model_name], True)
    for model_name in model_names
}

In [30]:
for model_name, shared in shared_vectors.items():
    count = len(list(filter(lambda x: len(x[1]) > 1, [tuple(item) for item in shared.items()])))
    print(model_name, count)

vit_base_patch16_224 141
vit_base_patch16_224_miil 126
vit_base_patch32_224 148
vit_large_patch16_224 31


In [29]:
for model_name, shared in shared_vectors.items():
    count = len(list(filter(lambda x: len(x[1]) > 2, [tuple(item) for item in shared.items()])))
    print(model_name, count)

vit_base_patch16_224 35
vit_base_patch16_224_miil 66
vit_base_patch32_224 47
vit_large_patch16_224 3


In [23]:
include_top = 4

rows = []
for model_name, shared in shared_vectors.items():
    shared = sorted([tuple(item) for item in shared.items()], key=lambda x: len(x[1]), reverse=True)
    for i in range(include_top):
        rows.append((pretty_model_names[model_name], i+1, shared[i][0][0]+1, len(list(map(lambda x: x[0], shared[i][1]))), 
                     ', '.join(map(lambda x: x[0], shared[i][1]))))
        
tbl_headers = ['Model', 'Top', 'Block', '$|\\text{Classes}|$', 'Classes']
tabulate.tabulate(rows, tbl_headers, floatfmt='.3f', tablefmt='html')

Model,Top,Block,$|	ext{Classes}|$,Classes
ViT-B/16,1,11,20,"green_lizard, vine_snake, pot, cardoon, leaf_beetle, daisy, lacewing, tiger_beetle, bee, coral_fungus, lycaenid, harvestman, sea_anemone, tree_frog, American_chameleon, hip, sulphur_butterfly, cabbage_butterfly, gyromitra, monarch"
ViT-B/16,2,11,19,"ground_beetle, hermit_crab, water_snake, wood_rabbit, hen-of-the-woods, Arctic_fox, dung_beetle, hornbill, mink, echidna, earthstar, guenon, green_snake, alligator_lizard, mongoose, porcupine, crayfish, macaque, sea_cucumber"
ViT-B/16,3,11,13,"carton, toyshop, jersey, shoe_shop, nipple, crossword_puzzle, shower_cap, tobacco_shop, paper_towel, bolo_tie, plunger, Christmas_stocking, potter's_wheel"
ViT-B/16,4,11,10,"Scotch_terrier, Boston_bull, French_bulldog, Norfolk_terrier, diaper, briard, affenpinscher, Italian_greyhound, redbone, Brabancon_griffon"
ViT-B/16-MIIL,1,11,29,"Lhasa, Maltese_dog, Shih-Tzu, golden_retriever, English_setter, Pekinese, beagle, Border_collie, Great_Pyrenees, kelpie, Labrador_retriever, Brittany_spaniel, clumber, Pomeranian, Irish_setter, Norwegian_elkhound, boxer, cocker_spaniel, Saluki, Sussex_spaniel, collie, German_shepherd, toy_poodle, Chihuahua, Shetland_sheepdog, Norwich_terrier, English_springer, pug, Tibetan_terrier"
ViT-B/16-MIIL,2,11,10,"minibus, fire_engine, trolleybus, ambulance, police_van, recreational_vehicle, jeep, tow_truck, pickup, trailer_truck"
ViT-B/16-MIIL,3,11,10,"fireboat, speedboat, trimaran, catamaran, liner, dock, schooner, yawl, pirate, lifeboat"
ViT-B/16-MIIL,4,11,9,"grille, minivan, car_wheel, sports_car, racer, convertible, limousine, beach_wagon, cab"
ViT-B/32,1,9,22,"briard, cairn, bluetick, malamute, white_wolf, Border_terrier, kelpie, Labrador_retriever, wild_boar, chow, Rottweiler, kuvasz, malinois, bull_mastiff, Norwegian_elkhound, Scotch_terrier, groenendael, Newfoundland, giant_schnauzer, Siberian_husky, pug, soft-coated_wheaten_terrier"
ViT-B/32,2,10,18,"quail, partridge, ptarmigan, bee_eater, brambling, magpie, ruffed_grouse, hornbill, house_finch, bittern, coucal, indigo_bunting, jay, bulbul, little_blue_heron, bustard, oystercatcher, black_grouse"


### Attn

In [35]:
from src.utils.extraction import extract_mhsa_proj_vectors

projected_values = {
    model_name: embedding_projection(models[model_name], 
                                     extract_mhsa_proj_vectors(models[model_name], device), device)
    for model_name in model_names
}
k_most_pred_inds = {
    model_name: k_most_predictive_ind_for_classes(projected_values[model_name], k, device)
    for model_name in model_names
}
shared_vectors = {
    model_name: shared_value_vectors(k_most_pred_inds[model_name], True)
    for model_name in model_names
}

In [36]:
for model_name, shared in shared_vectors.items():
    count = len(list(filter(lambda x: len(x[1]) > 1, [tuple(item) for item in shared.items()])))
    print(model_name, count)

vit_base_patch16_224 93
vit_base_patch16_224_miil 125
vit_base_patch32_224 155
vit_large_patch16_224 93


In [37]:
for model_name, shared in shared_vectors.items():
    count = len(list(filter(lambda x: len(x[1]) > 2, [tuple(item) for item in shared.items()])))
    print(model_name, count)

vit_base_patch16_224 72
vit_base_patch16_224_miil 41
vit_base_patch32_224 104
vit_large_patch16_224 68


In [38]:
include_top = 4

rows = []
for model_name, shared in shared_vectors.items():
    shared = sorted([tuple(item) for item in shared.items()], key=lambda x: len(x[1]), reverse=True)
    for i in range(include_top):
        rows.append((pretty_model_names[model_name], i+1, shared[i][0][0]+1, 
                     len(list(map(lambda x: x[0], shared[i][1]))), 
                     ', '.join(map(lambda x: x[0], shared[i][1]))))
        
tbl_headers = ['Model', 'Top', 'Block', '$|\\text{Classes}|$', 'Classes']
tabulate.tabulate(rows, tbl_headers, floatfmt='.3f', tablefmt='html')

Model,Top,Block,$|\text{Classes}|$,Classes
ViT-B/16,1,12,63,"space_heater, grand_piano, stove, upright, rotisserie, iron, vacuum, hand_blower, radiator, toaster, hand-held_computer, barber_chair, computer_keyboard, home_theater, Polaroid_camera, gas_pump, scale, oscilloscope, grille, washer, typewriter_keyboard, mailbox, notebook, harmonica, space_bar, hard_disc, parking_meter, carpenter's_kit, cash_machine, dishwasher, loudspeaker, microphone, dial_telephone, desktop_computer, cassette_player, power_drill, cassette, digital_clock, printer, projector, file, safe, sewing_machine, desk, laptop, slot, remote_control, iPod, joystick, stethoscope, screen, monitor, mouse, radio, cellular_telephone, microwave, reflex_camera, television, CD_player, tape_player, dumbbell, photocopier, pay-phone"
ViT-B/16,2,12,57,"jinrikisha, harp, torch, sunglass, oboe, parallel_bars, cinema, neck_brace, balance_beam, plunger, bearskin, bobsled, ballplayer, French_horn, stretcher, Leonberg, oxygen_mask, minibus, go-kart, barbershop, military_uniform, police_van, violin, sax, accordion, bullet_train, gasmask, rugby_ball, unicycle, banjo, horizontal_bar, mountain_bike, ambulance, lab_coat, mortarboard, cowboy_hat, bow, stage, toyshop, steel_drum, bicycle-built-for-two, football_helmet, crutch, assault_rifle, cornet, basketball, flute, limousine, trombone, crash_helmet, marimba, soccer_ball, volleyball, bassoon, golfcart, cello, library"
ViT-B/16,3,12,55,"Crock_Pot, ice_cream, saltshaker, perfume, ashcan, goblet, piggy_bank, bottlecap, chocolate_sauce, cup, coffee_mug, whiskey_jug, carton, soap_dispenser, nipple, water_bottle, cheeseburger, pill_bottle, wine_bottle, dough, ice_lolly, refrigerator, milk_can, espresso, lotion, packet, eggnog, measuring_cup, jigsaw_puzzle, bakery, tray, toilet_tissue, crossword_puzzle, cocktail_shaker, confectionery, pitcher, paper_towel, book_jacket, medicine_chest, sunscreen, teapot, menu, water_jug, beer_bottle, pop_bottle, restaurant, trifle, beer_glass, red_wine, beaker, coffeepot, vase, vending_machine, candle, pretzel"
ViT-B/16,4,12,45,"minivan, space_shuttle, liner, missile, solar_dish, amphibian, pier, radio_telescope, maypole, street_sign, trimaran, planetarium, projectile, sports_car, dock, moving_van, tobacco_shop, beach_wagon, sandbar, balloon, drilling_platform, paddlewheel, pirate, container_ship, fireboat, speedboat, convertible, schooner, traffic_light, cab, water_tower, lifeboat, mountain_tent, turnstile, comic_book, wing, carousel, parachute, catamaran, gondola, breakwater, airliner, flagpole, umbrella, airship"
ViT-B/16-MIIL,1,11,19,"West_Highland_white_terrier, crane, ice_bear, oystercatcher, sulphur-crested_cockatoo, spoonbill, Sealyham_terrier, king_penguin, African_grey, pelican, black_stork, albatross, kuvasz, dalmatian, Samoyed, miniature_poodle, Eskimo_dog, American_egret, white_stork"
ViT-B/16-MIIL,2,11,6,"llama, sorrel, cougar, goose, African_hunting_dog, bison"
ViT-B/16-MIIL,3,11,6,"standard_poodle, toy_poodle, guenon, Scottish_deerhound, lesser_panda, Bedlington_terrier"
ViT-B/16-MIIL,4,12,6,"French_bulldog, Boston_bull, Cardigan, miniature_pinscher, kelpie, Mexican_hairless"
ViT-B/32,1,12,86,"brown_bear, Ibizan_hound, flat-coated_retriever, langur, basset, Chesapeake_Bay_retriever, llama, zebra, English_setter, Leonberg, EntleBucher, baboon, Staffordshire_bullterrier, bull_mastiff, boxer, cocker_spaniel, collie, miniature_schnauzer, redbone, Brabancon_griffon, Tibetan_mastiff, orangutan, dhole, Irish_water_spaniel, Gordon_setter, lion, Afghan_hound, Doberman, meerkat, Saint_Bernard, timber_wolf, cheetah, Irish_setter, Rottweiler, malinois, borzoi, Rhodesian_ridgeback, Saluki, Newfoundland, jaguar, Chihuahua, bison, tiger, kit_fox, water_buffalo, wombat, American_Staffordshire_terrier, ox, warthog, sloth_bear, gibbon, chimpanzee, beagle, gorilla, ostrich, Old_English_sheepdog, German_short-haired_pointer, ram, horse_cart, koala, Indian_elephant, Irish_wolfhound, siamang, groenendael, German_shepherd, Eskimo_dog, Shetland_sheepdog, papillon, hippopotamus, ibex, Mexican_hairless, cougar, Bernese_mountain_dog, Greater_Swiss_Mountain_dog, Border_collie, whippet, kelpie, Labrador_retriever, giant_panda, spider_monkey, kuvasz, Arabian_camel, Australian_terrier, oxcart, African_elephant, Weimaraner"
ViT-B/32,2,12,29,"ballpoint, syringe, knot, hand_blower, screwdriver, lighter, rubber_eraser, combination_lock, pencil_sharpener, harmonica, pill_bottle, corkscrew, thimble, hammer, microphone, matchstick, ladle, measuring_cup, nail, oil_filter, lens_cap, iPod, drumstick, stethoscope, mousetrap, spindle, modem, fountain_pen, chain"
