In [13]:
import torch
import numpy as np
import einops
from tqdm import tqdm
import os
from datasets import load_dataset
import pickle
import matplotlib.pyplot as plt

external_path=''

In [14]:
concepts=['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

In [15]:
dataset = load_dataset("uoft-cs/cifar10")

train_set = dataset["train"]
test_set = dataset["test"]

In [16]:
class AutoEncoder(torch.nn.Module):
    def __init__(self, activation_size, n_dict_components, t_type=torch.float32):
        super(AutoEncoder, self).__init__()
        
        self.decoder = torch.nn.Linear(n_dict_components, activation_size, bias=True)
        self.encoder_bias= torch.nn.Parameter(torch.zeros(n_dict_components))
        torch.nn.init.orthogonal_(self.decoder.weight)
        self.decoder = self.decoder.to(t_type)
        self.encoder = torch.nn.Sequential(torch.nn.ReLU()).to(t_type)
        self.activation_size = activation_size
        self.n_dict_components = n_dict_components

    def forward(self, x):
        c=self.encoder(x@self.decoder.weight+self.encoder_bias)
        self.decoder.weight.data=torch.nn.functional.normalize(self.decoder.weight.data,dim=0)
        x_hat=self.decoder(c)
        return x_hat,c
    
def AutoEncoderLoss(inputs,target,alpha=1e-3):
    return torch.norm(target-inputs[0],p=2,dim=1).pow(2)+alpha*torch.norm(inputs[1],p=1,dim=1)

In [17]:
def feature_dictionary_construction(concepts,layer,expansion_factor=4,epochs=50,batch_size=128,lr=1e-3,alpha=1e-3):
    for n,category in enumerate(concepts):
        for k in range(1,11):
            if n==0 and k==1:
                activations=torch.load(f'{external_path}\\concept_token_activations\\{category}\\layer{layer}_{k}.pt').squeeze(1)
            else:
                activations=torch.cat([activations,torch.load(f'{external_path}\\concept_token_activations\\{category}\\layer{layer}_{k}.pt').squeeze(1)])
    activations=activations.detach()

    sparse_autoencoder=AutoEncoder(activations.shape[1],expansion_factor*activations.shape[1])
    optimizer=torch.optim.Adam(sparse_autoencoder.parameters(),lr=lr)

    dataset_size=activations.shape[0]
    for epoch in range(epochs):
        epoch_loss=0
        epoch_cycles=dataset_size//batch_size
        if dataset_size%batch_size==0:
            epoch_cycles+=1
        
        for k in range(epoch_cycles):
            if k==epoch_cycles-1:
                epoch_activations=activations[k*batch_size:,:]
            else:
                epoch_activations=activations[k*batch_size:(k+1)*batch_size,:]
            optimizer.zero_grad()
            outputs=sparse_autoencoder(epoch_activations)
            loss=AutoEncoderLoss(outputs,epoch_activations,alpha).mean()
            loss.backward()
            optimizer.step()
            epoch_loss+=loss*epoch_activations.shape[0]
        epoch_loss/=activations.shape[0]
    return sparse_autoencoder,epoch_loss


In [18]:
layer=1
if not(os.path.exists(f'{external_path}\\features\\{layer}')):
    os.makedirs(f'{external_path}\\features\\{layer}')
pbar=tqdm(concepts)
for positive_concept in pbar:
    for negative_concept in concepts:
        if positive_concept==negative_concept:
            continue
        pbar.set_description(f'{positive_concept}_{negative_concept}')
        if os.path.exists(f'{external_path}\\features\\{layer}\\losses'):
            losses_file=open(f'{external_path}\\features\\{layer}\\losses','rb')
            losses=pickle.load(losses_file)
            losses_file.close()
        else:
            losses={}

        if not(os.path.exists(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}')):
            os.mkdir(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}')

        if os.path.exists(f'{external_path}\\features\\{layer}\\{negative_concept}_{positive_concept}'):
            model_state_dict=torch.load(f'{external_path}\\features\\{layer}\\{negative_concept}_{positive_concept}\\model.pt')
            torch.save(model_state_dict,f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\model.pt')
            losses[f'{positive_concept}_{negative_concept}']=losses[f'{negative_concept}_{positive_concept}']
        else:
            pbar.set_description(f'{positive_concept}_{negative_concept}...training model...')
            model,loss=feature_dictionary_construction([positive_concept,negative_concept],layer)
            torch.save(model.state_dict(),f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\model.pt')
            losses[f'{positive_concept}_{negative_concept}']=loss
        losses_file=open(f'{external_path}\\features\\{layer}\\losses','wb')
        pickle.dump(losses,losses_file)
        losses_file.close()

truck_ship: 100%|██████████| 10/10 [46:01<00:00, 276.16s/it]                           


In [19]:
def activation_decomposition_to_features(concept_1,concept_2,layer,expansion_factor=4):

    for k in range(1,11):
        if k==1:
            concept_1_activations=torch.load(f'{external_path}\\concept_token_activations\\{concept_1}\\layer{layer}_{k}.pt').squeeze(1)
        else:
            concept_1_activations=torch.cat([concept_1_activations,torch.load(f'{external_path}\\concept_token_activations\\{concept_1}\\layer{layer}_{k}.pt').squeeze(1)])

    for k in range(1,11):
        if k==1:
            concept_2_activations=torch.load(f'{external_path}\\concept_token_activations\\{concept_2}\\layer{layer}_{k}.pt').squeeze(1)
        else:
            concept_2_activations=torch.cat([concept_2_activations,torch.load(f'{external_path}\\concept_token_activations\\{concept_2}\\layer{layer}_{k}.pt').squeeze(1)])
    

    sparse_autoencoder=AutoEncoder(concept_1_activations.shape[1],expansion_factor*concept_1_activations.shape[1])
    sparse_autoencoder.load_state_dict(torch.load(f'{external_path}\\features\\{layer}\\{concept_1}_{concept_2}\\model.pt'))
    sparse_autoencoder.eval()

    c_concept_1=sparse_autoencoder.encoder(concept_1_activations@sparse_autoencoder.decoder.weight+sparse_autoencoder.encoder_bias)
    c_concept_2=sparse_autoencoder.encoder(concept_2_activations@sparse_autoencoder.decoder.weight+sparse_autoencoder.encoder_bias)

    torch.save(c_concept_1,f'{external_path}\\features\\{layer}\\{concept_1}_{concept_2}\\{concept_1}_decompositions.pt')
    torch.save(c_concept_2,f'{external_path}\\features\\{layer}\\{concept_1}_{concept_2}\\{concept_2}_decompositions.pt')


In [20]:
layer=1
pbar=tqdm(concepts)
for positive_concept in pbar:
    for negative_concept in concepts:
        if positive_concept==negative_concept:
            continue
        pbar.set_description(f'{positive_concept}_{negative_concept}')
        activation_decomposition_to_features(positive_concept,negative_concept,layer)

truck_ship: 100%|██████████| 10/10 [00:42<00:00,  4.22s/it]        


In [21]:
def features_similar_to_concept_vector(positive_concept,negative_concept,layer,concept_vector_type='cav',expansion_factor=4):

    for k in range(1,11):
        if k==1:
            positive_concept_activations=torch.load(f'{external_path}\\concept_token_activations\\{positive_concept}\\layer{layer}_{k}.pt').squeeze(1)
        else:
            positive_concept_activations=torch.cat([positive_concept_activations,torch.load(f'{external_path}\\concept_token_activations\\{positive_concept}\\layer{layer}_{k}.pt').squeeze(1)])

    sparse_autoencoder=AutoEncoder(positive_concept_activations.shape[1],expansion_factor*positive_concept_activations.shape[1])
    sparse_autoencoder.load_state_dict(torch.load(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\model.pt'))
    sparse_autoencoder.eval()

    feature_dictionary=sparse_autoencoder.decoder.weight.data

    if concept_vector_type=='cav':
        concept_vector=torch.load(f'{external_path}\\concept_activation_vectors\\{layer}\\{positive_concept}_{negative_concept}.pt')
    elif concept_vector_type=='cbv':
        concept_vector=torch.load(f'{external_path}\\concept_boundary_vectors\\{layer}\\{positive_concept}_{negative_concept}.pt')
    else:
        raise ValueError('Enter valid concept vector, either cav or cbv')

    activation_decompositions=torch.cat([torch.load(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\{positive_concept}_decompositions.pt'),torch.load(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\{negative_concept}_decompositions.pt')])

    sparsities=torch.sum(activation_decompositions>0,axis=0)/activation_decompositions.shape[0]
    
    alive_features=torch.where(sparsities>0)
    alive_features_dictionary=feature_dictionary[:,alive_features[0]]

    alive_features_dictionary=alive_features_dictionary-torch.mean(alive_features_dictionary,axis=1,keepdim=True)
    
    dots_with_concept_vector=np.zeros(alive_features_dictionary.shape[1])

    for k in range(alive_features_dictionary.shape[1]):
        dots_with_concept_vector[k]=(torch.dot(concept_vector,alive_features_dictionary[:,k])/torch.norm(alive_features_dictionary[:,k])).item()

    most_similar_features=alive_features[0][np.argsort(dots_with_concept_vector)[-5:]]

    firing_images_relative=np.zeros((5,5),dtype=int)
    for n in range(5):
        firing_images_relative[n,:]=torch.argsort(activation_decompositions[:,most_similar_features[-(n+1)]],descending=True)[:5].detach().numpy()

    concept_correctly_classified_indices_file=open(f'{external_path}\\concept_correctly_classified_indices','rb')
    concept_correctly_classified_indices=pickle.load(concept_correctly_classified_indices_file)
    concept_correctly_classified_indices_file.close()

    correctly_classified_indices=np.concatenate([concept_correctly_classified_indices[positive_concept],concept_correctly_classified_indices[negative_concept]])

    firing_images_absolute=np.zeros((5,5),dtype=int)
    for n in range(5):
        firing_images_absolute[n,:]=correctly_classified_indices[firing_images_relative[n,:]]
    
    fig,axs=plt.subplots(nrows=5,ncols=5)
    if concept_vector_type=='cav':
        fig.suptitle('Features Most Similar to CAV')
    elif concept_vector_type=='cbv':
        fig.suptitle('Features Most Similar to CBV')
    for n in range(5):
        for k,instance in enumerate(test_set):
            if k in firing_images_absolute[n,:]:
                idx=np.where(firing_images_absolute[n,:]==k)[0][0]
                axs[n,idx].imshow(np.array(instance['img']))
                if idx==0:
                    axs[n,idx].set_ylabel(str(most_similar_features[-(n+1)].item()))
                    axs[n,idx].xaxis.set_visible(False)
                    axs[n,idx].tick_params(left=False, labelleft=False)
                else:
                    axs[n,idx].axis('off')
    plt.savefig(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\most_similar_{positive_concept}_{negative_concept}_{concept_vector_type}.png')

    plt.close()

    np.save(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\most_similar_{positive_concept}_{negative_concept}_{concept_vector_type}.npy',np.flip(most_similar_features.numpy()))

In [22]:
layer=1
pbar=tqdm(concepts)
for positive_concept in pbar:
    for negative_concept in concepts:
        if positive_concept==negative_concept:
            continue
        pbar.set_description(f'Layer {layer}: {positive_concept}_{negative_concept} CAV')
        features_similar_to_concept_vector(positive_concept,negative_concept,layer)
        pbar.set_description(f'Layer {layer}: {positive_concept}_{negative_concept} CBV')
        features_similar_to_concept_vector(positive_concept,negative_concept,layer,concept_vector_type='cbv')

Layer 1: truck_ship CBV: 100%|██████████| 10/10 [40:26<00:00, 242.68s/it]        


In [23]:
def random_features(positive_concept,negative_concept,layer,expansion_factor=4):

    for k in range(1,11):
        if k==1:
            positive_concept_activations=torch.load(f'{external_path}\\concept_token_activations\\{positive_concept}\\layer{layer}_{k}.pt').squeeze(1)
        else:
            positive_concept_activations=torch.cat([positive_concept_activations,torch.load(f'{external_path}\\concept_token_activations\\{positive_concept}\\layer{layer}_{k}.pt').squeeze(1)])

    sparse_autoencoder=AutoEncoder(positive_concept_activations.shape[1],expansion_factor*positive_concept_activations.shape[1])
    sparse_autoencoder.load_state_dict(torch.load(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\model.pt'))
    sparse_autoencoder.eval()

    feature_dictionary=sparse_autoencoder.decoder.weight.data

    activation_decompositions=torch.cat([torch.load(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\{positive_concept}_decompositions.pt'),torch.load(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\{negative_concept}_decompositions.pt')])

    sparsities=torch.sum(activation_decompositions>0,axis=0)/activation_decompositions.shape[0]

    alive_features=torch.where(sparsities>0)
    alive_features_dictionary=feature_dictionary[:,alive_features[0]]

    random_features=alive_features[0][np.random.choice(len(alive_features[0]),size=5,replace=False)]

    firing_images_relative=np.zeros((5,5),dtype=int)
    for n in range(5):
        firing_images_relative[n,:]=torch.argsort(activation_decompositions[:,random_features[n]],descending=True)[:5].detach().numpy()

    concept_correctly_classified_indices_file=open(f'{external_path}\\concept_correctly_classified_indices','rb')
    concept_correctly_classified_indices=pickle.load(concept_correctly_classified_indices_file)
    concept_correctly_classified_indices_file.close()

    correctly_classified_indices=np.concatenate([concept_correctly_classified_indices[positive_concept],concept_correctly_classified_indices[negative_concept]])

    firing_images_absolute=np.zeros((5,5),dtype=int)
    for n in range(5):
        firing_images_absolute[n,:]=correctly_classified_indices[firing_images_relative[n,:]]

    fig,axs=plt.subplots(nrows=5,ncols=5)
    fig.suptitle('Random Features')

    for n in range(5):
        for k,instance in enumerate(test_set):
            if k in firing_images_absolute[n,:]:
                idx=np.where(firing_images_absolute[n,:]==k)[0][0]
                axs[n,idx].imshow(np.array(instance['img']))
                if idx==0:
                    axs[n,idx].set_ylabel(str(random_features[-(n+1)].item()))
                    axs[n,idx].xaxis.set_visible(False)
                    axs[n,idx].tick_params(left=False, labelleft=False)
                else:
                    axs[n,idx].axis('off')

    plt.savefig(f'{external_path}\\features\\{layer}\\{positive_concept}_{negative_concept}\\random_{positive_concept}_{negative_concept}.png')

    plt.close()


In [24]:
layer=1
pbar=tqdm(concepts)
for positive_concept in pbar:
    for negative_concept in concepts:
        if positive_concept==negative_concept:
            continue
        pbar.set_description(f'Layer {layer}: {positive_concept}_{negative_concept} CAV')
        random_features(positive_concept,negative_concept,layer)

Layer 1: truck_ship CAV: 100%|██████████| 10/10 [22:12<00:00, 133.27s/it]        
