In [1]:
%matplotlib inline 
%load_ext autoreload
%autoreload 2

import timm
import timm.models.vision_transformer
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_names = ['vit_base_patch16_224', 
               'vit_base_patch16_224_miil', 
               'vit_base_patch32_224', 
               'vit_large_patch16_224']
models = {
    model_name: timm.create_model(model_name, pretrained=True).eval().to(device)
    for model_name in model_names
}

In [3]:
from genericpath import isfile
from src.datasets.ImageNet import ImageNetValDataset
import numpy as np
import matplotlib.pyplot as plt
from src.utils.transformation import get_transforms
from tqdm.auto import tqdm
import sqlite3
import os
from torchvision.models.feature_extraction import create_feature_extractor
from src.utils.extraction import extract_value_vectors
from src.utils.model import embedding_projection
from src.analyzers.vector_analyzer import k_most_predictive_ind_for_classes
import torch.nn.functional as F
from pathlib import Path
import torch.nn as nn

k_most_pred_by_model = {
    model_name: k_most_predictive_ind_for_classes(
        embedding_projection(models[model_name], extract_value_vectors(models[model_name], device), device),
        10, device)
    for model_name in model_names
}

mlp_fc2_biases = {
    model_name: [models[model_name].blocks[i].mlp.fc2.bias.detach() 
                 for i in range(len(models[model_name].blocks))]
    for model_name in model_names
}
mlp_fc2_biases_l2 = {
    model_name: [(mlp_fc2_biases[model_name][i] ** 2).sum().sqrt().item() 
                 for i in range(len(mlp_fc2_biases[model_name]))]
    for model_name in model_names
}
mlp_fc2_weights = {
    model_name: [models[model_name].blocks[i].mlp.fc2.weight.T.detach()
                 for i in range(len(models[model_name].blocks))]
    for model_name in model_names
}
model_heads = {
    model_name: models[model_name].head.eval()
    for model_name in model_names
}
mlp_fc2_biases_pred = {
    model_name: [model_heads[model_name](bias) for bias in mlp_fc2_biases[model_name]]
    for model_name in model_names
}

sqlite3.register_adapter(np.int64, int)
sqlite3.register_adapter(np.int32, int)

