<a href="https://colab.research.google.com/github/Horizontal-Labs/training-zoo/blob/main/tinyllama_finetuning_peft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!apt update
!apt install -y libmariadb-dev

[33m0% [Working][0m            Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
[33m0% [Connecting to archive.ubuntu.com] [Connecting to security.ubuntu.com (185.1[0m                                                                               Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:3 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:4 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:7 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists... Done
Building dependency tree... Done
Reading state inform

In [2]:
!pip install mysql-connector-python sqlalchemy mariadb



In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
!git clone https://github.com/Horizontal-Labs/Argument-Mining.git

fatal: destination path 'Argument-Mining' already exists and is not an empty directory.


In [5]:
# change direcotry
import sys
sys.path.append('/content/Argument-Mining')

In [6]:
import os
import shutil

cache_dir = '/content/Argument-Mining/.cache'

# Check if the cache folder exists and delete its contents
if os.path.exists(cache_dir):
    shutil.rmtree(cache_dir)  # Deletes the entire directory and its contents
    print(f"Cache folder at {cache_dir} has been deleted.")
else:
    print("Cache folder does not exist.")

Cache folder at /content/Argument-Mining/.cache has been deleted.


In [7]:
%cd /content/Argument-Mining/
!git pull origin main

/content/Argument-Mining
From https://github.com/Horizontal-Labs/Argument-Mining
 * branch            main       -> FETCH_HEAD
Already up to date.


In [8]:
import os
cache_dir = '/content/Argument-Mining/.cache'
os.makedirs(cache_dir, exist_ok=True)

In [9]:
from db.queries import get_training_data, get_test_data

# Load training data
claims_train, premises_train, relationships_train = get_training_data()

# Load test data
claims_test, premises_test, relationships_test = get_test_data()

In [10]:
print("Train Claims:", len(claims_train))
print("Train Premises:", len(premises_train))
print("Train Relationships:", len(relationships_train))

print("Test Claims:", len(claims_test))
print("Test Premises:", len(claims_test))
print("Test Relationships:", len(claims_test))

Train Claims: 40923
Train Premises: 40923
Train Relationships: 40923
Test Claims: 10395
Test Premises: 10395
Test Relationships: 10395


In [11]:
# Check attributes available in one of the ADU objects
print(vars(claims_train[0]))  # Check first claim object
print(vars(premises_train[0]))  # Check first premise object

{'_sa_instance_state': <sqlalchemy.orm.state.InstanceState object at 0x7983cc900530>, 'domain_id': 2, 'id': 3, 'text': 'This house believes that the sale of violent video games to minors should be banned', 'type': 'claim'}
{'_sa_instance_state': <sqlalchemy.orm.state.InstanceState object at 0x7983a17fef90>, 'domain_id': 2, 'id': 5, 'text': 'video game violence is not related to serious aggressive behavior in real life', 'type': 'premise'}


In [12]:
# Check if any ADU is None
print(f"Claims Train has {sum(x is None for x in claims_train)} None values.")
print(f"Premises Train has {sum(x is None for x in premises_train)} None values.")
print(f"Relationship Train has {sum(x is None for x in relationships_train)} None values.")

Claims Train has 0 None values.
Premises Train has 0 None values.
Relationship Train has 0 None values.


In [13]:
import pandas as pd

# Create pairs of claims and premises
debate_pairs = []

for i in range(len(claims_train)):
    debate_pairs.append({
        "claim": claims_train[i].text,
        "premise": premises_train[i].text,
        "stance": relationships_train[i],
    })

# Create final DataFrame
train_data = pd.DataFrame(debate_pairs)

print(train_data.head())

                                               claim  \
0  This house believes that the sale of violent v...   
1  This house supports the one-child policy of th...   
2  This house would permit the use of performance...   
3  This house would make physical education compu...   
4  This house believes in the use of affirmative ...   

                                             premise      stance  
0  video game violence is not related to serious ...  stance_con  
1         The policy had proved remarkably effective  stance_pro  
2  The use of drugs to enhance performance is con...  stance_con  
3  Frequent and regular physical exercise boosts ...  stance_pro  
4  In some countries which have laws on racial eq...  stance_con  


In [14]:
# Remove 'stance_' prefix for simplicity
train_data['stance'] = train_data['stance'].str.replace('stance_', '')

print(train_data.head())

                                               claim  \
0  This house believes that the sale of violent v...   
1  This house supports the one-child policy of th...   
2  This house would permit the use of performance...   
3  This house would make physical education compu...   
4  This house believes in the use of affirmative ...   

                                             premise stance  
0  video game violence is not related to serious ...    con  
1         The policy had proved remarkably effective    pro  
2  The use of drugs to enhance performance is con...    con  
3  Frequent and regular physical exercise boosts ...    pro  
4  In some countries which have laws on racial eq...    con  


# **Finetuning**

In [15]:
!pip install trl



In [16]:
!pip install evaluate



In [17]:
!pip install -U bitsandbytes



In [18]:
import numpy as np
import torch # for model training and tensor operations
from datasets import Dataset # interface for working with datasets
from transformers import (
    AutoModelForCausalLM,      # loads a pre-trained CLM
    AutoTokenizer,             # loads corresponding tokenizer for a model
    BitsAndBytesConfig,        # configuration for quantization techniques
    TrainingArguments,          # holds arguments for training
    DataCollatorForSeq2Seq
)

# PEFT (Parameter-Efficient Fine-Tuning)
from peft import (
    get_peft_model,               # wraps the base model with PEFT capabilities
    LoraConfig,                   # configuration for LoRA (Low-Rank Adaptation)
    TaskType,                     # specifies the type of task (e.g. CLM)
    prepare_model_for_kbit_training,  # prepares a quantized model
    PeftModel                     # class for loading and managing PEFT models
)

from trl import SFTTrainer          # Trainer for supervised fine-tuning of language models.
import evaluate # HF library to compute evaluation metrics for ML models
from sklearn.metrics import accuracy_score, f1_score, classification_report # common metrics for classification tasks

In [19]:
from sklearn.model_selection import train_test_split

train_data, eval_data = train_test_split(train_data, test_size=0.2, random_state=42)

# Data Formatting for Multitask Learning with Instructions

In [20]:
def format_for_argument_mining(df):
    formatted_data = []

    for _, row in df.iterrows():
        # Format with instructions for each task

        # Task 1: ADU Identification - Extract ADUs from text
        # Samples for ADU identification using both claims and premises
        claim_adu_sample = {
            "instruction": "Identify whether the following text contains an Argumentative Discourse Unit (ADU). An ADU is a span of text that serves as a claim or premise in an argument.",
            "input": f"Text: {row['claim']}",
            "output": "Yes, this text contains an ADU. It functions as a claim in an argument."
        }

        premise_adu_sample = {
            "instruction": "Identify whether the following text contains an Argumentative Discourse Unit (ADU). An ADU is a span of text that serves as a claim or premise in an argument.",
            "input": f"Text: {row['premise']}",
            "output": "Yes, this text contains an ADU. It functions as a premise in an argument."
        }


        # Task 2: ADU Classification (determine if claim or premise)
        adu_class_sample = {
            "instruction": "Classify the following Argumentative Discourse Unit (ADU) as either a claim or premise. A claim is the main point being argued, while a premise provides support or evidence.",
            "input": f"ADU: {row['claim']}",
            "output": "This ADU is a claim."
        }

        adu_class_sample2 = {
            "instruction": "Classify the following Argumentative Discourse Unit (ADU) as either a claim or premise. A claim is the main point being argued, while a premise provides support or evidence.",
            "input": f"ADU: {row['premise']}",
            "output": "This ADU is a premise."
        }

        # Task 3: Stance Classification
        stance_sample = {
            "instruction": "Determine if the premise supports or counters the given claim.",
            "input": f"Claim: {row['claim']}\nPremise: {row['premise']}",
            "output": f"{row['stance']}"
        }

        # Task 4: Relationship identification between ADUs
        relationship_sample = {
            "instruction": "Identify the relationship between the following claim and premise. Explain how they are connected in the argument structure.",
            "input": f"Claim: {row['claim']}\nPremise: {row['premise']}",
            "output": f"The premise {'supports' if row['stance'] == 'pro' else 'counters'} the claim. The relationship is {'supportive' if row['stance'] == 'pro' else 'contradictory'}, where the premise provides {'evidence for' if row['stance'] == 'pro' else 'evidence against'} the main argument."
        }

        # Add all tasks to our dataset
        formatted_data.extend([
            claim_adu_sample,
            premise_adu_sample,
            adu_class_sample,
            adu_class_sample2,
            stance_sample,
            relationship_sample
        ])

    return formatted_data

In [21]:
def tokenize_function(example, tokenizer, max_length=1024):
    # Combine instruction and input text as prompt
    prompt = example['instruction'] + "\n" + example['input']
    target = example['output']

    # Tokenize inputs (prompt)
    inputs = tokenizer(
        prompt,
        max_length=max_length,
        padding="max_length",
        truncation=True
    )

    # Tokenize targets (outputs/labels)
    targets = tokenizer(
        target,
        max_length=256,  # output length limit
        padding="max_length",
        truncation=True
    )

    # Prepare labels, replacing pad token id with -100 (ignore index)
    labels = targets["input_ids"].copy()
    labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]

    # Return dictionary for dataset
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels
    }

# Language Model Configuration

In [22]:
# Returns both the model and tokenizer, ready for fine-tuning or inference
def setup_model(model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    try:
        # QLoRA configuration - use 4-bit quantization for memory efficiency
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )

        # Load pre-trained model with quantization
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config, # see above
            device_map="auto", # maps model layers automatically to available CPU/GPUs
            trust_remote_code=True
        )

        # Load correspondent tokenizer for the model
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right" # padding directions for filler tokens

        return model, tokenizer
    except Exception as e:
        print(f"Error loading model: {e}")
        raise

# PEFT LoRA configuration

In [23]:
def configure_peft(model):
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32, # scales output of LoRA adapter before its added back to base model weights
        lora_dropout=0.1,
        # Target only key attention modules
        target_modules=["q_proj", "v_proj"],
        bias="none",       # Don't train biases for more stability
    )

    # Prepare model for training without gradient checkpointing (trade off memory vs slowing down training)
    # model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    model = get_peft_model(model, peft_config)

    # Print trainable parameters info
    print_trainable_parameters(model)

    return model

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [24]:
def setup_training(model, train_dataset, eval_dataset, output_dir="./argument-mining-lora"):
    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=1e-4,            # Low-ish learning rate for stability
        weight_decay=0.05,             # weight decay to combat overfitting
        logging_steps=20,
        save_strategy="epoch",
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        fp16=True,
        eval_strategy="epoch",   # Add evaluation during training
        save_total_limit=2,            # Save disk space
        load_best_model_at_end=True,   # Automatically use best model
        remove_unused_columns=False    # Keep all columns for potential use
    )

    # Set up the trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset= eval_dataset
    )

    return trainer

