In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
import json

In [2]:
base_model = "/home/g4/Llama-3.2-3B-Instruct"
new_model="/home/g4/Llama-3.2-3B-Instruct-Finetuned"
def load_llama_model():
    """Load LLaMA model and tokenizer"""
    model_name = base_model
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    if model.config.pad_token_id is None:
        model.config.pad_token_id = model.config.eos_token_id

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    model = model.to(device)


    return model, tokenizer, device

model, tokenizer, device = load_llama_model()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using device: cuda


In [3]:
distortions = {
    "fortune_telling": "Fortune Telling/Catastrophizing",
    "what_if_statements": "What If Statements",
    "labeling": "Labeling/Global Labeling",
    "unfair_comparisons": "Unfair Comparisons",
    "mind_reading_conclusion": "Mind Reading",
    "should_statements_conclusion": "Should Statements",
    "overgeneralization": "Overgeneralization",
    "all_or_nothing_thinking": "All or Nothing Thinking",
    "blaming_conclusion": "Blaming",
    "emotional_reasoning": "Emotional Reasoning",
    "mental_filter": "Mental Filter",
    "discounting_the_positive": "Discounting the Positive",
    "magnification_minimization": "Magnification/Minimization",
    "personalization_conclusion": "Personalization",
    "jumping_to_conclusions": "Jumping to Conclusions",
    "no_distortion": "No Distortion"
}

name_to_node = {
    "Fortune Telling": "fortune_telling",
    "What if?": "what_if_statements",
    "Labeling/Global Labeling": "labeling",
    "Unfair Comparisons": "unfair_comparisons",
    "Mind Reading": "mind_reading_conclusion",
    "Should Statements": "should_statements_conclusion",
    "Overgeneralization": "overgeneralization",
    "All or Nothing Thinking": "all_or_nothing_thinking",
    "Blaming": "blaming_conclusion",
    "Emotional Reasoning": "emotional_reasoning",
    "Mental Filter": "mental_filter",
    "Discounting the Positive": "discounting_the_positive",
    "Magnification/Minimization": "magnification_minimization",
    "Personalization": "personalization_conclusion",
    "Jumping to Conclusions": "jumping_to_conclusions",
    "No Distortion": "no_distortion"
}

# Define the decision tree structure
decision_tree = {
    "root": {
        "question": "Does the text provided show any signs of cognitive distortions?",
        "yes": "prediction",
        "no": "no_distortion"
    },
    "prediction": {
        "question": "Does the text provided show any predictions about future events?",
        "yes": "terrible_future",
        "no": "judgments"
    },
    "terrible_future": {
        "question": "Does the text provided show predictions of terrible or catastrophic outcomes that seem unbearable?",
        "yes": "fortune_telling",
        "no": "what_if"
    },
    "what_if": {
        "question": "Does the text provided show repeated 'what if' questioning about potential scenarios?",
        "yes": "what_if_statements",
        "no": "judgments"
    },
    "judgments": {
        "question": "Does the text provided show judgments being made about self or others?",
        "yes": "fixed_labels",
        "no": "assumptions"
    },
    "fixed_labels": {
        "question": "Does the text provided show fixed, broad labels being applied to self or others?",
        "yes": "labeling",
        "no": "comparisons"
    },
    "comparisons": {
        "question": "Does the text provided show unfavorable comparisons to others?",
        "yes": "unfair_comparisons",
        "no": "assumptions"
    },
    "assumptions": {
        "question": "Does the text provided show assumptions about thoughts or intentions?",
        "yes": "mind_reading",
        "no": "should_statements"
    },
    "mind_reading": {
        "question": "Does the text provided show assumptions about others' thoughts without clear evidence?",
        "yes": "mind_reading_conclusion",
        "no": "should_statements"
    },
    "should_statements": {
        "question": "Does the text provided show usage of words like 'should', 'must', or 'ought to'?",
        "yes": "should_statements_conclusion",
        "no": "extremes"
    },
    "extremes": {
        "question": "Does the text provided show extreme or black-and-white thinking?",
        "yes": "always_never",
        "no": "responsibility"
    },
    "always_never": {
        "question": "Does the text provided show usage of words like 'always', 'never', or 'every time'?",
        "yes": "overgeneralization",
        "no": "two_categories"
    },
    "two_categories": {
        "question": "Does the text provided show the situation being viewed in only two extreme categories?",
        "yes": "all_or_nothing_thinking",
        "no": "responsibility"
    },
    "responsibility": {
        "question": "Does the text provided show assignment of responsibility for feelings or events?",
        "yes": "blaming",
        "no": "emotional_state"
    },
    "blaming": {
        "question": "Does the text provided show someone being entirely blamed for feelings or experiences?",
        "yes": "blaming_conclusion",
        "no": "personalization"
    },
    "personalization": {
        "question": "Does the text provided show assumption that external events or others' behaviors are personally directed?",
        "yes": "personalization_conclusion",
        "no": "emotional_state"
    },
    "emotional_state": {
        "question": "Does the text provided show emotional reasoning or selective focus?",
        "yes": "emotions_evidence",
        "no": "conclusions"
    },
    "emotions_evidence": {
        "question": "Does the text provided show emotions being used as evidence for reality?",
        "yes": "emotional_reasoning",
        "no": "selective_focus"
    },
    "selective_focus": {
        "question": "Does the text provided show focus on specific aspects while ignoring others?",
        "yes": "negative_focus",
        "no": "conclusions"
    },
    "negative_focus": {
        "question": "Does the text provided show exclusive focus on negative details?",
        "yes": "mental_filter",
        "no": "positive_discount"
    },
    "positive_discount": {
        "question": "Does the text provided show dismissal or disqualification of positive experiences?",
        "yes": "discounting_the_positive",
        "no": "magnification"
    },
    "magnification": {
        "question": "Does the text provided show exaggeration of negatives or minimization of positives?",
        "yes": "magnification_minimization",
        "no": "conclusions"
    },
    "conclusions": {
        "question": "Does the text provided show conclusions being drawn with little or no supporting evidence?",
        "yes": "jumping_to_conclusions",
        "no": "no_distortion"
    },
}

