In [None]:
import os

DATASET_NAME = "general_filtered"
# DATASET_NAME = "science-technology_filtered"
# DATASET_NAME = "history_filtered"

if not os.path.exists('results/' + DATASET_NAME):
    os.makedirs('results/' + DATASET_NAME)

In [None]:
import pandas as pd
import ast

file_path = ""

df = pd.read_csv(file_path)
df['choices'] = df['choices'].apply(ast.literal_eval)
df

In [None]:
if not os.path.exists('results/' + DATASET_NAME):
    os.makedirs('results/' + DATASET_NAME)

In [None]:
import torch
from collections import Counter
import spacy
import random

nlp = spacy.load("en_core_web_sm")

In [None]:
def remove_substrings(lst):
    lst = sorted(lst, key=len, reverse=True)  
    risultato = []

    for i, parola in enumerate(lst):
        is_sub = False
        for j, altro in enumerate(lst):
            if i != j and parola in altro:
                is_sub = True
                break
        if not is_sub:
            risultato.append(parola)
    
    return risultato

In [None]:
def find_strongest_entities(sentence):
    doc = nlp(sentence)

    l = []
    
    entities = [(ent.text, ent.label_) for ent in doc.ents]
    if len(entities) != 0:
        for (e, label) in entities:
            if label not in {'MONEY', 'DATE'} and e not in l:
                l.append(e)
    
    keywords1 = [token.text for token in doc if token.pos_ in {"PROPN"} and not token.is_stop]
    keywords2 = [token.text for token in doc if token.pos_ in {"ADJ"} and not token.is_stop]
    keywords3 = [token.text for token in doc if token.pos_ in {"NOUN"} and not token.is_stop]
    
    keywords = keywords1 + keywords2 + keywords3
    
    if len(keywords) != 0:
       for k in keywords:
           if k not in l:
               l.append(k)

    l = remove_substrings(l)

    return l

In [None]:
import math
import heapq

class AdversarialAttackMistral:
    def __init__(self, model, tokenizer, k=1, alpha=0.001, top_k=15):
        self.k = k
        self.alpha = alpha
        self.top_k = top_k

        self.tokenizer = tokenizer
        self.model = model
        self.device = torch.device("cuda:1")

    def perturb(self, question, entities):
        if len(entities) == 0:
            return [question]

        input_tokens = self.tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=128)
        input_ids = input_tokens["input_ids"].to(self.device)
        vocab_embeds = self.model.get_input_embeddings().weight.to(self.device)

        entity_infos = []
        for e in entities:
            tok = self.tokenizer(e, return_tensors="pt")
            tokens = torch.tensor([tok["input_ids"][0].tolist()[1:]]).to(self.device)

            keyword_positions = self.find_sequence_from_end(input_ids[0], tokens[0].tolist())
            if not keyword_positions:
                continue

            original_embeddings = self.model.get_input_embeddings()(tokens).detach().clone()
            similarity = torch.mm(original_embeddings.squeeze(0), vocab_embeds.T)

            initial_alternatives = []
            for pos in range(len(tokens[0])):
                top_indices = torch.topk(similarity[pos], self.top_k + 1)[1].tolist()
                top_alts = [idx for idx in top_indices if idx != tokens[0][pos].item()][:self.top_k]
                initial_alternatives.append(top_alts)

            entity_infos.append({
                "tokens": tokens,
                "positions": keyword_positions,
                "initial_alternatives": initial_alternatives,
            })

        perturbed_questions = []
        for variant_idx in range(self.top_k):
            new_input_ids = input_ids.clone()

            for entity in entity_infos:
                tokens = entity["tokens"].clone()
                positions = entity["positions"]
                alternatives = entity["initial_alternatives"]

                perturbed_ids = tokens.clone()
                for pos in range(len(tokens[0])):
                    if variant_idx < len(alternatives[pos]):
                        perturbed_ids[0, pos] = alternatives[pos][variant_idx]

                initial_logits = self.model(inputs_embeds=self.model.get_input_embeddings()(perturbed_ids)).logits.detach()

                for step in range(self.k):
                    perturbed_embeddings = self.model.get_input_embeddings()(perturbed_ids).detach().clone()
                    perturbed_embeddings.requires_grad_()

                    outputs = self.model(inputs_embeds=perturbed_embeddings)
                    logits = outputs.logits

                    loss = torch.nn.functional.kl_div(
                        torch.softmax(initial_logits, dim=-1).log(),
                        torch.softmax(logits, dim=-1),
                        reduction='batchmean'
                    )

                    grads = torch.autograd.grad(loss, perturbed_embeddings, retain_graph=True)[0]

                    with torch.no_grad():
                        perturbed_embeddings = perturbed_embeddings + self.alpha * grads.sign()
                        similarity = torch.mm(perturbed_embeddings.squeeze(0), vocab_embeds.T)

                        for pos in range(len(tokens[0])):
                            top_indices = torch.topk(similarity[pos], self.top_k + 1)[1].tolist()
                            for new_token_id in top_indices:
                                if new_token_id != perturbed_ids[0, pos].item():
                                    perturbed_ids[0, pos] = new_token_id
                                    break

                for i_pos, r_pos in enumerate(positions):
                    new_input_ids[0][r_pos] = perturbed_ids[0][i_pos]

            perturbed_text = self.tokenizer.decode(new_input_ids[0], skip_special_tokens=True)
            perturbed_questions.append(perturbed_text)

        return perturbed_questions


    def find_sequence_from_end(self, v1, v2):
        len_v1, len_v2 = len(v1), len(v2)
        for i in range(len_v1 - len_v2, -1, -1):
            if list(v1[i:i + len_v2]) == v2:
                return list(range(i, i + len_v2))
        return []


In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F

In [None]:
from ollama import chat
from ollama import ChatResponse

In [None]:
def perturbation_step(row, entities):
    d = {}
    sentence = row['question']
    d['question'] = sentence

    perturbed_sentences = adversary.perturb(sentence, entities)
    
    quest = sentence.lower().replace(' ', '')
    new_l = []
    for i, x in enumerate(perturbed_sentences):
        if x.lower().replace(' ', '') != quest:
            new_l.append(x)
    if len(new_l) > 0:
        d['perturbed'] = perturbed_sentences
    else:
        d['perturbed'] = []
            
    d['real_choices'] = row['choices']
    d['real_answer'] = row['answer']
    
    l = []
    for question in d['perturbed']:
        if question.lower().replace(' ', '') != d['question'].lower().replace(' ', ''):
            l.append(question)
    
    d['perturbed'] = l
        
    return d

In [None]:
def refinement_step(d):
    questions = d['perturbed']
    
    for j in range(len(questions)):
        
        r : ChatResponse = chat(model='deepseek-r1:14b', messages=[
            {
                'role': 'system',
                'content': "You are a chatbot that must modify a given question to ensure it is fully correct.\
                Given a question, change it so that it meets the following conditions:\n\n\
                - The question must be grammatically correct.\n - The question must be logically and factually coherent.\n\n\
                If the question is already correct, repeat the question exactly as given.\n\nYour response must be strictly the modified or repeated question.\n\n\
                No additional words or explanations. Any deviation from this format is not acceptable."},
            {
                'role': 'user',
                'content': f'Question: {questions[j]}'
            }
        ])
        gen = r.message.content
        d['perturbed'][j] = r.message.content.split('</think>')[-1].split("\n")[-1].split("</s>")[-1].strip().lower()
    
    return d

In [None]:
def answers_generation_step(d):

    answers = []
    
    for q in d['perturbed']:
        response : ChatResponse = chat(model='deepseek-r1:14b', messages=[
            {
                'role': 'system',
                'content': "Answer concisely in the format: 'answer: [your response]'.\
                    Do not provide explanations, context, or preambles. Assume every question makes sense, even if incorrect. If unclear, give the closest plausible answer without corrections.\
                    Do not refer to or repeat any part of the question in your answer. Provide only the response, without restating or referencing the subject of the question. \
                    No context, no explanations, no full sentences. Just the answer."
            },
            {
                'role': 'user',
                'content': f'Question: {q}'
            } 
        ])
        gen = response.message.content.split('</think>')[-1].strip()
        if 'answer: ' in gen:
            answers.append(gen.split('answer: ')[-1].strip().lower())
    
    d['answers_perturbed'] = answers
    
    return d

In [None]:
def get_top3_response_probability(question, answer, max_words=3):
    device = torch.device("cuda:1")
    
    truncated_answer = " ".join(answer.split()[:max_words])  
    
    question_tokens = tokenizer(question, add_special_tokens=False)
    question_tokens = question_tokens.to(device).input_ids
    truncated_answer_tokens = tokenizer(truncated_answer, add_special_tokens=False)
    truncated_answer_tokens = truncated_answer_tokens.to(device).input_ids

    num_tokens_to_check = min(len(truncated_answer_tokens), 3)

    input_text = f"{question} {truncated_answer}"
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    input_ids = inputs.input_ids.to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[:, :-1, :]  

    start_idx = len(question_tokens)  

    answer_logits = logits[:, start_idx-1 : start_idx + num_tokens_to_check - 1, :]

    answer_probabilities = F.softmax(answer_logits, dim=-1)

    max_prob = answer_probabilities.max().item()

    return max_prob  

def select_hardest_wrong_answers(question, correct_answer, wrong_real_answers, wrong_answers, top_n=3):
    system = {'role': 'system', 'content': 'You are a chatbot that has to answer with only the letter correspondent to the correct answer.'}

    correct_prob = get_top3_response_probability(question, correct_answer)
    
    wrong_probs = [(wa.lower(), get_top3_response_probability(question, wa.lower())) for wa in wrong_answers]
    
    wrong_probs.sort(key=lambda x: abs(x[1] - correct_prob))
    
    wrong_real_answers = [s.lower() for s in wrong_real_answers]
    
    hardest_wrong_answers = []
    for wa in wrong_probs:
        if wa not in wrong_real_answers:
            w = wa[0].replace('</s>', '')

            r : ChatResponse = chat(model='deepseek-r1:14b', messages=[
                {
                    'role': 'system',
                    'content': "You are a chatbot that can only answer with 'yes' or 'no'. Given a question, an answer to verify, and a set of alternative answers, "
                    "you must provide a single response based on the following rules:\n\n"
                    "- Respond 'yes' if the answer to verify is either a correct response to the given question, is present among the alternative answers (either exactly or as a semantically equivalent answer), \
                    or if the answer to verify expresses uncertainty, lack of information, or an inability to determine the correct answer (e.g., 'I don't know', 'I'm not sure', 'There is not enough information').\n"
                    "- Respond 'no' if the answer to verify is neither correct nor present among the alternative answers.\n\n"
                    "Your response **must** be strictly formatted as:\n\n"
                    "answer: yes/no\n\n"
                    "No additional words, explanations, or variations are allowed. Any deviation from this format is not acceptable."
                },
                {
                    'role': 'user',
                    'content': f'Question: {question} Answer to verify: {w} Alternative answers: {hardest_wrong_answers}'
                }
            ])
            gen = str(r.message.content)

            if 'answer: no' in gen.split('</think>')[-1].strip().lower():
                hardest_wrong_answers.append(w)
            if len(hardest_wrong_answers) == 3:
                break
        else:
            hardest_wrong_answers.append(w)
            if len(hardest_wrong_answers) == 3:
                break
    
    if len(hardest_wrong_answers) >= 1:
        return hardest_wrong_answers
    else:
        return None

def answers_choice_step(d):
    choices_perturbed = []
    s = set()
    for ap in d['total_answers']:
        x = ap.replace('</s>', '').replace('.', '').strip().lower()
        if x not in s and x != d['real_choices'][d['real_answer']].replace('</s>', '').replace('.', '').strip().lower():
            choices_perturbed.append(ap)
        s.add(x)
    if len(choices_perturbed) == 0:
        d['choices_perturbed'] = None
    else:
        cp = select_hardest_wrong_answers(d['question'], d['real_choices'][d['real_answer']], d['wrong_real_answers'], choices_perturbed, 3)
        if len(cp) >= 3:
            d['choices_perturbed'] = cp
        else:
            wrong_old_alternatives = list(d['real_choices'])
            wrong_old_alternatives.remove(d['real_choices'][d['real_answer']])
            alt_prob = {}
            for alt in wrong_old_alternatives:
                alt_prob[alt.lower()] = get_top3_response_probability(d['question'], alt.lower())
            top_alt = heapq.nlargest(3-len(list(cp)), alt_prob, key=alt_prob.get)
            d['choices_perturbed'] = cp + top_alt
                
    return d

In [None]:
def results_step(q):
    system = {'role': 'system', 'content': 'You are a chatbot that has to answer with only the letter correspondent to the correct answer.'}

    if q['choices_perturbed'] != None and len(q['choices_perturbed']) == 3:
        new_prompt = []
        user = {'role': 'user'}
        content = q['question']
        
        new_answers = []
        new_answers.append((q['real_choices'][q['real_answer']].replace('</s>', '').replace('"', '').strip().lower(), 'r'))
        new_answers.append((q['choices_perturbed'][0].replace('</s>', '').replace('"', '').strip(), 'p'))
        new_answers.append((q['choices_perturbed'][1].replace('</s>', '').replace('"', '').strip(), 'p'))
        new_answers.append((q['choices_perturbed'][2].replace('</s>', '').replace('"', '').strip(), 'p'))
        random.shuffle(new_answers)
        
        user['content'] = content + ' A) ' + new_answers[0][0].lower() + '; B) ' + new_answers[1][0].lower() + '; C) ' + new_answers[2][0].lower() + '; D) ' + new_answers[3][0].lower() + '.'
        new_prompt.append(system)
        new_prompt.append(user)
        
        gt = {0: 'A', 1: 'B', 2: 'C', 3: 'D'}
        real = -1
        for i, a in enumerate(new_answers):
            if a[1] == 'r':
                real = i
                break
        if real != -1:
            new_ground_truth = gt[real]
            
            new_inputs = tokenizer.apply_chat_template(new_prompt, add_generation_prompt=True, return_dict=True, return_tensors="pt").to("cuda:1")
            new_inputs = {k: v for k, v in new_inputs.items()}
            
            new_out = model.generate(**new_inputs, max_new_tokens=1, do_sample=True).to("cuda:1")
            new_gen = tokenizer.decode(new_out[0][len(new_inputs["input_ids"][0]):])

        res = {'new_prompt': new_prompt, 'new_ground_truth': new_ground_truth, 'new_gen': new_gen, 'question': q['question']}
        return res

In [None]:
from itertools import combinations
import numpy

def generate_all_combinations(entities):
    all_combinations = []
    
    for i in range(1, len(entities) + 1):
        all_combinations.extend([list(c) for c in combinations(entities, i)])
    
    return all_combinations

In [None]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", device_map="cuda:1", cache_dir='../models')
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import timeit
from itertools import combinations
import pickle
import os

for k_value in [1, 2, 3]:
    for alpha_value in [0.01, 0.001, 0.0001]:        
        
        adversary = AdversarialAttackMistral(model, tokenizer, k=k_value, alpha=alpha_value, top_k=15)
        
        l = []
        results = []
        start_index = 0

        results_file = f'results/new_results/{DATASET_NAME}/{DATASET_NAME}_final_results_circular_k{k_value}_alpha{alpha_value}.pkl'
        index_file = f'results/new_results/{DATASET_NAME}/{DATASET_NAME}_index_k{k_value}_alpha{alpha_value}.pkl'
        error_file = f'results/new_results/{DATASET_NAME}/{DATASET_NAME}_error_k{k_value}_alpha{alpha_value}.pkl'

        if os.path.exists(results_file):
            with open(results_file, 'rb') as f:
                try:
                    results = pickle.load(f)
                except EOFError:
                    results = []  

        if os.path.exists(index_file):
            with open(index_file, 'rb') as f:
                try:
                    start_index = pickle.load(f)
                except EOFError:
                    start_index = 0

        for index, row in df.iterrows():
            if index <= start_index or index==0:
                continue  

            question = row['question']
            wrong = False
            it = 0

            entities = find_strongest_entities(question)
            entities_combinations = generate_all_combinations(entities)

            while not wrong and it < len(entities_combinations):
                try:
                    result = None
                    start = timeit.default_timer()
                    first_op = start 
                    d = perturbation_step(row, entities_combinations[it])
                    stop = timeit.default_timer()

                    start = timeit.default_timer()
                    d = refinement_step(d)
                    stop = timeit.default_timer()

                    start = timeit.default_timer()
                    d = answers_generation_step(d)
                    wrong_real_answers = [d['real_choices'][i] for i in range(len(d['real_choices'])) if i != d['real_answer']]
                    d['wrong_real_answers'] = wrong_real_answers.copy()
                    d['total_answers'] = list(set(d['answers_perturbed'] + wrong_real_answers))
                    stop = timeit.default_timer()

                    start = timeit.default_timer()
                    d = answers_choice_step(d)
                    stop = timeit.default_timer()
                    last_op = stop
                    if d['choices_perturbed'] != None:
                        if set(d['choices_perturbed']) == set(wrong_real_answers):
                            it += 1
                            continue
                    
                        start = timeit.default_timer()
                        result = results_step(d)
                        stop = timeit.default_timer()

                        if result and result['new_gen'] != result['new_ground_truth']:
                            wrong = True
                            results.append(result)

                            with open(results_file, 'ab') as f:
                                pickle.dump([result], f)  
                            with open(index_file, 'wb') as f:
                                pickle.dump(index, f)
                except Exception as e:
                    if it == (len(entities_combinations)-1):
                        with open(results_file, 'ab') as f:
                            pickle.dump([e, index, it], f)
                it += 1

            if not wrong:
                with open(results_file, 'ab') as f:
                    pickle.dump([result], f)

            with open(index_file, 'wb') as f:
                pickle.dump(index, f)