In [37]:
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch.nn.functional as F
import torch.nn as nn
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import torch
import os
from datasets import Dataset
from concept_gradient import ConceptGradients
from torch.utils.data import DataLoader

In [38]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [39]:
class X2YModel(nn.Module):
    def __init__(self, model_name='./saved_target_model', num_classes=2):
        super(X2YModel, self).__init__()
        self.model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)
        
    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.model.roberta(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.model.roberta(input_ids=input_ids, attention_mask=attention_mask)
        
        return self.model.classifier(outputs.last_hidden_state)  

class X2CModel(nn.Module):
    def __init__(self, model_name='./saved_concept_model', num_concepts=5):
        super(X2CModel, self).__init__()
        self.model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=num_concepts, ignore_mismatched_sizes=True).to('cuda')

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.model.roberta(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:
            outputs = self.model.roberta(input_ids=input_ids, attention_mask=attention_mask)
        
        return self.model.classifier(outputs.last_hidden_state)

x2y_model = X2YModel().to(device)
x2c_model = X2CModel().to(device)

In [40]:
def forward_func(embeddings, attention_mask):
    output = x2y_model(inputs_embeds=embeddings, attention_mask=attention_mask)
    return output

def concept_forward_func(embeddings, attention_mask):
    output = x2c_model(inputs_embeds=embeddings, attention_mask=attention_mask)
    return output

In [41]:
cg = ConceptGradients(forward_func, concept_forward_func=concept_forward_func, x2y_model=x2y_model, x2c_model=x2c_model)

def calculate_concept_gradient(input_ids, attention_mask, target_index, concept_index, mode):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)  

    with torch.no_grad():
        embeddings = x2y_model.model.get_input_embeddings()(input_ids)
    embeddings.requires_grad_(True)
    attention_mask = attention_mask.float()
    attention_mask.requires_grad_(True)
    attr = cg.attribute(
        (embeddings, attention_mask),
        mode=mode,
        target=target_index,
        target_concept=concept_index,
        n_concepts=5,
        target_layer_name='roberta.encoder.layer.10.attention.output.dense',
        concept_layer_name='roberta.encoder.layer.10.attention.output.dense',
    )

    return attr

In [42]:
df_test = pd.read_csv('dataset/test.csv')
df_test = df_test.sample(2000, random_state=42)

In [43]:
df_test['concept_labels'] = df_test[['obscene', 'threat', 'sexual_explicit', 'insult', 'identity_attack']].values.tolist()
df_test.drop(columns=['obscene', 'threat', 'insult', 'severe_toxicity', 'id', 'identity_attack', 'sexual_explicit'], inplace=True, axis=1)
ds_test = Dataset.from_pandas(df_test)
df_test.head()

Unnamed: 0,comment_text,toxicity,concept_labels
6252,"It's an idiotic system offering only one moronic choice. It lacks the integrity to offer a True/False choice, and is essentially meaningless. A lie politely expressed is bogus civility.",1,"[1, 0, 0, 1, 1]"
4684,"I WAS RAPED BY AMY DEMBOSKI\n\nLet's talk about it, either validate or disprove it",1,"[0, 1, 1, 1, 1]"
1731,"I guess it's hard to keep an old race-baiting hypocrite like tangreedo down...\nHowever, if he doesn't understand the basic elements of the First Amendment, then he's not fit to hold any office. Let him waste more money running - as long as he's never voted in again.",1,"[1, 0, 0, 1, 1]"
4742,Mine as well. I gather one can only speculate the race of the perpetrator if it's white.,0,"[0, 0, 0, 0, 0]"
4521,"I didn't see what the Steelers did. The whole team wasn't on the field? I thought it classy what Dallas did. Good for them all. Including Jerry Jones (I think I just threw up in my mouth).\n\nIt would be tough to dump the song entirely. Heck, bring back the F-16 fly-by's and I'm sure we'd get over it. ;-)\n\n""...By the way, someone can be proud of being an American and also not a nationalist nazi....""\nYup. Just used that wording to make being a Cowboy fan that much worse. :-D",1,"[1, 0, 0, 1, 1]"


In [44]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

def tokenize_function(examples):
    return tokenizer(examples["comment_text"], padding="max_length", truncation=True)

tokenized_dataset = ds_test.map(tokenize_function, batched=True)
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "concept_labels", "toxicity"])

x2y_dl_test = DataLoader(tokenized_dataset, batch_size=8, shuffle=False)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [46]:
results = []

