In [71]:
from transformers import ViTImageProcessor, ViTForImageClassification
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import pickle

import plotly.express as px
import pandas as pd

external_path='c:\\Users\\thoma\\Documents\\working_docs\\LIoT_aidos_external\\ViT'

In [72]:
concepts=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

In [73]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, in_dimension: int):
        super().__init__()
        self.in_dimension=in_dimension
        self.linear=torch.nn.Linear(in_dimension,1)
        self.sigmoid=torch.nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.sigmoid(self.linear(x))

In [74]:
def get_dataset(positive_concept,negative_concept,layer):
    for k in range(1,11):
        if k==1:
            positive_concept_activations=torch.load(f'{external_path}\\concept_token_activations\\{positive_concept}\\layer{layer}_{k}.pt').squeeze(1)
        else:
            activations=torch.load(f'{external_path}\\concept_token_activations\\{positive_concept}\\layer{layer}_{k}.pt').squeeze(1)
            positive_concept_activations=torch.cat([positive_concept_activations,activations])
    
    for k in range(1,11):
        if k==1:
            negative_concept_activations=torch.load(f'{external_path}\\concept_token_activations\\{negative_concept}\\layer{layer}_{k}.pt').squeeze(1)
        else:
            activations=torch.load(f'{external_path}\\concept_token_activations\\{negative_concept}\\layer{layer}_{k}.pt').squeeze(1)
            negative_concept_activations=torch.cat([negative_concept_activations,activations])

    dataset=torch.cat([positive_concept_activations,negative_concept_activations])
    labels=torch.cat([torch.ones(positive_concept_activations.shape[0]),torch.zeros(negative_concept_activations.shape[0])])

    return dataset,labels

In [75]:
def get_cav(positive_concept,negative_concept,layer,lr=1e-3,batch_size=32,epochs=50):
    dataset,labels=get_dataset(positive_concept,negative_concept,layer)

    model=LinearClassifier(dataset.shape[1])
    optimizer=torch.optim.Adam(model.parameters(), lr=lr)
    criterion=torch.nn.BCELoss()

    for epoch in range(epochs):
        epoch_loss=0
        epoch_cycles=dataset.shape[0]//batch_size
        if dataset.shape[0]%batch_size==0:
            epoch_cycles+=1
        for k in range(epoch_cycles):
            optimizer.zero_grad()
            if k==epoch_cycles-1:
                epoch_data=dataset[k*batch_size:]
                epoch_labels=labels[k*batch_size:]
            else:
                epoch_data=dataset[k*batch_size:(k+1)*batch_size]
                epoch_labels=labels[k*batch_size:(k+1)*batch_size]
            outputs=model(epoch_data).T.squeeze(0)
            loss=criterion(outputs,epoch_labels)
            loss.backward()
            optimizer.step()
            epoch_loss+=loss.item()*len(epoch_labels)
        epoch_loss/=dataset.shape[0]
    cav=model.linear.weight[0].detach()
    return cav/torch.norm(cav),epoch_loss

In [76]:
pbar=tqdm(range(11,12))
for layer in pbar:
    if not(os.path.exists(f'{external_path}\\concept_activation_vectors\\{layer}')):
        os.mkdir(f'{external_path}\\concept_activation_vectors\\{layer}')
    losses={}
    for positive_concept in concepts:
        for negative_concept in concepts:
            if positive_concept==negative_concept:
                continue
            pbar.set_description(f'Layer {layer}: {positive_concept}_{negative_concept}')
            cav,loss=get_cav(positive_concept,negative_concept,layer)
            torch.save(cav,f'{external_path}\\concept_activation_vectors\\{layer}\\{positive_concept}_{negative_concept}.pt')
            losses[f'{positive_concept}_{negative_concept}']=loss
    losses_file=open(f'{external_path}\\concept_activation_vectors\\{layer}\\losses','wb')
    pickle.dump(losses,losses_file)
    losses_file.close()

Layer 11: truck_ship: 100%|██████████| 1/1 [19:25<00:00, 1165.66s/it]
