This notebook shows how to calculate the sensitivity of a trained classifier to a human-defined concepts. Here we calculate the sensitivity of the detoxify model at https://huggingface.co/unitary/unbiased-toxic-roberta to the concepts of sentiment polarity. 

In [None]:
pip install transformers  #version transformers-4.24.0

In [4]:
from Roberta_model_data import RobertaClassifier,ToxicityDataset

##load model 

In [None]:
labels = ['toxicity',
'severe_toxicity',
'obscene',
'identity_attack',
'insult',
'threat',
'sexual_explicit',
]

In [None]:
#download the model from https://huggingface.co/unitary/unbiased-toxic-roberta/tree/main
model_path = '/model'
model = RobertaClassifier(model_path)
tokenizer = RobertaTokenizerFast.from_pretrained(model_path)


## functions to calculate TCAV scores

In [5]:
import torch.nn as nn
import numpy as np
import os
import pickle
import torch
from transformers import RobertaTokenizerFast
from torch.utils.data.dataloader import DataLoader

import random

random.seed(100)


with open('data/random_stopword_tweets.txt','r') as f_:
  random_examples= f_.read().split('\n\n')

random_concepts = random_examples[-1000:]


device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


def get_dataloader(X, y, tokenizer, batch_size):
  assert len(X) == len(y)
  encodings = tokenizer(X, truncation=True, padding=True, return_tensors="pt")
  dataset = ToxicityDataset(encodings, y)
  dataloader = DataLoader(dataset, batch_size=batch_size)
  return dataloader

def get_reps(model,tokenizer, concept_examples):
  #returns roberta representations    
  batch_size = 8
  concept_labels = torch.ones([len(concept_examples)]) 
  
  concept_repres = []
  concept_dataloader = get_dataloader(concept_examples,concept_labels,tokenizer,64)
  with torch.no_grad():
    for i_batch, batch in enumerate(concept_dataloader):
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      _, _, representation = model(input_ids, attention_mask=attention_mask)
      concept_repres.append(representation[:,0,:])
  
  concept_repres = torch.cat(concept_repres, dim=0).cpu().detach().numpy()
  #print('concept representation shape', concept_repres.shape)
  #print('concept representation shape', representation[:,0,:].shape)

  return concept_repres

def statistical_testing(model, tokenizer, concept_examples, num_runs=10):
  #calculates CAVs
  cavs = []

  concept_repres = get_reps(model,tokenizer,concept_examples)
  for i in range(num_runs):
    #print(i)
    concept_rep_ids = list(np.random.choice(range(len(concept_repres)), 50))
    concept_rep = [concept_repres[i] for i in concept_rep_ids]
    cavs.append(np.mean(concept_rep, axis = 0))

  return cavs

def get_logits_grad(model, tokenizer, sample, desired_class):
  #returns logits and gradients
  #print(sample)
  input = tokenizer(sample, truncation=True,padding=True, return_tensors="pt")
  model.zero_grad()
  input_ids = input['input_ids'].to(device)
  attention_mask = input['attention_mask'].to(device)
  logits, _, representation = model(input_ids, attention_mask=attention_mask)
  
  logits[0, desired_class].backward()
  #print('cav shape',cav.shape)
  grad = model.grad_representation
  #print('first',grad.shape)
  grad = grad[0][0].cpu().numpy()
    
  return logits,grad

def get_preds_tcavs(model, tokenizer , desired_class = 1,examples_set = 'random',concept_examples = random_concepts, num_runs = 10):
  #returns logits, sensitivies and tcav score
  
  if examples_set=='random':
    examples = random_examples[:2000]   # input examples
  else:
    print('examples are unknown')
    return


  print('calculating cavs...')
  model.to(device)
  concept_cavs = statistical_testing(model,tokenizer, concept_examples, num_runs=num_runs)


  
  print('calculating logits and grads...')
  logits = []
  grads = []
  for sample in examples:
    logit,grad = get_logits_grad(model, tokenizer, sample, desired_class)
    grads.append(grad)
    logits.append(logit)
    data ={'grads':grads,
          'logits':logits}
    
   
   
  sensitivities = [] 
  for grad in grads:
    sensitivities.append([np.dot(grad, cav) for cav in concept_cavs])
  sensitivities = np.array(sensitivities)
  tcavs = []
  for i in range(num_runs):
    tcavs.append(len([s for s in sensitivities[:,i] if s>0])/len(examples))
   
  print('TCAV score for the concept: ')
  print(np.mean(tcavs),np.std(tcavs)) 
  
  return logits, sensitivities, tcavs

## sentiment concept examples

In [8]:
import pandas as pd
data = pd.read_table('data/VAD-V-ADJ-filtered.txt') 

In [9]:
data

Unnamed: 0,Word,Sentiment category,Valence score
0,happy,very pos,1.000
1,generous,very pos,1.000
2,magnificent,very pos,1.000
3,cheerful,very pos,0.980
4,passionate,very pos,0.980
...,...,...,...
495,criminal,very neg,-0.958
496,dangerous,very neg,-0.960
497,corrupt,very neg,-0.960
498,afraid,very neg,-0.980


In [10]:
data['Sentiment category'].unique()

array(['very pos', 'pos', 'neutral', 'neg', 'very neg'], dtype=object)