In [25]:
def evaluate_model(model, tokenizer, test_data, task="stance"):
    model.eval()
    predictions = []
    references = []

    for item in test_data:
        if task == "stance":
            prompt = f"Instruction: Determine if the premise supports or counters the given claim.\nInput: Claim: {item['claim']}\nPremise: {item['premise']}\nOutput:"
        elif task == "adu_identification":
            prompt = f"Instruction: Identify whether the following text contains an Argumentative Discourse Unit (ADU).\nInput: Text: {item['text']}\nOutput:"
        elif task == "adu_classification":
            prompt = f"Instruction: Classify the following Argumentative Discourse Unit (ADU) as either a claim or premise.\nInput: ADU: {item['text']}\nOutput:"
        elif task == "relationship":
            prompt = f"Instruction: Identify the relationship between the following claim and premise.\nInput: Claim: {item['claim']}\nPremise: {item['premise']}\nOutput:"
        else:
            continue

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=50,
                temperature=0.1,
                do_sample=False,
                num_beams=1
            )

        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        prediction = generated_text.split("Output:")[-1].strip().lower()

        # Extract relevant information based on task
        if task == "stance":
            if "pro" in prediction:
                predictions.append("pro")
            elif "con" in prediction:
                predictions.append("con")
            else:
                predictions.append("unknown")

            references.append(item['stance'])

        elif task == "adu_identification":
            if "yes" in prediction:
                predictions.append("contains_adu")
            elif "no" in prediction:
                predictions.append("no_adu")
            else:
                predictions.append("unknown")

            references.append(item['contains_adu'])

        elif task == "adu_classification":
            if "claim" in prediction:
                predictions.append("claim")
            elif "premise" in prediction:
                predictions.append("premise")
            else:
                predictions.append("unknown")

            references.append(item['adu_type'])

    # Calculate metrics
    accuracy = accuracy_score(references, predictions)

    # For binary tasks like ADU identification
    if task == "adu_identification":
        f1 = f1_score(references, predictions, average='binary', pos_label="contains_adu")
    # For multi-class tasks
    else:
        f1 = f1_score(references, predictions, average='weighted')

    print(f"Task: {task}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(classification_report(references, predictions))

    return accuracy, f1

In [26]:
# Training Pipeline

def train_argument_mining_model(train_data, model_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0"):
    # Setup model and tokenizer
    model, tokenizer = setup_model(model_id)

    # Format data for training
    formatted_train_data = format_for_argument_mining(train_data)
    formatted_eval_data = format_for_argument_mining(eval_data)

    # Create HF Dataset from list of dicts
    hf_train_dataset = Dataset.from_list(formatted_train_data)
    hf_eval_dataset = Dataset.from_list(formatted_eval_data)

    # Apply tokenizer function to the train dataset (batched for speed)
    train_dataset = hf_train_dataset.map(
    lambda x: tokenize_function(x, tokenizer),
    batched=False
    )

    # Apply tokenizer function to the eval dataset (batched for speed)
    eval_dataset = hf_eval_dataset.map(
    lambda x: tokenize_function(x, tokenizer),
    batched=False
    )

    # Configure PEFT/LoRA
    model = configure_peft(model)

    # Setup training
    trainer = setup_training(model, train_dataset, eval_dataset)

    # Train the model
    trainer.train()

    # Save the model and tokenizer
    peft_model_id = f"argument-mining-{model_id.split('/')[-1]}"
    trainer.model.save_pretrained(peft_model_id)
    tokenizer.save_pretrained(peft_model_id)

    print(f"Model saved to {peft_model_id}")
    return model, tokenizer, peft_model_id

In [27]:
def run_inference(model, tokenizer, text, task="stance", claim=None, premise=None):
    """
    Run inference on different argument mining tasks

    Args:
        model: The fine-tuned model
        tokenizer: The tokenizer
        text: Text to analyze (for ADU tasks)
        task: Which task to perform - "adu_identification", "adu_classification", "stance", or "relationship"
        claim: The claim text (for stance and relationship tasks)
        premise: The premise text (for stance and relationship tasks)

    Returns:
        Prediction result as a string
    """
    if task == "stance":
        if not claim or not premise:
            return "Error: Claim and premise required for stance classification"
        prompt = f"Instruction: Determine if the premise supports or counters the given claim.\nInput: Claim: {claim}\nPremise: {premise}\nOutput:"

    elif task == "adu_identification":
        prompt = f"Instruction: Identify whether the following text contains an Argumentative Discourse Unit (ADU). An ADU is a span of text that serves as a claim or premise in an argument.\nInput: Text: {text}\nOutput:"

    elif task == "adu_classification":
        prompt = f"Instruction: Classify the following Argumentative Discourse Unit (ADU) as either a claim or premise. A claim is the main point being argued, while a premise provides support or evidence.\nInput: ADU: {text}\nOutput:"

    elif task == "relationship":
        if not claim or not premise:
            return "Error: Claim and premise required for relationship identification"
        prompt = f"Instruction: Identify the relationship between the following claim and premise. Explain how they are connected in the argument structure.\nInput: Claim: {claim}\nPremise: {premise}\nOutput:"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.1,
            do_sample=False  # Deterministic generation for evaluation
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    prediction = generated_text.split("Output:")[-1].strip()

    return prediction

# Training

In [None]:
if __name__ == "__main__":
    # Load your data

    print(torch.cuda.is_available())
    print(torch.cuda.get_device_name(0))

    print("Starting training pipeline...")
    model, tokenizer, model_id = train_argument_mining_model(train_data)

    # Example inference for all tasks
    print("\n--- Example Inferences ---")

    # 1. ADU Identification example
    example_text = "The government should invest more in renewable energy."
    adu_prediction = run_inference(model, tokenizer, example_text, task="adu_identification")
    print(f"ADU Identification: {adu_prediction}")

    # 2. ADU Classification example
    example_adu = "Studies show that renewable energy creates more jobs than fossil fuels."
    class_prediction = run_inference(model, tokenizer, example_adu, task="adu_classification")
    print(f"ADU Classification: {class_prediction}")

    # 3. Stance Classification example
    example_claim = "This house believes that social media is harmful to society."
    example_premise = "Social media has been linked to increased rates of depression in teenagers."
    stance_prediction = run_inference(model, tokenizer, None, task="stance",
                                     claim=example_claim, premise=example_premise)
    print(f"Stance Classification: {stance_prediction}")

    # 4. Relationship Identification example
    relationship_prediction = run_inference(model, tokenizer, None, task="relationship",
                                           claim=example_claim, premise=example_premise)
    print(f"Relationship Identification: {relationship_prediction}")

True
Tesla T4
Starting training pipeline...


Map:   0%|          | 0/196428 [00:00<?, ? examples/s]

Map:   0%|          | 0/49110 [00:00<?, ? examples/s]

trainable params: 1126400 || all params: 616732672 || trainable%: 0.18


Truncating train dataset:   0%|          | 0/196428 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/49110 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[34m[1mwandb[0m: Currently logged in as: [33mlen-rtz[0m ([33mlen-rtz-th-k-ln[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss
