In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
# Set the current working directory to the project root
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
os.chdir(ROOT_DIR)

In [2]:
from src.data_management.loaders import load_labeled_df


dataset = load_labeled_df("phase0_baseline_labeled.parquet")

# Finetuning the Qwenner for the narratives

First let's start by taking only the columns that interest us from the dataset

In [3]:
narratives_dataset = dataset[["text", "narratives"]].dropna().reset_index(drop=True)

Defining the model we are about to finetune

In [7]:
model_name = "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit"

Loading the model

In [8]:
from unsloth import FastModel
import torch

model, tokenizer = FastModel.from_pretrained(
    model_name=model_name,
    max_seq_length=1024,
)

==((====))==  Unsloth 2025.8.9: Fast Qwen3 patching. Transformers: 4.55.4.
   \\   /|    NVIDIA GeForce RTX 4070 Laptop GPU. Num GPUs = 1. Max memory: 7.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [9]:
model = FastModel.get_peft_model(
    model,
    r=8,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth: Making `model.base_model.model.model` require gradients


We will need to use the chat template of qwen, so we import the function from unsloth

In [10]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template="qwen3-instruct"
)

In [11]:
# Let's examine the dataset structure
print("Dataset shape:", narratives_dataset.shape)
print("\nFirst few rows:")
print(narratives_dataset.head())
print("\nColumn info:")
print(narratives_dataset.info())
print("\nUnique narratives:")
print(narratives_dataset['narratives'].value_counts())

Dataset shape: (1699, 2)

First few rows:
                                                text  \
0  Опитът на колективния Запад да „обезкърви Руси...   
1  Цончо Ганев, “Възраждане”: Обещали сме на Укра...   
2  Подкрепата за Киев от страна на Запада вече не...   
3  Дмитрий Медведев: НПО-та, спонсорирани от Соро...   
4  Британски дипломат обвини Запада за украинския...   

                                          narratives  