for batch in tqdm(x2y_dl_test, leave=True):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['toxicity'].to(device)
    concept_labels = batch['concept_labels']

    target_logits = x2y_model(input_ids=input_ids, attention_mask=attention_mask)
    target_preds = torch.argmax(target_logits, dim=-1)

    incorrect_indices = (target_preds != labels).nonzero(as_tuple=True)[0]

    if len(incorrect_indices) > 0:
        for idx in incorrect_indices:
            sample_input_ids = input_ids[idx].unsqueeze(0)
            sample_attention_mask = attention_mask[idx].unsqueeze(0)
            sample_label = labels[idx].item()
            sample_concept_label = concept_labels[idx].unsqueeze(0)
            sample_sentence = tokenizer.decode(sample_input_ids.squeeze(), skip_special_tokens=True)

            concept_gradient = calculate_concept_gradient(sample_input_ids, sample_attention_mask, target_index=sample_label, concept_index=None, mode='chain_rule_joint')
            concept_gradient = concept_gradient[0].detach().cpu().numpy()

            concept_logits = x2c_model(input_ids=sample_input_ids, attention_mask=sample_attention_mask)
            concept_logits = torch.sigmoid(concept_logits).detach().cpu().numpy()

            target_logits_final = x2y_model(input_ids=sample_input_ids, attention_mask=sample_attention_mask)
            target_logits_final = torch.softmax(target_logits_final, dim=-1).detach().cpu().numpy()

            results.append({
                "sentence": sample_sentence,
                "target_logits": target_logits_final,
                "concept_logits": concept_logits,
                "concept_gradient": concept_gradient,
                "label": sample_label,
                "concept_label": sample_concept_label.cpu().numpy()  
            })

df_results = pd.DataFrame(results)
df_results.to_csv("analysis_sheets/final_misclassified_samples_with_concept_gradients.csv", index=False)

print("Process complete. Results saved to 'misclassified_samples_with_concept_gradients.csv'")


  0%|          | 0/250 [00:00<?, ?it/s]

Process complete. Results saved to 'misclassified_samples_with_concept_gradients.csv'


In [41]:
df_results.to_csv("classified_samples_with_concept_gradients.csv", index=False)

In [16]:
import re

def clean_and_parse_gradient(example):
    if isinstance(example, str):
        cleaned_example = example.replace(' ', ',').replace(',,', ',').replace('[,', '[').replace(',]', ']')
        try:
            return ast.literal_eval(cleaned_example)  
        except:
            return None  
    return example

import numpy as np
def mean_concepts(x):
    return np.mean(x)

df_cc = pd.read_csv('analysis_sheets/final_classified_samples_with_concept_gradients.csv')
df_mc = pd.read_csv('analysis_sheets/final_misclassified_samples_with_concept_gradients.csv')
df_cc['concept_gradient'] = df_cc['concept_gradient'].apply(clean_and_parse_gradient)
df_mc['concept_gradient'] = df_mc['concept_gradient'].apply(clean_and_parse_gradient)
df_cc.to_csv('analysis_sheets/final_classified_samples_with_concept_gradients.csv', index=False)
df_mc.to_csv('analysis_sheets/final_misclassified_samples_with_concept_gradients.csv', index=False)

### TCAV

In [22]:
import os
import sys
import yaml
import torch
import glob
import numpy as np
import pandas as pd
from time import sleep
from scipy.stats import ttest_ind
from captum.attr import LayerActivation
from captum._utils.gradient import compute_layer_gradients_and_eval
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torchvision.datasets import CelebA, ImageFolder
from torch.utils.data import DataLoader, Dataset
import torchmetrics
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm, trange
import PIL
import seaborn

def save_tcav_results(trials, tcavs, save_npz_fname=None, force=False):
    
    stacked_tcavs = np.stack([np.stack(list(tcavs_.values()), axis=0) for tcavs_ in tcavs], axis=0)
    stacked_accs = np.stack([np.stack(list(trial[1].values()), axis=0)
                             for trial in trials], axis=0)
    
    if save_npz_fname is not None:
        if os.path.exists(save_npz_fname) and not force:
            print(f"{save_npz_fname} already exists.")
        else:
            np.savez(save_npz_fname, tcavs=stacked_tcavs, accs=stacked_accs)
    
    return stacked_tcavs, stacked_accs