def create_pred_db(path):

    transforms = get_transforms(models[model_names[0]])

    dataset = ImageNetValDataset('A:\\CVData\\ImageNet', transforms=transforms)
    batch_size = 15
    k = 10

    connection = sqlite3.connect(path)
    cursor = connection.cursor()
    connection.execute(f"""
        CREATE TABLE IF NOT EXISTS predictions (
            path TEXT,
            pred_model TEXT,
            imagenet_id TEXT,
            num_idx INTEGER,
            name TEXT,
            {','.join([f'top_{i}_score REAL' for i in range(k)])},
            {','.join([f'top_{i}_ind INTEGER' for i in range(k)])},
            sum_row REAL,
            l1_row REAL,
            l2_row REAL,
            exp_reciprocal REAL,
            min_row REAL
        )
    """)
    connection.execute(f"""
        CREATE TABLE IF NOT EXISTS vec_activations (
            path TEXT,
            pred_model TEXT,
            imagenet_id TEXT,
            num_idx INTEGER,
            name TEXT,
            {','.join([f'top_{i}_cls_token_act REAL' for i in range(k)])},
            {','.join([f'top_{i}_img_token_avg_act REAL' for i in range(k)])},
            {','.join([f'top_{i}_block_ind INTEGER' for i in range(k)])},
            {','.join([f'top_{i}_vec_ind INTEGER' for i in range(k)])},
            {','.join([f'top_{i}_max_cls_token_act REAL' for i in range(k)])},
            {','.join([f'top_{i}_max_cls_act_block_ind INTEGER' for i in range(k)])},
            {','.join([f'top_{i}_max_cls_act_vec_ind INTEGER' for i in range(k)])},
            {','.join([f'top_{i}_max_img_token_avg_act REAL' for i in range(k)])},
            {','.join([f'top_{i}_max_img_avg_act_block_ind INTEGER' for i in range(k)])},
            {','.join([f'top_{i}_max_img_avg_act_vec_ind INTEGER' for i in range(k)])},
            top_bias_l2 REAL,
            top_bias_pl_res_l2 REAL,
            top_bias_pl_vec_noise_l2 REAL,
            top_bias_pl_all_l2 REAL,
            path_bias_pl_res_pred TEXT,
            path_bias_pl_vec_noise_pred TEXT,
            path_bias_pl_all_pred TEXT,
            mean_bias_pl_res_pred REAL,
            mean_bias_pl_vec_noise_pred REAL,
            mean_bias_pl_all_pred REAL,
            std_bias_pl_res_pred REAL,
            std_bias_pl_vec_noise_pred REAL,
            std_bias_pl_all_pred REAL,
            max_bias_pl_res_pred REAL,
            max_bias_pl_vec_noise_pred REAL,
            max_bias_pl_all_pred REAL,
            mean_top_bias_pred REAL,
            std_top_bias_pred REAL,
            max_top_bias_pred REAL,
            mean_resid_pred REAL,
            std_resid_pred REAL,
            max_resid_pred REAL,
            mean_vec_noise_pred REAL,
            std_vec_noise_pred REAL,
            max_vec_noise_pred REAL,
            mean_all_pred REAL,
            std_all_pred REAL,
            max_all_pred REAL
        )

    """)

    bias_pl_resid_path = 'A:\\CVData\\ImageNet\\val_top_bias_pl_resid'
    bias_pl_vec_noise_path = 'A:\\CVData\\ImageNet\\val_top_bias_pl_vec_noise'
    bias_pl_all_path = 'A:\\CVData\\ImageNet\\val_top_bias_pl_all'

    model_sums = {
        model_name: 0
        for model_name in model_names
    }
    total = len(dataset)

    for category in dataset.get_imagenet_classes():
        os.makedirs(os.path.join(bias_pl_resid_path, category), exist_ok=True)
        os.makedirs(os.path.join(bias_pl_vec_noise_path, category), exist_ok=True)
        os.makedirs(os.path.join(bias_pl_all_path, category), exist_ok=True)

    extractor_layers = {
        model_name: ['head']
        for model_name in model_names
    }
    bias_resid_hooks_handles = {
        model_name: []
        for model_name in model_names
    }

    residuals = {
        model_name: {}
        for model_name in model_names
    }

    def get_resid_hook(model: str, layer: str):
        def hook(module, input, output):
            residuals[model][layer] = input[0].detach()
        return hook

    for model_name, pred_inds in k_most_pred_by_model.items():
        # block_counts = pred_inds[:,0,:].flatten().bincount(minlength=len(models[model_name].blocks))
        block_counts_top1 = pred_inds[0,0,:].flatten().bincount(minlength=len(models[model_name].blocks))
        for i in range(len(models[model_name].blocks)):
            extractor_layers[model_name].append(f'blocks.{i}.mlp.fc1')
            # if block_counts[i].item() > 0:
            if block_counts_top1[i].item() > 0:
                handle = models[model_name].blocks[i].norm2.register_forward_hook(
                    get_resid_hook(model_name, f'blocks.{i}.norm2'))
                bias_resid_hooks_handles[model_name].append(handle)

    extractors = {
        model_name: create_feature_extractor(models[model_name], extractor_layers[model_name])
        for model_name in model_names
    }

    for i in tqdm(range(0, len(dataset), batch_size), desc='Batches'):
        
        length = min(len(dataset) - i, batch_size)
        items = [dataset[i+ii] for ii in range(length)]

        total += batch_size

        tensor_images = torch.stack([item['img'] for item in items], dim=0).to(device)

        labels = torch.tensor([item['num_idx'] for item in items]).to(device)

        for model_name, extractor in extractors.items():

            extraction = extractor(tensor_images)

            results = extraction['head']

            model_sums[model_name] += torch.gather(
                results, 1, labels.unsqueeze(1)).squeeze().sum().item()
            pred_vals, pred_inds = results.topk(k=k, dim=1)

            sums = results.sum(dim=1)
            l1_rows = results.abs().sum(dim=1)
            l2_rows = (results ** 2).sum(dim=1).sqrt()
            min_scores = results.min(dim=1)[0]
            exp_reciprocals = 1 / results.exp().sum(dim=1)

            key_activations_batch = torch.stack([extraction[f'blocks.{i}.mlp.fc1'] 
                                           for i in range(len(models[model_name].blocks))],
                                           dim=0)

            rows_pred = []
            rows_activations = []
            head = model_heads[model_name]
            for ii, item in enumerate(items):

                rows_pred.append((item['path'], model_name, item['imagenet_id'], 
                                  item['num_idx'], item['name'], sums[ii].item(), 
                                  l1_rows[ii].item(), l2_rows[ii].item(), 
                                  exp_reciprocals[ii].item(), min_scores[ii].item(),
                                  *[pred_vals[ii, iii].item() for iii in range(k)],
                                  *[pred_inds[ii, iii].item() for iii in range(k)]))
                
                k_most_pred = k_most_pred_by_model[model_name]
                idx = item['num_idx']
                key_activations = key_activations_batch[:,ii,:,:]

                top_i_cls_token_act = key_activations[k_most_pred[:,0,idx], 0, k_most_pred[:,1,idx]]
                top_i_img_token_avg_act = key_activations[k_most_pred[:,0,idx], 1:, k_most_pred[:,1,idx]].mean(dim=1)

                blocks, tokens, hidden = key_activations.shape

                top_i_max_cls_token_act, top_i_max_cls_token_act_inds = (
                    key_activations[:,0,:].flatten().topk(k=k))
                
                top_i_max_img_token_avg_act, top_i_max_img_token_avg_act_inds = (
                    key_activations[:,1:,:].mean(dim=1).flatten().topk(k=k))

                top_i_max_cls_token_act_block_inds = top_i_max_cls_token_act_inds // hidden
                top_i_max_cls_token_act_col_inds = top_i_max_cls_token_act_inds % hidden

                top_i_max_img_token_avg_act_block_inds = top_i_max_img_token_avg_act_inds // hidden
                top_i_max_img_token_avg_act_col_inds = top_i_max_img_token_avg_act_inds % hidden

                fc2_top_weight = mlp_fc2_weights[model_name][k_most_pred[0,0,idx].item()]
                fc2_top_bias = mlp_fc2_biases[model_name][k_most_pred[0,0,idx].item()]
                top_bias_pred = mlp_fc2_biases_pred[model_name][k_most_pred[0,0,idx].item()]

                cls_residual = residuals[model_name][f'blocks.{k_most_pred[0,0,idx].item()}.norm2']
                vec_noise = F.gelu(torch.concat([key_activations[k_most_pred[0,0,idx], 0, :k_most_pred[0,1,idx]],
                                          key_activations[k_most_pred[0,0,idx], 0, k_most_pred[0,1,idx]+1:]],
                                          dim=-1)) @ torch.concat([fc2_top_weight[:k_most_pred[0,1,idx],:],
                                                                   fc2_top_weight[k_most_pred[0,1,idx]+1:]],
                                                                   dim=0)
                
                resid_pred = head(cls_residual)
                top_bias_pl_res_l2 = cls_residual + fc2_top_bias
                top_bias_pl_res_pred = head(top_bias_pl_res_l2)
                top_bias_pl_res_l2 = (top_bias_pl_res_l2 ** 2).sum().sqrt().item()

                vec_noise_pred = head(vec_noise)
                top_bias_pl_vec_noise_l2 = vec_noise + fc2_top_bias
                top_bias_pl_vec_noise_pred = head(top_bias_pl_vec_noise_l2)
                top_bias_pl_vec_noise_l2 = (top_bias_pl_vec_noise_l2 ** 2).sum().sqrt().item()

                all_pred = head(vec_noise + cls_residual)
                top_bias_pl_all_l2 = cls_residual + vec_noise + fc2_top_bias
                top_bias_pl_all_pred = head(top_bias_pl_all_l2)
                top_bias_pl_all_l2 = (top_bias_pl_all_l2 ** 2).sum().sqrt().item()

                top_bias_pl_res_pred_path = os.path.join(bias_pl_resid_path, item['imagenet_id'], Path(item['path']).stem) + '.pt'
                top_bias_pl_vec_noise_pred_path = os.path.join(bias_pl_vec_noise_path, item['imagenet_id'], Path(item['path']).stem) + '.pt'
                top_bias_pl_all_pred_path = os.path.join(bias_pl_all_path, item['imagenet_id'], Path(item['path']).stem) + '.pt'

                torch.save(top_bias_pl_res_pred, top_bias_pl_res_pred_path)
                torch.save(top_bias_pl_vec_noise_pred, top_bias_pl_vec_noise_pred_path)
                torch.save(top_bias_pl_all_pred, top_bias_pl_all_pred_path)

                rows_activations.append((
                    item['path'], model_name, item['imagenet_id'], item['num_idx'], item['name'],
                    mlp_fc2_biases_l2[model_name][k_most_pred[0,0,idx].item()], top_bias_pl_res_l2,
                    top_bias_pl_vec_noise_l2, top_bias_pl_all_l2, top_bias_pl_res_pred_path,
                    top_bias_pl_vec_noise_pred_path, top_bias_pl_all_pred_path, top_bias_pl_res_pred.mean().item(),
                    top_bias_pl_vec_noise_pred.mean().item(), top_bias_pl_all_pred.mean().item(),
                    top_bias_pl_res_pred.std().item(), top_bias_pl_vec_noise_pred.std().item(),
                    top_bias_pl_all_pred.std().item(), top_bias_pred.mean().item(), top_bias_pred.std().item(),
                    resid_pred.mean().item(), resid_pred.std().item(), vec_noise_pred.mean().item(), 
                    vec_noise_pred.std().item(), all_pred.mean().item(), all_pred.std().item(),
                    top_bias_pl_res_pred.max().item(), top_bias_pl_vec_noise_pred.max().item(),
                    top_bias_pl_all_pred.max().item(), top_bias_pred.max().item(), resid_pred.max().item(),
                    vec_noise_pred.max().item(), all_pred.max().item(),
                    *[top_i_cls_token_act[iii].item() for iii in range(k)],
                    *[top_i_img_token_avg_act[iii].item() for iii in range(k)],
                    *[k_most_pred[iii,0,idx].item() for iii in range(k)],
                    *[k_most_pred[iii,1,idx].item() for iii in range(k)],
                    *[top_i_max_cls_token_act[iii].item() for iii in range(k)],
                    *[top_i_max_cls_token_act_block_inds[iii].item() for iii in range(k)],
                    *[top_i_max_cls_token_act_col_inds[iii].item() for iii in range(k)],
                    *[top_i_max_img_token_avg_act[iii].item() for iii in range(k)],
                    *[top_i_max_img_token_avg_act_block_inds[iii].item() for iii in range(k)],
                    *[top_i_max_img_token_avg_act_col_inds[iii].item() for iii in range(k)]
                ))
                

            cursor.executemany(f"""
                    INSERT INTO vec_activations (path, pred_model, imagenet_id, num_idx, name, top_bias_l2,
                    top_bias_pl_res_l2, top_bias_pl_vec_noise_l2, top_bias_pl_all_l2, path_bias_pl_res_pred,
                    path_bias_pl_vec_noise_pred, path_bias_pl_all_pred, mean_bias_pl_res_pred, 
                    mean_bias_pl_vec_noise_pred, mean_bias_pl_all_pred, std_bias_pl_res_pred,
                    std_bias_pl_vec_noise_pred, std_bias_pl_all_pred, mean_top_bias_pred, std_top_bias_pred,
                    mean_resid_pred, std_resid_pred, mean_vec_noise_pred, std_vec_noise_pred,
                    mean_all_pred, std_all_pred, 
                    max_bias_pl_res_pred, max_bias_pl_vec_noise_pred, 
                    max_bias_pl_all_pred, max_top_bias_pred, max_resid_pred, max_vec_noise_pred, max_all_pred,
                    {','.join([f'top_{iii}_cls_token_act' for iii in range(k)])},
                    {','.join([f'top_{iii}_img_token_avg_act' for iii in range(k)])},
                    {','.join([f'top_{iii}_block_ind' for iii in range(k)])},
                    {','.join([f'top_{iii}_vec_ind' for iii in range(k)])},
                    {','.join([f'top_{iii}_max_cls_token_act' for iii in range(k)])},
                    {','.join([f'top_{iii}_max_cls_act_block_ind' for iii in range(k)])},
                    {','.join([f'top_{iii}_max_cls_act_vec_ind' for iii in range(k)])},
                    {','.join([f'top_{iii}_max_img_token_avg_act' for iii in range(k)])},
                    {','.join([f'top_{iii}_max_img_avg_act_block_ind' for iii in range(k)])},
                    {','.join([f'top_{iii}_max_img_avg_act_vec_ind' for iii in range(k)])}) VALUES
                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
                    {','.join(['?' for _ in range(10*k)])})
                """, rows_activations
            )
            cursor.executemany(f"""
                    INSERT INTO predictions (path, pred_model, imagenet_id, num_idx, name,
                        sum_row, l1_row, l2_row, exp_reciprocal, min_row, 
                        {",".join([f'top_{iii}_score' for iii in range(k)])},
                        {",".join([f'top_{iii}_ind' for iii in range(k)])}) VALUES 
                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, {", ".join(["?" for _ in range(2*k)])})
                """, rows_pred
            )
        connection.commit()
        for item in items:
            del item['img']
        del tensor_images
        torch.cuda.empty_cache()

    cursor.close()
    connection.close()

    for model_name in model_names:
        for handle in bias_resid_hooks_handles[model_name]:
            handle.remove()

    model_scores = {
        model_name: model_sums[model_name] / total
        for model_name in model_names
    }

path = 'imgnet_val.db'
create_pred_db(path)

Batches:   0%|          | 0/3334 [00:00<?, ?it/s]

KeyboardInterrupt: 