In [11]:
low_sent = data[data['Sentiment category']=='very neg']['Word'].to_list()
print(len(low_sent))
midlow_sent = data[data['Sentiment category']=='neg']['Word'].to_list()
print(len(midlow_sent))
neut_sent = data[data['Sentiment category']=='neutral']['Word'].to_list()
print (len(neut_sent))
midhigh_sent = data[data['Sentiment category']=='pos']['Word'].to_list()
print(len(midhigh_sent))
high_sent = data[data['Sentiment category']=='very pos']['Word'].to_list()
print(len(high_sent))
#concept = ['They are '+ item  +'.' for item in low_sent]

100
100
100
100
100


#Results 

In [15]:
targets = ['Women','Trans people','Gay people','Black people','Disabled people','Muslims','Immigrants','These people','These things']
target = targets[8]

In [16]:
target

'These things'

##sentiment concepts

#### label #1: toxicity

In [19]:
#toxicity
for i,sent_level in enumerate([low_sent,midlow_sent,neut_sent,midhigh_sent,high_sent]):
  concept = [target+' are '+ item  +'.' for item in sent_level]
  level_names = ['low_sent','midlow_sent','neut_sent','midhigh_sent','high_sent']
  print('\n','-'*10,level_names[i],'-'*10,'\n')
  logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 0,examples_set = 'random',concept_examples = concept,num_runs=100)



 ---------- low_sent ---------- 

calculating cavs...
calculating logits and grads...
TCAV score for the concept: 
0.000105 0.0009957283766168367

 ---------- midlow_sent ---------- 

calculating cavs...
calculating logits and grads...
TCAV score for the concept: 
0.0 0.0

 ---------- neut_sent ---------- 

calculating cavs...
calculating logits and grads...
TCAV score for the concept: 
0.0 0.0

 ---------- midhigh_sent ---------- 

calculating cavs...
calculating logits and grads...
TCAV score for the concept: 
0.0 0.0

 ---------- high_sent ---------- 

calculating cavs...
calculating logits and grads...
TCAV score for the concept: 
0.0 0.0


#### label #2: severe toxicity

In [None]:
#severe toxicity 
for i,sent_level in enumerate([low_sent,midlow_sent,neut_sent,midhigh_sent,high_sent]):
  concept = [target+' are '+ item  +'.' for item in sent_level]
  level_names = ['low_sent','midlow_sent','neut_sent','midhigh_sent','high_sent']
  print('\n','-'*10,level_names[i],'-'*10,'\n')
  logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 1,examples_set = 'random',concept_examples = concept,num_runs=1000)

#### label #3: obscene

In [None]:

#obscene
for i,sent_level in enumerate([low_sent,midlow_sent,neut_sent,midhigh_sent,high_sent]):
  concept = [target+' are '+ item  +'.' for item in sent_level]
  level_names = ['low_sent','midlow_sent','neut_sent','midhigh_sent','high_sent']
  print('\n','-'*10,level_names[i],'-'*10,'\n')
  logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 2,examples_set = 'random',concept_examples = concept,num_runs=1000)

#### label #4: identity attack

In [None]:
#identity attack
for i,sent_level in enumerate([low_sent,midlow_sent,neut_sent,midhigh_sent,high_sent]):
  concept = [target+' are '+ item  +'.' for item in sent_level]
  level_names = ['low_sent','midlow_sent','neut_sent','midhigh_sent','high_sent']
  print('\n','-'*10,level_names[i],'-'*10,'\n')
  logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 3,examples_set = 'random',concept_examples = concept,num_runs=1000)

#### label #5: insult

In [None]:
#insult 
for i,sent_level in enumerate([low_sent,midlow_sent,neut_sent,midhigh_sent,high_sent]):
  concept = [target+' are '+ item  +'.' for item in sent_level]
  level_names = ['low_sent','midlow_sent','neut_sent','midhigh_sent','high_sent']
  print('\n','-'*10,level_names[i],'-'*10,'\n')
  logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 4,examples_set = 'random',concept_examples = concept,num_runs=1000)

#### label #6: threat

In [None]:
#threat
for i,sent_level in enumerate([low_sent,midlow_sent,neut_sent,midhigh_sent,high_sent]):
  concept = [target+' are '+ item  +'.' for item in sent_level]
  level_names = ['low_sent','midlow_sent','neut_sent','midhigh_sent','high_sent']
  print('\n','-'*10,level_names[i],'-'*10,'\n')
  logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 5,examples_set = 'random',concept_examples =concept,num_runs=1000)

#### label #7: sextual explicit

In [None]:
#sexual explicit
for i,sent_level in enumerate([low_sent,midlow_sent,neut_sent,midhigh_sent,high_sent]):
  concept = [target+' are '+ item  +'.' for item in sent_level]
  level_names = ['low_sent','midlow_sent','neut_sent','midhigh_sent','high_sent']
  print('\n','-'*10,level_names[i],'-'*10,'\n')
  logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 6,examples_set = 'random',concept_examples = concept,num_runs=1000)

## Non-coherent concepts


In [None]:
#toxicity
concept = [target + ' are '+ item  +'.' for item in low_sent+high_sent]
logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 0,examples_set = 'random',concept_examples = concept,num_runs=100)

In [None]:
#insult
concept = [target + ' are '+ item  +'.' for item in low_sent+high_sent]
logits, sensitivity, TCAV = get_preds_tcavs(model, tokenizer,desired_class = 4,examples_set = 'random',concept_examples = concept,num_runs=100)