class LinearModel(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.model = nn.Linear(in_dim, out_dim)
    def forward(self, x):
        return self.model(x)
    
class NoReduceMSE():
    def __init__(self):
        self.se = 0
        self.total = 0
    def __call__(self, pred, gt):
        self.se += ((pred - gt)**2).sum(0)
        self.total += pred.shape[0]
    def compute(self):
        return self.se / self.total
    
class TCAVScore(torchmetrics.Metric):
    def __init__(self, CAV, signed=True, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        
        assert len(CAV.shape) == 2
        self.CAV = CAV
        
        self.signed = signed
        
        self.add_state("sum", default=torch.zeros([self.CAV.shape[0]]), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, grads: torch.Tensor):
        
        with torch.no_grad():
            assert len(grads.shape) == 2
            assert grads.shape[-1] == self.CAV.shape[-1]

            grads = grads.unsqueeze(1)

            cos = F.cosine_similarity(grads, self.CAV, dim=-1)
            if self.signed:
                score = (cos > 0).sum(0)
            else:
                score = cos.sum(0)

            self.sum += score
            self.total += cos.shape[0]

    def compute(self):
        return self.sum.float() / self.total

import gc
class TCAV(nn.Module):
    def __init__(self, target_model, layer_names=None, cache_dir=None):
        super().__init__()
        
        self.target_model = target_model.eval()
        
        self.CAVs = None
        self.random_CAVs = None
        self.metrics = None
        self.cache_dir = cache_dir
        
        assert (layer_names is not None) or (cache_dir is not None)
        
        # reload from cache
        if self.cache_dir is not None and \
            os.path.exists(os.path.join(self.cache_dir, 'random_CAVs.npz')) and \
            os.path.exists(os.path.join(self.cache_dir, 'CAVs.npz')) and \
            os.path.exists(os.path.join(self.cache_dir, 'metrics.npz')):

            print("Loading `random_CAVs.npz`, `CAVs.npz`, and `metrics.npz` from cache...")
            with np.load(os.path.join(self.cache_dir, 'random_CAVs.npz')) as f:
                random_CAVs = {k: v for k, v in f.items()}
            assert all([len(v) > 0 for v in random_CAVs.values()])
            self.random_CAVs = random_CAVs

            with np.load(os.path.join(self.cache_dir, 'CAVs.npz')) as f:
                CAVs = {k: v for k, v in f.items()}
            assert all([len(v) > 0 for v in CAVs.values()])
            self.CAVs = CAVs
            
            with np.load(os.path.join(self.cache_dir, 'metrics.npz')) as f:
                metrics = {k: v for k, v in f.items()}
            assert all([len(v) > 0 for v in metrics.values()])
            self.metrics = metrics

            assert list(self.random_CAVs.keys()) == list(self.CAVs.keys())
            assert list(self.metrics.keys()) == list(self.CAVs.keys())

            self.layer_names = list(self.random_CAVs.keys())
            print(f"Using cached layer names: {self.layer_names}")
        else:
            self.layer_names = layer_names
        
        # searching for layers in target_model
        self.layers = {}
        # print(list(target_model.named_modules()))
        for name, layer in target_model.named_modules():
            if name in self.layer_names:
                self.layers[name] = layer
        if sorted(self.layer_names) != sorted(list(self.layers.keys())):
            raise ValueError(f"Keys {sorted(self.layer_names)} and {sorted(list(self.layers.keys()))} don't match.")
    
    @staticmethod            
    def get_class_balanced_sampler(ys, y_index):
        ys = ys[:, y_index]
        pos_ratio = ys.sum() / ys.shape[0]
        weights = ys * (1 - pos_ratio) + (1 - ys).abs() * pos_ratio
        sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
        return sampler

    def _generate_CAVs(self, dset_train, dset_valid, hparams=None, verbose=True):
        
        default_hparams = dict(task='classification', n_epochs=100, lr=1e-4, weight_decay=1e-2, 
                               batch_size=32, patience=10, pos_weight=None, num_workers=2)
        
        if hparams is None:
            hparams = default_hparams
        else:
            default_hparams.update(hparams)
            hparams = default_hparams
        
        dl_train = DataLoader(dset_train, batch_size=hparams['batch_size'], drop_last=False, 
                              num_workers=hparams['num_workers'], shuffle=False, persistent_workers=False)
        
        dl_valid = DataLoader(dset_valid, batch_size=hparams['batch_size'], drop_last=False, 
                              num_workers=hparams['num_workers'], shuffle=False, persistent_workers=False)
        
        device = next(self.target_model.parameters())
        cs_train = torch.cat([batch['labels'].detach().clone() for batch in tqdm(dl_train, leave=False)], dim=0)
        cs_valid = torch.cat([batch['labels'].detach().clone() for batch in tqdm(dl_valid, leave=False)], dim=0)
        
        CAVs, metrics = {}, {}
        
        for layer_name, layer in tqdm(self.layers.items(), leave=False, desc="Layers: "):
            layer_act = LayerActivation(self.target_model, layer)
            
            # extract activations
            acts_train = []
            max_batches = 140
            for batch_idx, batch in enumerate(tqdm(dl_train, leave=False)):
                try:
                    # if batch_idx >= max_batches:
                    #     break
                    input_ids = batch['input_ids'].to(device).long()
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)
                    
                    with torch.no_grad():
                        embeddings = self.target_model.model.get_input_embeddings()(input_ids)
                    
                    embeddings.requires_grad_(True)
                    attention_mask = attention_mask.float()
                    attention_mask.requires_grad_(True)
                    act = layer_act.attribute((embeddings, attention_mask), attribute_to_layer_input=True).flatten(start_dim=1)

                    np.save(os.path.join('cav', f"cav_acts/act_batch_{batch_idx}.npy"), act.detach().cpu().numpy())
                    # acts_train.append(act.detach().cpu().numpy())
                    del act, embeddings, attention_mask, input_ids, labels
                    torch.cuda.empty_cache()
                    gc.collect()
                except Exception as e:
                    print(f"Skipping batch due to error: {e}")
                    continue
                    
            acts_train = [np.load(os.path.join('cav', f"cav_acts/act_batch_{i}.npy")) for i in range(len(dl_train))]
            acts_train = [torch.tensor(act, dtype=torch.float32) for act in acts_train]

            # Now concatenate the list of tensors
            acts_train = torch.cat(acts_train, dim=0)
            layer_dset_train = torch.utils.data.TensorDataset(acts_train, cs_train)
            acts_valid = []
            for batch_idx, batch in enumerate(tqdm(dl_valid, leave=False)):
                try:
                    input_ids = batch['input_ids'].to(device).long()
                    attention_mask = batch['attention_mask'].to(device)
                    labels = batch['labels'].to(device)
                    
                    with torch.no_grad():
                        embeddings = self.target_model.model.get_input_embeddings()(input_ids)
                        
                    embeddings.requires_grad_(True)
                    attention_mask = attention_mask.float()
                    attention_mask.requires_grad_(True)
                    act = layer_act.attribute((embeddings, attention_mask), attribute_to_layer_input=True).flatten(start_dim=1)
                    in_dim, out_dim = act.shape[1], labels.shape[1]
                    np.save(os.path.join('cav', f"cav_acts_val/act_batch_{batch_idx}.npy"), act.detach().cpu().numpy())
                    del act, embeddings, attention_mask, input_ids, labels  
                    torch.cuda.empty_cache()
                    gc.collect()
                except Exception as e:
                    print(f"Skipping batch due to error: {e}")
                    continue
                
            acts_valid = [np.load(os.path.join('cav', f'cav_acts_val/act_batch_{i}.npy')) for i in range(len(dl_valid))]
            acts_valid = [torch.tensor(act, dtype=torch.float32) for act in acts_valid]

            # Now concatenate the list of tensors
            acts_valid = torch.cat(acts_valid, dim=0)
            # acts_valid = torch.cat(acts_valid, dim=0)
            
            layer_dset_valid = torch.utils.data.TensorDataset(acts_valid, cs_valid)
            
            if hparams['pos_weight'] is not None:
                if isinstance(hparams['pos_weight'], torch.Tensor):
                    pos_weight = hparams['pos_weight'].to(device)
                else:
                    pos_weight = hparams['pos_weight'] * torch.ones([out_dim]).to(device)
            else:
                pos_weight = None

            # sampler = get_class_balanced_sampler(all_cs, nc)
            layer_dl_train = DataLoader(layer_dset_train, batch_size=hparams['batch_size'], drop_last=False, 
                                        num_workers=hparams['num_workers'], shuffle=True)
            layer_dl_valid = DataLoader(layer_dset_valid, batch_size=hparams['batch_size'], drop_last=False, 
                                        num_workers=hparams['num_workers'], shuffle=False)

            # define model and optimizer
            linear_model = LinearModel(in_dim, out_dim).to(device)
            optimizer = optim.Adam(linear_model.parameters(), lr=hparams['lr'], 
                                   weight_decay=hparams['weight_decay'])
            if hparams['task'] == 'classification':
                loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                metric = torchmetrics.Accuracy(threshold=0.0, task = 'multilabel', num_classes=out_dim, num_labels=6, average=None).to(device)
            elif hparams['task'] == 'regression':
                loss_fn = torch.nn.MSELoss()
                metric = NoReduceMSE()
            else:
                raise NotImplementedError

            # train
            patience = 0
            min_loss = np.inf
            linear_model.train()
            with trange(hparams['n_epochs'], leave=False, desc="Epochs: ") as tepochs:
                for epoch in tepochs:
                    losses = []
                    for xs, cs in layer_dl_train:
                        xs = xs.to(device)
                        cs = cs.to(device)
                        optimizer.zero_grad()

                        logits_cs = linear_model(xs)
                        loss = loss_fn(logits_cs, cs.float())

                        loss.backward()
                        optimizer.step()
                        
                        losses.append(loss.item())
                    
                    if min_loss > np.mean(losses):
                        min_loss = np.mean(losses)
                        patience = 0
                    else:
                        patience += 1
                    tepochs.set_postfix(loss=f"{np.mean(losses):.4f}/{min_loss:.4f}")
                    sleep(0.1)
                    if patience > hparams['patience']:
                        tepochs.update(n=hparams['n_epochs'] - epoch)
                        tepochs.close()
                        break

            # eval
            linear_model.eval()
            for xs, cs in layer_dl_valid:
                xs = xs.to(device)
                cs = cs.to(device).to(cs_valid.dtype)
                
                with torch.no_grad():
                    pred_cs = linear_model(xs)
                    metric(pred_cs, cs)
            
            CAV = linear_model.model.weight.detach().clone()
            CAV = CAV / torch.norm(CAV, dim=1, keepdim=True)
            CAVs[layer_name] = CAV.cpu().numpy()
            metrics[layer_name] = metric.compute().detach().cpu().numpy()
            
        return CAVs, metrics
    
    def _generate_random_CAVs(self, dset_train, dset_valid):
        dl = DataLoader(dset_train, batch_size=32, drop_last=False, 
                        num_workers=8, shuffle=False)
        device = next(self.target_model.parameters())
        
        CAVs = {}
        for layer_name, layer in tqdm(self.layers.items(), leave=False):
            layer_act = LayerActivation(self.target_model, layer)
            # get in_dim
            for batch in dl:
                input_ids = batch['input_ids'].to(device).long()
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                with torch.no_grad():
                    embeddings = self.target_model.model.get_input_embeddings()(input_ids)
                    
                embeddings.requires_grad_(True)
                attention_mask = attention_mask.float()
                attention_mask.requires_grad_(True)
                act = layer_act.attribute((embeddings, attention_mask), attribute_to_layer_input=True).flatten(start_dim=1)
                in_dim, out_dim = act.shape[1], labels.shape[1]
                break
            CAV = (torch.rand(out_dim, in_dim) - 1)
            CAV = CAV / torch.norm(CAV, dim=1, keepdim=True)
            CAVs[layer_name] = CAV.cpu().numpy()
        
        return CAVs
    
    def generate_CAVs(self, dset_train, dset_valid, n_repeat=5, hparams=None, force_rewrite_cache=False):
        
        self.CAVs = {layer_name: [] for layer_name in self.layer_names}
        metrics = {layer_name: [] for layer_name in self.layer_names}
        
        # reload from cache
        if (self.cache_dir is not None) and (not force_rewrite_cache) and \
           (os.path.exists(os.path.join(self.cache_dir, 'CAVs.npz'))) and \
           (os.path.exists(os.path.join(self.cache_dir, 'metrics.npz'))):
            
            raise ValueError("Cached directory already exist. Use `force_rewrite_cache = True` to overwrite.")
            '''
            print("Loading from cache...")
            
            with np.load(os.path.join(self.cache_dir, 'CAVs.npz')) as f:
                self.CAVs.update({k: v for k, v in f.items()})
            with np.load(os.path.join(self.cache_dir, 'metrics.npz')) as f:
                metrics.update({k: v for k, v in f.items()})
            
            if all([len(v) > 0 for v in self.CAVs.values()]):
                return self.CAVs, metrics
            '''
        
        update_layer_names = [k for k, v in self.CAVs.items() if len(v) == 0]
        print(f"Generating TCAV for layers: {update_layer_names}")
        
        # generate
        for _ in trange(n_repeat, desc="#repeats: "):
            CAVs_, metrics_ = self._generate_CAVs(dset_train, dset_valid, hparams=hparams)
            for layer_name in update_layer_names:
                self.CAVs[layer_name].append(CAVs_[layer_name])
                metrics[layer_name].append(metrics_[layer_name])
                
        for layer_name in update_layer_names:
            self.CAVs[layer_name] = np.stack(self.CAVs[layer_name], axis=0)
            metrics[layer_name] = np.stack(metrics[layer_name], axis=0)

        if self.cache_dir is not None:
            os.makedirs(self.cache_dir, exist_ok=True)
            np.savez_compressed(os.path.join(self.cache_dir, 'CAVs.npz'), **self.CAVs)
            np.savez_compressed(os.path.join(self.cache_dir, 'metrics.npz'), **metrics)
            
        return self.CAVs, metrics
    
    def generate_random_CAVs(self, dset_train, dset_valid, n_repeat=5, force_rewrite_cache=False):
        
        random_CAVs = {layer_name: [] for layer_name in self.layer_names}
        
        # reload from cache
        if (self.cache_dir is not None) and (not force_rewrite_cache) and \
           (os.path.exists(os.path.join(self.cache_dir, 'random_CAVs.npz'))):
            
            raise ValueError("Cached directory already exist. Use `force_rewrite_cache = True` to overwrite.")
            
            '''
            print("Loading from cache...")
            
            with np.load(os.path.join(self.cache_dir, 'random_CAVs.npz')) as f:
                random_CAVs.update({k: v for k, v in f.items()})
            
            if all([len(v) > 0 for v in random_CAVs.values()]):
                return random_CAVs
            '''
        
        update_layer_names = [k for k, v in random_CAVs.items() if len(v) == 0]
        print(f"Generating random TCAV for layers: {update_layer_names}")
        
        for _ in trange(n_repeat):
            random_CAVs_ = self._generate_random_CAVs(dset_train, dset_valid)
            for layer_name in update_layer_names:
                random_CAVs[layer_name].append(random_CAVs_[layer_name])
                
        for layer_name in update_layer_names:
            random_CAVs[layer_name] = np.stack(random_CAVs[layer_name], axis=0)

        if self.cache_dir is not None:
            os.makedirs(self.cache_dir, exist_ok=True) 
            np.savez_compressed(os.path.join(self.cache_dir, 'random_CAVs.npz'), **random_CAVs)
        
        self.random_CAVs = random_CAVs
        return self.random_CAVs
    
    def generate_TCAVs(self, dset_valid, layer_name, target_index=None, score_signed=True, 
                       return_ttest_results=False, ttest_threshold=0.05):
        
        assert len(self.CAVs[layer_name]) == len(self.random_CAVs[layer_name])
        n_repeat = len(self.CAVs[layer_name])
        
        device = next(self.target_model.parameters())
        tcavs, random_tcavs = [], []
        
        concept_dl = DataLoader(dset_valid, batch_size=16, shuffle=False, 
                                drop_last=False, num_workers=8)
        
        for i in trange(n_repeat, leave=False):
            
            CAV = torch.from_numpy(self.CAVs[layer_name][i]).float().to(device)
            random_CAV = torch.from_numpy(self.random_CAVs[layer_name][i]).float().to(device)

            tcavs_ = TCAVScore(CAV, signed=score_signed).to(device)
            random_tcavs_ = TCAVScore(random_CAV, signed=score_signed).to(device)

            for xs, cs in concept_dl:
                xs = xs.to(device)
                cs = cs.to(device)

                layer_grads_, _ = compute_layer_gradients_and_eval(
                    self.target_model, layer, xs, target_ind=target_index, 
                    attribute_to_layer_input=True)
                del _
                layer_grads_ = layer_grads_[0].flatten(start_dim=1)

                tcavs_(layer_grads_)
                random_tcavs_(layer_grads_)

            tcavs.append(tcavs_.compute().detach().cpu().numpy())
            random_tcavs.append(random_tcavs_.compute().detach().cpu().numpy())
        
        random_tcavs = np.stack(random_tcavs, axis=0)
        tcavs = np.stack(tcavs, axis=0)
        
        print('random_tcavs:', random_tcavs.mean(0))
        print('tcavs:', tcavs.mean(0))
        
        # run two-sided test
        ttest_results = []
        for i in range(tcavs.shape[1]):
            ttest_result = ttest_ind(tcavs[:, i], random_tcavs[:, i])
            ttest_results.append(ttest_result.pvalue)
        ttest_results = np.array(ttest_results)
        
        avg_tcav_scores = tcavs.mean(0)
        avg_tcav_scores[~(ttest_results < ttest_threshold)] = np.nan
        
        if return_ttest_results:
            return avg_tcav_scores, ttest_results
        else:
            return avg_tcav_scores

    def attribute(self, inputs, layer_name, mode, target=None, abs=False, use_random=False, select_index=None):
        
        assert mode in ['inner_prod', 'cosine_similarity']
        
        
        if use_random:
            assert self.random_CAVs is not None
            
            if select_index is None:
                CAV_ = self.random_CAVs[layer_name].mean(0)
            else:
                CAV_ = self.random_CAVs[layer_name][select_index]
        else:
            assert self.CAVs is not None
            
            if select_index is None:
                CAV_ = self.CAVs[layer_name].mean(0)
            else:
                CAV_ = self.CAVs[layer_name][select_index]
        CAV = torch.from_numpy(CAV_).float().to(device)
        # print(CAV)
        # CAV.float().to(inputs.device)
        
        with torch.no_grad():
            grads, _ = compute_layer_gradients_and_eval(
                self.target_model, self.layers[layer_name], inputs, 
                target_ind=target, attribute_to_layer_input=True)
            del _
            grads = grads[0].flatten(start_dim=1)
            
            if mode == 'inner_prod':
                attributions = grads @ CAV.T
            elif mode == 'cosine_similarity':
                attributions = grads @ CAV.T / (torch.norm(grads, dim=1, keepdim=True) * \
                                                torch.norm(CAV.T, dim=0, keepdim=True))
            else:
                raise NotImplementedError
            
            if abs:
                attributions = torch.abs(attributions)
                
        return attributions
    
device='cuda' if torch.cuda.is_available() else 'cpu'

class X2YModel(nn.Module):
    def __init__(self, model_name='saved_target_model', num_classes=2):
        super(X2YModel, self).__init__()
        self.model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)
        
    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
        if inputs_embeds is not None:
            outputs = self.model.roberta(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        else:

            outputs = self.model.roberta(inputs_embeds=input_ids, attention_mask=attention_mask)
        
        return self.model.classifier(outputs.last_hidden_state)  

x2y_model = X2YModel().to(device)

chosen_model = x2y_model#x2c_model
chosen_cache_dir = 'cav'#'cav_x2c'



In [32]:
# Its for spreadsheet

layer_names = ['model.roberta.encoder.layer.10.attention.output.dense']#, 'model.roberta.encoder.layer.10.attention.output.dense', 'model.roberta.encoder.layer.9.attention.output.dense', 'model.roberta.encoder.layer.11.output.dense']
import pandas as pd
import torch
import numpy as np
from tqdm.auto import tqdm, trange

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import RobertaTokenizer, RobertaModel
from datasets import Dataset

df_test = df_cc#pd.read_csv("analyze_score_cc.csv")
ds_test = Dataset.from_pandas(df_test)

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True)

tokenized_dataset = ds_test.map(tokenize_function, batched=True)

tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x2y_dl_test = DataLoader(tokenized_dataset, batch_size=8, shuffle=False)
n_concepts = 5
print(f'Num concepts: {n_concepts}')
tcav = TCAV(x2y_model, layer_names=['model.roberta.encoder.layer.10.attention.output.dense'], cache_dir='cav')

attrs = []

for batch in tqdm(x2y_dl_test, leave=True):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    with torch.no_grad():
        embeddings = x2y_model.model.get_input_embeddings()(input_ids)
        # embeddings = x2c_model.model.get_input_embeddings()(input_ids)
        
    embeddings.requires_grad_(True)
    attention_mask = attention_mask.float()
    attention_mask.requires_grad_(True)
    y = batch['label'].to(device)
    attr = tcav.attribute((embeddings, attention_mask), 'model.roberta.encoder.layer.10.attention.output.dense', 'cosine_similarity', target=y)
    attr = attr.detach().cpu().numpy()
    attrs.append(attr)

attrs = np.concatenate(attrs, axis=0)

np.save('cav/attr_npy_classification.npy', attrs)


attrs = np.load('cav/attr_npy_classification.npy')
# df.shape, attrs.shape
df = df_test.copy()
# another column for the tcav-scores using attrs
# df['tcav_score'] =  np.nan
# df['tcav_score'] = attrs
# df.tcav_score should contain a list of all the tcav scores for each concept
df['tcav_score'] = attrs.tolist()
# reduce the precision of the tcav scores
df.tcav_score = df.tcav_score.apply(lambda x: [round(i, 8) for i in x])
df.tcav_score[0]


