In [None]:
import json
import torch
import numpy as np
from transformers import RobertaTokenizer, BertTokenizer
from utils.vis_utils import vis_concepts
from utils.evaluate_utils import all_evaluate
from utils.concept_utils import concept_sum, concept_sim
from models.AttenConcept import AttenConcept

config = {
    'base_model_name': 'roberta',
    'task': 'IMDB',
    'dist_weight': -0.01,
    'con_weight': 0.1,
    'lr': 0.00001,
    'num_classes': 2,
    'num_concepts': 20,
    'batch_size': 256,
    'max_len': 256,
    'num_cases': 10,
    'data_path': './datasets/',
    'bert_path': './pre-trained-models/',
    'checkpoint_path': './checkpoints/',
    'prompt_path': './prompts/',
    'llm_type': 'gpt-4o'
}

if config['base_model_name'] == 'roberta':
    tokenizer = RobertaTokenizer.from_pretrained(config["bert_path"])
if config['base_model_name'] == 'bert':
    tokenizer = BertTokenizer.from_pretrained(config["bert_path"])
model = AttenConcept(num_classes=config['num_classes'],
                     num_concepts=config['num_concepts'],
                     dist_weight=config['dist_weight'],
                     con_weight=config['con_weight'],
                     base_model_name=config['base_model_name'],
                     bert_path=config['bert_path'],
                     max_len=config['max_len'])
model.load_state_dict(torch.load(config['checkpoint_path']))
model.cuda()
model = model.eval()
weights = list(model.classifier.parameters())[0].cpu().detach().numpy()

In [None]:
# Testing ____________________________________________
_, _, _, _ = all_evaluate(model, config['task'], config['num_classes'], config['batch_size'], 'test', config['bert_path'], config['base_model_name'], max_len=config['max_len'])

In [None]:
# Get features from training set ____________________________________________
all_contents, all_slot_attn, y_true, y_pred = all_evaluate(model, config['task'], config['num_classes'], config['batch_size'], 'train', config['bert_path'], config['base_model_name'], max_len=config['max_len'])

In [None]:
slot_attention = np.array(all_slot_attn)
input_logits = np.sum(slot_attention, axis=-1)
n, _ = input_logits.shape
contributions = []
for i in range(n):
    contributions.append(input_logits[i] * weights)
contributions = np.array(contributions)

concept_logits = np.sum(slot_attention, axis=-1)
cases_ids = {}
for concept_id in range(config['num_concepts']):
    filtered_lst = [(i, x) for i, x in enumerate(concept_logits[:, concept_id])]
    indexed_lst = sorted(filtered_lst, key=lambda x: x[1], reverse=True)
    cases_ids[concept_id] = [item[0] for item in indexed_lst][:config['num_cases']]

concept_sum(all_contents, all_slot_attn, tokenizer, config)
concept_sim(all_contents, all_slot_attn, tokenizer, config)

summary_list = []
with open(config['prompt_path'] + "summary.json", 'r') as f:
    summary = json.load(f)
for item in summary:
    summary_list.append(item['Summary'])

In [None]:
# Visualization ____________________________________________
concepts_html = vis_concepts(all_contents,
                            np.array(all_slot_attn),
                            contributions,
                            summary_list,
                            y_true,
                            y_pred,
                            cases_ids,
                            tokenizer,
                            config['vis_threshold'])