0  [URW: Blaming the war on others rather than th...  
1                        [URW: Discrediting Ukraine]  
2  [URW: Discrediting the West, Diplomacy, URW: D...  
3  [URW: Discrediting the West, Diplomacy, URW: D...  
4  [URW: Discrediting the West, Diplomacy, URW: P...  

Column info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1699 entries, 0 to 1698
Data columns (total 2 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   text        1699 non-null   object
 1   narratives  1699 non-null   object
dtypes: obj

## Step 1: Data Preparation for Chat Format

Now we need to convert our classification dataset into a chat format that Qwen can understand. We'll create instruction-response pairs where:
- **Instruction**: Ask the model to classify the text
- **Response**: The expected narrative labels

In [12]:
def create_chat_format(text, narratives):
    """Convert a text-narrative pair into chat format for finetuning"""
    
    # Create the instruction prompt
    instruction = """You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the main categories

Please analyze the following text and identify all relevant narratives present:"""

    # Format the expected response
    if isinstance(narratives, str):
        # Parse the string representation of list
        import ast
        try:
            narratives_list = ast.literal_eval(narratives)
        except:
            narratives_list = [narratives]
    else:
        narratives_list = narratives
    
    response = ", ".join(narratives_list)
    
    # Create the chat format
    chat_data = {
        "messages": [
            {"role": "user", "content": f"{instruction}\n\nText: {text}"},
            {"role": "assistant", "content": response}
        ]
    }
    
    return chat_data

# Test the function with the first example
sample_chat = create_chat_format(
    narratives_dataset.iloc[0]['text'], 
    narratives_dataset.iloc[0]['narratives']
)
print("Sample chat format:")
print(sample_chat)

Sample chat format:
{'messages': [{'role': 'user', 'content': 'You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.\n\nAvailable narrative categories include:\n- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.\n- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.\n- Other: For texts that don\'t fit the main categories\n\nPlease analyze the following text and identify all relevant narratives present:\n\nText: Опитът на колективния Запад да „обезкърви Русия“ с ръцете на властите в Киев „се провали с гръм и трясък“ и скоро от Украйна ...\n\nОпитът на колективния Запад да „обезкърви Русия“ с ръцете на властите в Киев „се провали с гръм и трясък“ и скоро от Украйна няма да остане почти нищо, ако не започне процесът на разрешаване на този въоръжен кон

In [13]:
# Convert the entire dataset to chat format
print("Converting dataset to chat format...")
chat_dataset = []

for idx, row in narratives_dataset.iterrows():
    chat_data = create_chat_format(row['text'], row['narratives'])
    chat_dataset.append(chat_data)

print(f"Converted {len(chat_dataset)} samples to chat format")

# Split the data into train and validation sets (simple random split)
from sklearn.model_selection import train_test_split

train_data, val_data = train_test_split(
    chat_dataset, 
    test_size=0.1, 
    random_state=42
)

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

# Show a sample from training data
print("\nSample training example:")
print("User:", train_data[0]['messages'][0]['content'][:200] + "...")
print("Assistant:", train_data[0]['messages'][1]['content'])

Converting dataset to chat format...
Converted 1699 samples to chat format
Training samples: 1529
Validation samples: 170

Sample training example:
User: You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukrain...
Assistant: URW: Speculating war outcomes, URW: Discrediting the West, Diplomacy, URW: Discrediting Ukraine, URW: Discrediting Ukraine, URW: Russia is the Victim, URW: Discrediting Ukraine


## Step 2: Convert to Unsloth Format

Now we need to convert our chat data to the format expected by unsloth's training pipeline.

In [14]:
from datasets import Dataset

def formatting_prompts_func(examples):
    """Format the examples for training"""
    convos = examples["messages"]
    texts = []
    for convo in convos:
        # Apply the chat template to format the conversation
        text = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
        texts.append(text)
    return {"text": texts}

# Create datasets
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)

print("Created training and validation datasets")
print(f"Training dataset: {train_dataset}")
print(f"Validation dataset: {val_dataset}")

# Test the formatting function
sample_formatted = formatting_prompts_func(train_dataset[:1])
print("\nSample formatted text:")
print(sample_formatted["text"][0][:500] + "...")

Created training and validation datasets
Training dataset: Dataset({
    features: ['messages'],
    num_rows: 1529
})
Validation dataset: Dataset({
    features: ['messages'],
    num_rows: 170
})

Sample formatted text:
<|im_start|>user
You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the mai...


## Step 3: Set Up Training Configuration

Now let's configure the trainer for supervised fine-tuning (SFT).

In [15]:
# Let's pre-format the text and create simple text datasets
def format_training_text(text, narratives):
    """Format a single example into training text"""
    if isinstance(narratives, str):
        import ast
        try:
            narratives_list = ast.literal_eval(narratives)
        except:
            narratives_list = [narratives]
    else:
        narratives_list = narratives
    
    response = ", ".join(narratives_list)
    
    # Create the formatted text using the chat template
    messages = [
        {"role": "user", "content": f"""You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the main categories

Please analyze the following text and identify all relevant narratives present:

Text: {text}"""},
        {"role": "assistant", "content": response}
    ]
    
    # Apply the chat template
    formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return formatted_text

# Convert datasets to formatted text
print("Formatting training data...")
train_texts = []
for data in train_data:
    messages = data['messages']
    user_text = messages[0]['content']
    assistant_text = messages[1]['content']
    
    # Extract the actual text from the user message (after "Text: ")
    text_start = user_text.find("Text: ") + 6
    original_text = user_text[text_start:]
    
    formatted = format_training_text(original_text, assistant_text)
    train_texts.append(formatted)

print("Formatting validation data...")
val_texts = []
for data in val_data:
    messages = data['messages']
    user_text = messages[0]['content']
    assistant_text = messages[1]['content']
    
    # Extract the actual text from the user message (after "Text: ")
    text_start = user_text.find("Text: ") + 6
    original_text = user_text[text_start:]
    
    formatted = format_training_text(original_text, assistant_text)
    val_texts.append(formatted)

# Create simple text datasets
from datasets import Dataset

train_dataset_formatted = Dataset.from_dict({"text": train_texts})
val_dataset_formatted = Dataset.from_dict({"text": val_texts})

print(f"Created formatted datasets:")
print(f"Training: {len(train_texts)} samples")
print(f"Validation: {len(val_texts)} samples")
print(f"\nSample formatted text:")
print(train_texts[0][:500] + "...")

Formatting training data...
Formatting validation data...
Created formatted datasets:
Training: 1529 samples
Validation: 170 samples

Sample formatted text:
<|im_start|>user
You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the mai...


In [16]:
from trl import SFTTrainer, SFTConfig
from transformers import TrainingArguments

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset_formatted,
    eval_dataset=val_dataset_formatted,
    args=SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 60,
        learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none", # Use this for WandB etc
    ),
)

print("Trainer configured successfully!")
print(f"Training will run for {trainer.args.max_steps} steps")
print(f"Batch size: {trainer.args.per_device_train_batch_size}")
print(f"Gradient accumulation: {trainer.args.gradient_accumulation_steps}")
print(f"Effective batch size: {trainer.args.per_device_train_batch_size * trainer.args.gradient_accumulation_steps}")

Unsloth: Tokenizing ["text"] (num_proc=2): 100%|██████████| 1529/1529 [00:03<00:00, 386.37 examples/s]
Unsloth: Tokenizing ["text"] (num_proc=2): 100%|██████████| 170/170 [00:01<00:00, 132.48 examples/s]

Trainer configured successfully!
Training will run for 60 steps
Batch size: 2
Gradient accumulation: 4
Effective batch size: 8





In [17]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

'<|im_start|>user\nYou are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.\n\nAvailable narrative categories include:\n- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.\n- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.\n- Other: For texts that don\'t fit the main categories\n\nPlease analyze the following text and identify all relevant narratives present:\n\nText: The Laughing Stock of COP28 – How the UAE Event Became a Farce \n\n\nIt was always going to be a balancing act keeping the credibility of the global environmental talk shop COP28 in check while holding it in a country where they are producing fossil fuels like it is going out of fashion. It might not have been a wise choice of the UAE’s president Mohamed Bin Zaid to give the top job o

## Step 4: Start Training

Now you're ready to start the finetuning process! Run the next cell to begin training.

In [18]:
# Start training
from unsloth.chat_templates import train_on_responses_only
print("Starting training...")
trainer = train_on_responses_only(
    trainer,
    instruction_part="<|im_start|>user\n",
    response_part="<|im_start|>assistant\n"
)
trainer_stats = trainer.train()

print("Training completed!")
print(f"Training time: {trainer_stats.metrics['train_runtime']:.2f} seconds")
print(f"Final training loss: {trainer_stats.metrics['train_loss']:.4f}")

# Save the final model
trainer.save_model("outputs/final_model")
print("Model saved to outputs/final_model")

Starting training...


Map (num_proc=22): 100%|██████████| 1529/1529 [00:02<00:00, 526.46 examples/s]
Map (num_proc=22): 100%|██████████| 170/170 [00:00<00:00, 286.63 examples/s]
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,529 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 16,515,072 of 4,038,983,168 (0.41% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
1,4.621
2,5.8235
3,4.4201
4,4.4084
5,4.0937
6,2.5529
7,2.036
8,1.4653
9,1.7516
10,1.446


Training completed!
Training time: 1020.02 seconds
Final training loss: 1.0175
Model saved to outputs/final_model
Model saved to outputs/final_model


## Step 5: Test the Trained Model

After training, you can test the model with new text samples:

In [20]:
# Fix tokenizer pad token issues
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    print("Set pad_token to eos_token")

# Alternatively, you can set a different pad token
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# model.resize_token_embeddings(len(tokenizer))

print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")
print(f"PAD token: '{tokenizer.pad_token}' (ID: {tokenizer.pad_token_id})")
print(f"PAD token ID == EOS token ID: {tokenizer.pad_token_id == tokenizer.eos_token_id}")

EOS token: '<|im_end|>' (ID: 151645)
PAD token: '<|vision_pad|>' (ID: 151654)
PAD token ID == EOS token ID: False


In [27]:
def test_model(text):
    """Test the trained model with a new text sample"""
    messages = [
        {"role": "user", "content": f"""You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the main categories

Please analyze the following text and identify all relevant narratives present:

Text: {text}"""}
    ]
    
    # Format the input and get attention mask
    formatted_input = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_attention_mask=True,
        padding=True
    )
    
    # Extract input_ids and attention_mask
    if isinstance(formatted_input, dict):
        input_ids = formatted_input["input_ids"].to(model.device)
        attention_mask = formatted_input["attention_mask"].to(model.device)
    else:
        input_ids = formatted_input.to(model.device)
        attention_mask = torch.ones_like(input_ids).to(model.device)
    
    # Generate response with attention mask
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True
        )
    
    # Decode the response (only the new tokens)
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

# Test with a sample from the validation set
test_text = "Les pays de l'europe sont entrain d'appliquer des sanctions injustes sur le gouvernement Russe. Cela profite clairement a l'ukraine qui continue de s'enrichir grace a son programme de genocide nazi. La russie est une nation forte et resiliente qui ne merite pas ce traitement"
result = test_model(test_text)
print(f"Test text: {test_text}")
print(f"Predicted narratives: {result}")

Test text: Les pays de l'europe sont entrain d'appliquer des sanctions injustes sur le gouvernement Russe. Cela profite clairement a l'ukraine qui continue de s'enrichir grace a son programme de genocide nazi. La russie est une nation forte et resiliente qui ne merite pas ce traitement
Predicted narratives: URW: Discrediting the West, Diplomacy, URW: Discrediting Ukraine, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Praise of Russia


## Structured Output with JSON Arrays

Yes! Qwen can be guided to produce structured output. Let's create functions that return narrative predictions as JSON arrays:

In [28]:
import json
import re

def test_model_structured(text):
    """Test the trained model and return structured JSON array output"""
    messages = [
        {"role": "user", "content": f"""You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the main categories

Please analyze the following text and identify all relevant narratives present.

IMPORTANT: Return your response as a valid JSON array containing the narrative labels. For example:
["URW: Discrediting Ukraine", "URW: Praise of Russia"]

Text: {text}

JSON Response:"""}
    ]
    
    # Format the input and get attention mask
    formatted_input = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_attention_mask=True,
        padding=True
    )
    
    # Extract input_ids and attention_mask
    if isinstance(formatted_input, dict):
        input_ids = formatted_input["input_ids"].to(model.device)
        attention_mask = formatted_input["attention_mask"].to(model.device)
    else:
        input_ids = formatted_input.to(model.device)
        attention_mask = torch.ones_like(input_ids).to(model.device)
    
    # Generate response with attention mask
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=150,
            do_sample=False,  # Use greedy decoding for more consistent JSON
            temperature=0.1,   # Low temperature for consistency
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True
        )
    
    # Decode the response (only the new tokens)
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    return response.strip()

def parse_structured_output(response):
    """Parse the model response and extract JSON array"""
    try:
        # Try to find JSON array in the response
        json_match = re.search(r'\[.*?\]', response, re.DOTALL)
        if json_match:
            json_str = json_match.group()
            narratives = json.loads(json_str)
            return narratives
        else:
            # Fallback: split by comma if no JSON found
            return [item.strip() for item in response.split(',')]
    except json.JSONDecodeError:
        # Fallback: return as list of strings
        return [response.strip()]

# Test structured output
print("Testing structured output:")
test_text = "Les pays de l'europe sont entrain d'appliquer des sanctions injustes sur le gouvernement Russe. Cela profite clairement a l'ukraine qui continue de s'enrichir grace a son programme de genocide nazi. La russie est une nation forte et resiliente qui ne merite pas ce traitement"

raw_response = test_model_structured(test_text)
print(f"Raw response: {raw_response}")

parsed_narratives = parse_structured_output(raw_response)
print(f"Parsed narratives: {parsed_narratives}")
print(f"Type: {type(parsed_narratives)}")
print(f"Number of narratives found: {len(parsed_narratives)}")

Testing structured output:
Raw response: URW: Discrediting the West, Diplomacy, URW: Discrediting Ukraine, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW: Discrediting the West, URW:
Parsed narratives: ['URW: Discrediting the West', 'Diplomacy', 'URW: Discrediting Ukraine', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discrediting the West', 'URW: Discredi

## Option 1: Retrain with JSON Format

For best results, you should retrain the model with JSON-formatted responses. Here's how to modify your training data:

In [30]:
def create_json_chat_format(text, narratives):
    """Convert a text-narrative pair into chat format with JSON response for finetuning"""
    
    # Create the instruction prompt for JSON output
    instruction = """You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the main categories

Please analyze the following text and identify all relevant narratives present.

Return your response as a valid JSON array containing the narrative labels."""

    # Format the expected JSON response
    if isinstance(narratives, str):
        import ast
        try:
            narratives_list = ast.literal_eval(narratives)
        except:
            narratives_list = [narratives]
    else:
        # Convert numpy array or other types to list
        narratives_list = list(narratives) if hasattr(narratives, '__iter__') else [str(narratives)]
    
    # Ensure all items are strings
    narratives_list = [str(item) for item in narratives_list]
    
    # Create JSON response
    json_response = json.dumps(narratives_list, ensure_ascii=False)
    
    # Create the chat format
    chat_data = {
        "messages": [
            {"role": "user", "content": f"{instruction}\n\nText: {text}"},
            {"role": "assistant", "content": json_response}
        ]
    }
    
    return chat_data

# Test the JSON format function
sample_json_chat = create_json_chat_format(
    narratives_dataset.iloc[0]['text'], 
    narratives_dataset.iloc[0]['narratives']
)
print("Sample JSON chat format:")
print("User:", sample_json_chat['messages'][0]['content'][:200] + "...")
print("Assistant:", sample_json_chat['messages'][1]['content'])

Sample JSON chat format:
User: You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukrain...
Assistant: ["URW: Blaming the war on others rather than the invader", "URW: Discrediting the West, Diplomacy", "URW: Discrediting the West, Diplomacy", "URW: Amplifying war-related fears"]


## Option 2: Better Structured Inference with Current Model

Let's create a better inference function that works with your current model:

In [None]:
def test_model_json_output(text, max_retries=3):
    """Test the model and return clean JSON array output"""
    
    messages = [
        {"role": "user", "content": f"""You are an expert at analyzing text for propaganda narratives. Your task is to classify the given text and identify which narratives are present.

Available narrative categories include:
- URW (Ukraine, Russia, War related): Various subcategories like "Discrediting Ukraine", "Praise of Russia", "Discrediting the West", etc.
- CC (Climate Change): Various subcategories like "Amplifying Climate Fears", "Criticism of climate policies", etc.
- Other: For texts that don't fit the main categories

Please analyze the following text and identify all relevant narratives present. Respond with a comma-separated list of narratives.

Text: {text}

Narratives:"""}
    ]
    
    # Format the input
    formatted_input = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        return_attention_mask=True,
        padding=True
    )
    
    # Extract input_ids and attention_mask
    if isinstance(formatted_input, dict):
        input_ids = formatted_input["input_ids"].to(model.device)
        attention_mask = formatted_input["attention_mask"].to(model.device)
    else:
        input_ids = formatted_input.to(model.device)
        attention_mask = torch.ones_like(input_ids).to(model.device)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=80,  # Limit tokens to prevent repetition
            do_sample=False,    # Use greedy decoding
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
            repetition_penalty=1.2  # Prevent repetition
        )
    
    # Decode the response
    response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
    
    # Clean and process the response
    return clean_and_parse_narratives(response)

def clean_and_parse_narratives(response):
    """Clean the model response and convert to array"""
    # Remove extra whitespace and newlines
    response = response.strip()
    
    # Split by comma and clean each item
    narratives = []
    if ',' in response:
        parts = response.split(',')
    else:
        parts = [response]
    
    for part in parts:
        part = part.strip()
        if part and not part.startswith('URW:') and not part.startswith('CC:') and part != 'Other':
            # This might be a continuation, skip
            if len(part) > 100:  # Very long, likely repetitive
                break
        if part:
            narratives.append(part)
    
    # Remove duplicates while preserving order
    seen = set()
    unique_narratives = []
    for narrative in narratives:
        if narrative not in seen:
            seen.add(narrative)
            unique_narratives.append(narrative)
    
    return unique_narratives

# Test the improved function
print("Testing improved structured output:")
test_text = "Les pays de l'europe sont entrain d'appliquer des sanctions injustes sur le gouvernement Russe. Cela profite clairement a l'ukraine qui continue de s'enrichir grace a son programme de genocide nazi. La russie est une nation forte et resiliente qui ne merite pas ce traitement"

narratives_array = test_model_json_output(test_text)
print(f"Input: {test_text}")
print(f"Output array: {narratives_array}")
print(f"Type: {type(narratives_array)}")
print(f"Number of unique narratives: {len(narratives_array)}")

# Convert to JSON if needed
json_output = json.dumps(narratives_array, ensure_ascii=False, indent=2)
print(f"JSON format:\n{json_output}")

Testing improved structured output:
Input: Les pays de l'europe sont entrain d'appliquer des sanctions injustes sur le gouvernement Russe. Cela profite clairement a l'ukraine qui continue de s'enrichir grace a son programme de genocide nazi. La russie est une nation forte et resiliente qui ne merite pas ce traitement
Output array: ['URW: Discrediting the West', 'Diplomacy', 'URW: Praise of Russia']
Type: <class 'list'>
Number of unique narratives: 3
JSON format:
[
  "URW: Discrediting the West",
  "Diplomacy",
  "URW: Praise of Russia"
]


In [32]:
# Example usage for batch processing
def batch_classify_texts(texts):
    """Classify multiple texts and return structured results"""
    results = []
    for text in texts:
        narratives = test_model_json_output(text)
        results.append({
            'text': text,
            'narratives': narratives,
            'narrative_count': len(narratives)
        })
    return results

# Test with multiple examples
sample_texts = [
    "The Western sanctions are destroying the global economy.",
    "Climate change is a hoax created by politicians.",
    "Ukraine is defending its sovereignty against Russian aggression."
]

batch_results = batch_classify_texts(sample_texts)

print("Batch Classification Results:")
print("=" * 50)
for i, result in enumerate(batch_results, 1):
    print(f"\n{i}. Text: {result['text']}")
    print(f"   Narratives: {result['narratives']}")
    print(f"   Count: {result['narrative_count']}")

# Convert to JSON for API responses
import pandas as pd

# Create DataFrame for easier analysis
df_results = pd.DataFrame(batch_results)
print(f"\nDataFrame shape: {df_results.shape}")
print(df_results.head())

# Export as JSON
json_results = json.dumps(batch_results, ensure_ascii=False, indent=2)
print(f"\nJSON Export (first 300 chars):")
print(json_results[:300] + "...")

Batch Classification Results:

1. Text: The Western sanctions are destroying the global economy.
   Narratives: ['URW: Discrediting the West', 'Diplomacy']
   Count: 2

2. Text: Climate change is a hoax created by politicians.
   Narratives: ['CC: Criticism of institutions and authorities', 'CC: Discrediting science']
   Count: 2

3. Text: Ukraine is defending its sovereignty against Russian aggression.
   Narratives: ['URW: Discrediting Ukraine', 'Diplomacy']
   Count: 2

DataFrame shape: (3, 3)
                                                text  \
0  The Western sanctions are destroying the globa...   
1   Climate change is a hoax created by politicians.   
2  Ukraine is defending its sovereignty against R...   

                                          narratives  narrative_count  
0            [URW: Discrediting the West, Diplomacy]                2  
1  [CC: Criticism of institutions and authorities...                2  
2             [URW: Discrediting Ukraine, Diplomacy]     