Map:   0%|          | 0/1886 [00:00<?, ? examples/s]

Num concepts: 5
Loading `random_CAVs.npz`, `CAVs.npz`, and `metrics.npz` from cache...
Using cached layer names: ['model.roberta.encoder.layer.10.attention.output.dense']


  0%|          | 0/236 [00:00<?, ?it/s]

[0.00142059, -0.00031433, 0.00088187, 0.0011738, -0.00229424]

In [33]:
df.to_csv('analysis_sheets/final_classified_samples_with_concept_gradients.csv', index=False)

In [26]:
df.head()

Unnamed: 0,sentence,target_logits,concept_logits,concept_gradient,label,concept_label,tcav_score
0,"It's an idiotic system offering only one moronic choice. It lacks the integrity to offer a True/False choice, and is essentially meaningless. A lie politely expressed is bogus civility.",[[0.05489601 0.9617731 ]],[[0.99932384 0.23486891 0.756429 0.9986303 0.18379843]],"[[0.04189841, -0.02674945, 0.02142245, -0.00010259, 0.06312781]]",1,[[1 0 0 1 1]],"[0.00142059, -0.00031433, 0.00088187, 0.0011738, -0.00229424]"
1,"I WAS RAPED BY AMY DEMBOSKI\n\nLet's talk about it, either validate or disprove it",[[0.2097058 0.8583711]],[[0.99266475 0.99864525 0.8775576 0.9963966 0.99759716]],"[[0.07840082, -0.13348918, 0.11335696, -0.05296306, -0.08456826]]",1,[[0 1 1 1 1]],"[0.00129855, 0.00147589, -0.00079129, 0.00112493, -0.00903469]"
2,"I guess it's hard to keep an old race-baiting hypocrite like tangreedo down...\nHowever, if he doesn't understand the basic elements of the First Amendment, then he's not fit to hold any office. Let him waste more money running - as long as he's never voted in again.",[[0.13946097 0.9065232 ]],[[0.8774967 0.04580401 0.92773587 0.9996799 0.09558382]],"[[0.06866927, -0.09557968, 0.11821295, 0.14562763, -0.06304058]]",1,[[1 0 0 1 1]],"[-0.00100404, -0.0036003, -0.00078455, -0.00366171, -0.00616127]"
3,"I find it hard to take this congresswoman clown serious, she looks like a transgendered cowboy/girl. What a joker!",[[0.05246165 0.9623934 ]],[[0.9769464 0.59346807 0.9991326 0.99851865 0.2192553 ]],,1,[[1 0 1 1 1]],"[0.00286534, -0.00363649, 0.00037329, 0.00138659, -0.00115875]"
4,"Sounds like a loser. Or at least, an angry loser, lol",[[0.04923106 0.96299654]],[[0.9166289 0.04369685 0.18025956 0.999726 0.12345581]],"[[0.02209041, -0.00520486, 0.00913896, -0.01408261, 0.02067645]]",1,[[1 1 0 1 0]],"[-0.00015554, -0.0024519, -0.00013447, -4.24e-06, -0.00194035]"


In [35]:
df.sample(10)

Unnamed: 0,sentence,target_logits,concept_logits,concept_gradient,label,concept_label,tcav_score
1177,"What is worse, ticks or liberals? They both are blood suckers!",[[0.05387298 0.9611835 ]],[[0.99850523 0.88956136 0.99865794 0.9990465 0.55715215]],"[[0.11190026, -0.00711629, 0.10685235, -0.01391387, -0.0223923]]",1,[[1 0 0 1 1]],"[0.00255218, -0.00270456, 0.00349239, 0.00188917, -0.00032008]"
1238,"Forwards, more likely..............................",[[0.92232734 0.09227892]],[[0.00063929 0.00078733 0.00080577 0.00072127 0.0007089 ]],,0,[[0 0 0 0 0]],"[0.00198666, 0.00461079, 0.00252969, 0.00196735, 0.00063642]"
774,"Remember also what Jesus says about giving scandal, and the millstone.",[[0.92169785 0.10087423]],[[0.00056885 0.00086256 0.00084678 0.00065399 0.00081226]],"[[-0.14696527, 0.02438674, -0.0330714, -0.14015268, -0.05634261]]",0,[[0 0 0 0 0]],"[0.00304911, -0.00279948, -0.00223824, -0.00041029, -0.00215613]"
85,I guess the Ontario liberals and Alberta ndp share stupid ideas.,[[0.05040782 0.9641423 ]],[[0.99807763 0.09881116 0.97711986 0.99898475 0.30146742]],,1,[[1 0 0 1 1]],"[-0.00174784, 0.00041132, 0.00098502, -0.00068852, -0.00106911]"
993,Or buy a one way ticket...,[[0.9293104 0.07693341]],[[0.00051418 0.00064927 0.00082675 0.00082647 0.00081356]],"[[0.17667936, 0.2650457, 0.12466749, 0.14965369, 0.15520136]]",0,[[0 0 0 0 0]],"[0.00550162, 0.00096809, 0.00549318, 0.004867, 0.00254748]"
1746,"You can say that since you have already left Canada, funny thing is you are making a living on Chinese soil.",[[0.8987819 0.11612054]],[[0.00037696 0.00054174 0.0018553 0.00131972 0.00098301]],"[[-0.5158155, -0.2566018, -0.06607198, -0.22067684, -0.2235776]]",0,[[0 0 0 0 0]],"[0.00450565, -0.00639469, -0.00578705, -0.00338063, 0.00805975]"
577,Eddy deserves the hangman's noose! Hang him as high as his clearance and a little more. A traitor is a traitor.,[[0.08954755 0.94145256]],[[0.94592726 0.1085472 0.7380514 0.9967552 0.99913967]],"[[-0.01484637, 0.29299054, 0.13260087, -0.24077478, -0.41638708]]",1,[[0 1 0 1 0]],"[-0.00172819, -0.00179959, 0.00091168, -0.00319312, -0.00376357]"
192,You have history upside down and inside out. Are you talking about the Missiles to Megawatts program where the US bought uranium FROM Russia to power OUR country? \n\nhttps://en.wikipedia.org/wiki/Megatons_to_Megawatts_Program\n\nThat program was widely praised as a swords-to-plowshares type of initiative that was beneficial and helped reduce the aging nuclear stockpile. \n\nIt's as if your facts got put in a blender and you typed what poured out.,[[0.93097913 0.07181426]],[[0.00061594 0.0007757 0.00071952 0.00066045 0.00070036]],"[[-0.46955225, -0.37156785, -0.35248777, -0.62343365, -0.45276725]]",0,[[0 0 0 0 0]],"[0.00162354, 0.00743648, 0.00128449, 0.00285041, 0.00210718]"
88,"OK, you're right, ""x"" the texting. I never text when I drive, regardless of circumstances. But I might hit the speed dial to make a phone call if I'm not in traffic.\n\nMy pet peeve is, ""Why to I have to wait for the light to change when there's no one coming from any direction?"" Common sense is against the law.\n\nThen there's that speed trap in Thorton on Washington between 88th and 84th where the speed limit drops from 45 to 35 for 4-5 blocks with no change in driving condtions. I saw six cars pulled over in that stretch at the same time.",[[0.9369667 0.06480987]],[[0.00062289 0.00083278 0.00070992 0.00062419 0.00071651]],,0,[[0 0 0 0 0]],"[0.00402275, 0.00629554, 0.0015151, 0.00314647, -0.00208538]"
430,"Also, the PM is equating a lone wolf attack with ISIS, etc who commit terrorism, they say, because it's part of their long-term strategy to create a global caliphate. The PM conveniently ignores that difference.",[[0.74431205 0.2967079 ]],[[0.00049197 0.00061045 0.00092101 0.00116962 0.00068518]],"[[-0.3030023, -0.14486581, -0.21068025, -0.15728547, -0.2243253]]",0,[[0 0 0 0 0]],"[0.00380062, 0.00455174, -0.00263239, 0.00392534, 0.00875435]"


All files have been transferred successfully!