distortion_definitions = """All or Nothing Thinking/Polarized Thinking: I view a situation, a person or an event in “either-or” terms, fitting them into only two extreme categories instead of on a continuum.
Fortune telling (also called catastrophizing): I predict the future in negative terms and believe that what will happen will be so awful that I will not be able to stand it.
Emotional reasoning:  I believe my emotions reflect reality and let them guide my attitudes and judgments.
Labeling/Global Labeling: I put a fixed, global label, usually negative, on myself or others.
Mental Filter(): I pay attention to one or a few details and fail to see the whole picture.
Mind reading: I believe that I know the thoughts or intentions of others (or that they know my thoughts or intentions) without having sufficient evidence
Overgeneralization: I take isolated negative cases and generalize them, transforming them in a never-ending pattern, by repeatedly using words such as “always”, “never”, “ever”, “whole”, “entire”, etc
Personalization: I assume that others’ behaviors and external events concern (or are directed to) myself without considering other plausible explanations.
Should statements (also “musts”, “oughts”, “have tos”): I tell myself that events, people’s behaviors, and my own attitudes “should” be the way I expected them to be and not as they really are.
Blaming (others or oneself): I direct my attention to others as sources of my negative feelings and experiences, failing to consider my own responsibility; or, conversely, I take responsibility for others’ behaviors and attitudes.
What if?: I keep asking myself questions such as “what if something happens?”
Discounting the positive: I disqualify positive experiences or events insisting that they do not count.
Magnification/minimization: I evaluate myself, others, and situations placing greater importance on the negatives and/or placing much less importance on the positives.
Jumping to conclusions (also called arbitrary inference): I draw conclusions (negative or positive) from little or no confirmatory evidence.
Unfair comparisons: I compare myself with others who seem to do better than I do and place myself in a disadvantageous position.
"""
system_message = """You are a helpful assistant that can only respond with "Yes" or "No". Follow these rules:
1. You must answer with exactly "Yes" or "No" (case-sensitive)
2. Do not add any other words or punctuation
3. If you are not completely certain, answer "No"
4. These are the only two valid responses: "Yes" or "No\""""

def is_leaf_node(node_id: str) -> bool:
    """Check if a node is a leaf node (final distortion)."""
    return isinstance(decision_tree[node_id], str)

def get_distortion_description(distortion: str) -> str:
    """Get the description of a specific distortion."""
    return distortions.get(distortion, "Unknown distortion")

In [4]:
def make_llama(text: str) -> str:
    return [{"role": "system", "content": system_message,}, {"role": "user", "content": text},]
def make_example(text: str, answer: str):
    return make_llama(text) + [{"role": "assistant", "content": answer}]
def make_question(text: str, question: str):
    return f'Given the following text:\n\n{text}\n\n Answer the question: {question}\n\n'

In [5]:
def generate_response(text: str, count) -> str:
    messages = make_llama(text)
    inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_dict=True, add_generation_prompt=True, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    outputs = model.generate(
        **inputs,
        max_new_tokens=count,
        temperature=0.1,
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,
    )
    return tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):])

In [6]:
def find_distortion_paths(tree):
    distortion_paths = {}
    
    def traverse(node_key, current_path):
        if node_key in distortions:
            distortion_paths[node_key] = current_path
            return
        
        node = tree[node_key]      
        traverse(node.get('yes'), current_path + ['Yes'])
        traverse(node.get('no'), current_path + ['No'])
    
    traverse('root', [])
    distortion_paths.pop('no_distortion')
    return distortion_paths

paths = find_distortion_paths(decision_tree)

def create_path_training_data(json_data):
    examples = []
    for item in json_data:
        text = item["generated_text"]
        label = item["cognitive_distortion"]
        path = paths[name_to_node[label]]
        node = 'root'
        
        for i, opt in enumerate(path):
            examples.append(make_example(make_question(text, decision_tree[node]['question']), opt))
            node = decision_tree[node][opt.lower()]
    return examples

with open('/home/g4/Mindwell/data/2024-11-27_16:12:43.json') as f:
    examples = create_path_training_data(json.load(f))
