# Notebook 4: LoRA Teacher Model Training (roberta-large on AG News)
Purpose:
1. Load pre-processed original and augmented AG News data.
2. Combine original and augmented data.
3. Split the combined data into Train and Validation sets (rule compliant).
4. Load roberta-large model and apply LoRA configuration.
5. Configure Trainer for LoRA fine-tuning.
6. Fine-tune the LoRA roberta-large model on the new train split, validating on the new validation split.
7. Save the fine-tuned LoRA adapter to be used as a teacher.

In [17]:
# --- Essential Setup ---
import os
import time
import pickle
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import traceback
import random
import shutil
import gc
from sklearn.utils.class_weight import compute_class_weight

from datasets import load_dataset, Dataset, ClassLabel, load_from_disk, concatenate_datasets, Features, Value
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification, 
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    TrainerCallback,
    SchedulerType,
    TrainerState,
    TrainerControl,
)
from peft import LoraConfig, get_peft_model,TaskType 
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.manifold import TSNE

print("Setting up environment...")
!rm -rf /kaggle/working/*

# --- Cache Directory Setup ---
cache_dir = "/kaggle/working/hf_datasets_cache"
os.environ['HF_DATASETS_CACHE'] = cache_dir
os.environ['DATASETS_CACHE'] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
print(f"INFO: Hugging Face datasets cache directory set to: {os.environ.get('HF_DATASETS_CACHE')}")

Setting up environment...
INFO: Hugging Face datasets cache directory set to: /kaggle/working/hf_datasets_cache


In [10]:
# --- Configuration ---
teacher_base_model_name = 'roberta-large' 
dataset_name = 'ag_news'

# Paths to pre-processed data
cleaned_original_load_path = "/kaggle/input/cleanedorig"
tokenized_augmented_load_path = "/kaggle/input/cleanedaugmenteddata"

# Output paths for the LoRA TEACHER model adapter
teacher_output_dir = "/kaggle/working/lora_teacher_training_output" # Training checkpoints/logs
teacher_adapter_save_path = "/kaggle/working/roberta_large_lora_teacher_adapter" # Final saved adapter

# --- LoRA settings FOR TEACHER ---
TEACHER_LORA_R = 16         
TEACHER_LORA_ALPHA = 32        
TEACHER_LORA_DROPOUT = 0.1    
TEACHER_LORA_TARGET_MODULES = ['query', 'value', 'key']
# Tokenizer settings
TOKENIZER_MAX_LENGTH = 512

# Validation split size
VALIDATION_SET_SIZE = 0.1 # 10% for validation

In [11]:
# --- GPU Check ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"GPU is available. Using device: {device}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("WARNING: GPU not available, using CPU. Training will be slow.")

GPU is available. Using device: cuda
GPU Name: Tesla P100-PCIE-16GB


In [12]:
# --- Load Tokenizer ---
print(f"Loading tokenizer for: {teacher_base_model_name}")
try:
    tokenizer = AutoTokenizer.from_pretrained(teacher_base_model_name)
except Exception as e: print(f"ERROR: Failed to load tokenizer: {e}"); raise e

Loading tokenizer for: roberta-large


In [13]:
# --- Label Info ---
num_labels = 4
id2label = {0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}
label2id = {'World': 0, 'Sports': 1, 'Business': 2, 'Sci/Tech': 3}
class_names = list(id2label.values())
labels_feature_definition = ClassLabel(num_classes=num_labels, names=class_names)

In [14]:
print("Loading and preparing Original + Augmented AG News data...")
# --- Load CLEANED ORIGINAL data ---
original_ds_reloaded = None
if os.path.exists(cleaned_original_load_path):
    try:
        # ... (Insert save/reload logic for original_ds_reloaded here, including feature casting) ...
        loaded_cleaned_original_ds_input = load_from_disk(cleaned_original_load_path)
        temp_orig_save_path = "/kaggle/working/original_dataset_temp"; loaded_cleaned_original_ds_input.save_to_disk(temp_orig_save_path)
        original_features_ref = loaded_cleaned_original_ds_input.features; del loaded_cleaned_original_ds_input; gc.collect(); torch.cuda.empty_cache()
        original_ds_reloaded = load_from_disk(temp_orig_save_path)
        if original_ds_reloaded.features != original_features_ref: original_ds_reloaded = original_ds_reloaded.cast(original_features_ref)
        original_labels_feature = original_ds_reloaded.features['labels'] # Get definitive label feature
        if original_labels_feature.num_classes != num_labels or original_labels_feature.names != class_names: # Recast if needed
             original_ds_reloaded = original_ds_reloaded.cast_column('labels', labels_feature_definition); original_labels_feature = original_ds_reloaded.features['labels']
        print(f"INFO: Original dataset reloaded ({len(original_ds_reloaded)} examples).")
        try: shutil.rmtree(temp_orig_save_path)
        except Exception as e_rm: print(f"WARNING: Could not remove {temp_orig_save_path}: {e_rm}")
    except Exception as e: print(f"ERROR loading/processing original: {e}"); raise e
else: raise FileNotFoundError(f"Path not found: {cleaned_original_load_path}")

# --- Load TOKENIZED AUGMENTED data ---
augmented_ds_reloaded = None
if os.path.exists(tokenized_augmented_load_path):
     try:
         loaded_tokenized_augmented_ds_input = load_from_disk(tokenized_augmented_load_path)
         temp_aug_save_path = "/kaggle/working/augmented_dataset_temp"; loaded_tokenized_augmented_ds_input.save_to_disk(temp_aug_save_path)
         augmented_features_ref = loaded_tokenized_augmented_ds_input.features; del loaded_tokenized_augmented_ds_input; gc.collect(); torch.cuda.empty_cache()
         augmented_ds_reloaded = load_from_disk(temp_aug_save_path)
         if augmented_ds_reloaded.features != augmented_features_ref: augmented_ds_reloaded = augmented_ds_reloaded.cast(augmented_features_ref)
         augmented_ds_reloaded = augmented_ds_reloaded.cast_column('labels', original_labels_feature) # Cast labels using feature from original
         print(f"INFO: Augmented dataset reloaded ({len(augmented_ds_reloaded)} examples).")
         try: shutil.rmtree(temp_aug_save_path)
         except Exception as e_rm: print(f"WARNING: Could not remove {temp_aug_save_path}: {e_rm}")
     except Exception as e: print(f"ERROR loading/processing augmented: {e}"); raise e
else: raise FileNotFoundError(f"Path not found: {tokenized_augmented_load_path}")

# --- Combine Original + Augmented Data ---
print("INFO: Combining RELOADED original and RELOADED augmented datasets...")
required_columns = ['input_ids', 'attention_mask', 'labels']
try:
    train_dataset_for_concat = original_ds_reloaded.select_columns(required_columns)
    tokenized_augmented_dataset_for_concat = augmented_ds_reloaded.select_columns(required_columns)
except Exception as e: print(f"ERROR during column selection/prep for concat: {e}"); raise e
combined_dataset_all = concatenate_datasets([train_dataset_for_concat, tokenized_augmented_dataset_for_concat])
print(f"Combined dataset created (Orig+Aug) with {len(combined_dataset_all)} examples.")
del train_dataset_for_concat, tokenized_augmented_dataset_for_concat, original_ds_reloaded, augmented_ds_reloaded
gc.collect(); torch.cuda.empty_cache()

Loading and preparing Original + Augmented AG News data...


Saving the dataset (0/1 shards):   0%|          | 0/114832 [00:00<?, ? examples/s]

INFO: Original dataset reloaded (114832 examples).


Saving the dataset (0/1 shards):   0%|          | 0/114832 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/114832 [00:00<?, ? examples/s]

INFO: Augmented dataset reloaded (114832 examples).
INFO: Combining RELOADED original and RELOADED augmented datasets...
Combined dataset created (Orig+Aug) with 229664 examples.


In [15]:
# --- Create Train/Validation Split from Combined Data ---
print(f"INFO: Splitting combined data into Train/Validation ({1.0-VALIDATION_SET_SIZE:.0%}/{VALIDATION_SET_SIZE:.0%})...")
combined_dataset_shuffled = combined_dataset_all.shuffle(seed=42)
split_datasets = combined_dataset_shuffled.train_test_split(test_size=VALIDATION_SET_SIZE, seed=42, shuffle=False) # Already shuffled

teacher_train_dataset = split_datasets['train']
teacher_eval_dataset = split_datasets['test']

print(f"Teacher Training set size: {len(teacher_train_dataset)}")
print(f"Teacher Validation set size: {len(teacher_eval_dataset)}")
del combined_dataset_all, combined_dataset_shuffled
gc.collect(); torch.cuda.empty_cache()

INFO: Splitting combined data into Train/Validation (90%/10%)...
Teacher Training set size: 206697
Teacher Validation set size: 22967


In [18]:
# --- Prepare LoRA Teacher Model ---
print(f"Loading teacher base model: {teacher_base_model_name}")
model = AutoModelForSequenceClassification.from_pretrained(
    teacher_base_model_name,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
)
print("Teacher base model loaded.")

print(f"Applying LoRA config to Teacher: r={TEACHER_LORA_R}, alpha={TEACHER_LORA_ALPHA}, dropout={TEACHER_LORA_DROPOUT}, targets={TEACHER_LORA_TARGET_MODULES}")
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=TEACHER_LORA_R,
    lora_alpha=TEACHER_LORA_ALPHA,
    target_modules=TEACHER_LORA_TARGET_MODULES,
    lora_dropout=TEACHER_LORA_DROPOUT
)
# Apply LoRA to the roberta-large model
model = get_peft_model(model, lora_config) # model variable now holds the PEFT model

print("LoRA applied to teacher model.")
model.print_trainable_parameters() # Print LoRA parameter count for teacher

Loading teacher base model: roberta-large


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Teacher base model loaded.
Applying LoRA config to Teacher: r=16, alpha=32, dropout=0.1, targets=['query', 'value', 'key']
LoRA applied to teacher model.
trainable params: 3,412,996 || all params: 358,776,840 || trainable%: 0.9513


In [19]:
# --- Define Metrics & Collator ---
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='weighted')
    return {"accuracy": acc, "f1": f1}

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

class MetricsCollectorCallback(TrainerCallback):
    def __init__(self): self.logs = []
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None: self.logs.append((state.global_step, logs))
        return control
metrics_collector = MetricsCollectorCallback()

In [20]:
# --- Define Training Arguments for LoRA Teacher Model ---
print("Defining Training Arguments for LoRA Teacher Model...")
# Adjust parameters for LoRA fine-tuning roberta-large
training_args = TrainingArguments(
    output_dir=teacher_output_dir,
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=100,
    save_steps=500,
    save_total_limit=2, # Keep best 2 checkpoints of the adapter

    learning_rate=5e-5,       
    lr_scheduler_type=SchedulerType.LINEAR, # Or cosine
    warmup_ratio=0.06,

    # Batch size can likely be larger than full fine-tuning large model
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=4, # effective batch size (16*4=64)

    num_train_epochs=1,        # Start with 1 epoch for LoRA teacher

    weight_decay=0.1,         
    label_smoothing_factor=0.1,

    load_best_model_at_end=True, 
    metric_for_best_model="accuracy", 
    greater_is_better=True,

    fp16=torch.cuda.is_available(),

    dataloader_num_workers=2,
    report_to=[],
    logging_first_step=True,
    logging_dir=f"{teacher_output_dir}/logs"
)

Defining Training Arguments for LoRA Teacher Model...


In [21]:
# --- Initialize STANDARD Trainer for LoRA Teacher ---
print("Initializing Trainer for LoRA Teacher Model...")
trainer = Trainer(
    model=model, # Pass the PEFT model (roberta-large + LoRA)
    args=training_args,
    data_collator=data_collator,
    train_dataset=teacher_train_dataset, # Use the split train set
    eval_dataset=teacher_eval_dataset,   # Use the split validation set
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[metrics_collector],
)
print("LoRA Teacher Trainer initialized.")

Initializing Trainer for LoRA Teacher Model...


  trainer = Trainer(
No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


LoRA Teacher Trainer initialized.


In [22]:
# --- Train the LoRA Teacher Model ---
import torch.serialization; import numpy.core.multiarray
try: from numpy.dtypes import UInt32DType
except ImportError: UInt32DType = None
safe_numpy_globals = []; print("Attempting NumPy allowlist...")
try:
    safe_numpy_globals.append(np.core.multiarray._reconstruct)
    safe_numpy_globals.append(np.ndarray); safe_numpy_globals.append(np.dtype)
    if UInt32DType: safe_numpy_globals.append(UInt32DType)
    else: safe_numpy_globals.append(np.uint32)
    safe_numpy_globals = list(set(safe_numpy_globals))
    torch.serialization.add_safe_globals(safe_numpy_globals)
    print(f"Added NumPy components: {[c.__name__ if hasattr(c, '__name__') else str(c) for c in safe_numpy_globals]}")
except Exception as e_gen: print(f"WARNING: Error setting safe globals for numpy: {e_gen}")

print("Starting LoRA Teacher Model Training...")
start_train_time = time.time()
try:
    # Train fresh
    train_result = trainer.train()
    print("LoRA Teacher training finished.")
    metrics = train_result.metrics
    print("--- LoRA Teacher Training Metrics ---")
    for key, value in metrics.items(): print(f"{key}: {value}")
    trainer.save_metrics("train", metrics)
except Exception as e: print(f"ERROR during teacher training: {e}"); traceback.print_exc(); raise e
end_train_time = time.time()
print(f"LoRA Teacher Training Duration: {end_train_time - start_train_time:.2f} seconds")

Attempting NumPy allowlist...
Added NumPy components: ['UInt32DType', '_reconstruct', 'ndarray', 'dtype']
Starting LoRA Teacher Model Training...


Step,Training Loss,Validation Loss,Accuracy,F1
500,0.4583,0.448451,0.952541,0.952533
1000,0.4363,0.434302,0.959768,0.959749
1500,0.437,0.431839,0.959507,0.95944
2000,0.4378,0.424477,0.96225,0.962219
2500,0.4257,0.423432,0.962816,0.962829
3000,0.425,0.42093,0.963643,0.963626


LoRA Teacher training finished.
--- LoRA Teacher Training Metrics ---
train_runtime: 6047.8523
train_samples_per_second: 34.177
train_steps_per_second: 0.534
total_flos: 3.566797779426816e+16
train_loss: 0.46969956109751926
epoch: 0.9997677838842015
LoRA Teacher Training Duration: 6048.39 seconds


In [23]:
# --- Save the Final LoRA Teacher Adapter ---
print(f"\nSaving the fine-tuned LoRA teacher adapter to {teacher_adapter_save_path}...")
try:
    # save_model for PEFT model saves the adapter & config correctly
    trainer.save_model(teacher_adapter_save_path)
    tokenizer.save_pretrained(teacher_adapter_save_path) # Save tokenizer with adapter
    print(f"LoRA Teacher adapter and tokenizer saved to {teacher_adapter_save_path}")
    # Save training args as well
    final_args_path = os.path.join(teacher_adapter_save_path, "teacher_training_args.json")
    with open(final_args_path, 'w') as f: f.write(training_args.to_json_string())
    print(f"Teacher training arguments saved to {final_args_path}")
except Exception as e:
    print(f"ERROR saving final teacher adapter/tokenizer/args: {e}")
    traceback.print_exc()


Saving the fine-tuned LoRA teacher adapter to /kaggle/working/roberta_large_lora_teacher_adapter...
LoRA Teacher adapter and tokenizer saved to /kaggle/working/roberta_large_lora_teacher_adapter
Teacher training arguments saved to /kaggle/working/roberta_large_lora_teacher_adapter/teacher_training_args.json


In [24]:
# --- Optional: Evaluate Teacher on Validation Set ---
print("\nRunning final evaluation on the validation set for the Teacher Model...")
try:
    eval_metrics = trainer.evaluate(eval_dataset=teacher_eval_dataset)
    print("--- Teacher Final Validation Metrics ---")
    for key, value in eval_metrics.items(): print(f"{key}: {value}")
    trainer.save_metrics("eval", eval_metrics)
except Exception as e: print(f"ERROR during final teacher evaluation: {e}")


Running final evaluation on the validation set for the Teacher Model...


--- Teacher Final Validation Metrics ---
eval_loss: 0.42092961072921753
eval_accuracy: 0.9636434884834763
eval_f1: 0.9636262982151969
eval_runtime: 253.1916
eval_samples_per_second: 90.71
eval_steps_per_second: 2.836
epoch: 0.9997677838842015


In [None]:
# --- t-SNE / UMAP Visualization ---
print("\nSetting up for visualization...")
try:
    trainer.model.config.output_hidden_states = True
    print("Model config set to output hidden states.")
except Exception as e: print(f"ERROR setting output_hidden_states: {e}")

print("Extracting CLS token embeddings for visualization...")
subset_size = 300
if len(eval_dataset) < subset_size: subset_size = len(eval_dataset)
features = None; label_list = []
if subset_size > 0:
    try:
        subset_indices = random.sample(range(len(eval_dataset)), subset_size)
        subset = eval_dataset.select(subset_indices)
        feature_list = []
        model_device = trainer.model.device
        print(f"Extracting embeddings using device: {model_device}")
        for sample in subset:
            input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(model_device)
            attention_mask = torch.tensor(sample["attention_mask"]).unsqueeze(0).to(model_device)
            with torch.no_grad(): outputs = trainer.model(input_ids=input_ids, attention_mask=attention_mask)
            if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
                 hidden_state = outputs.hidden_states[-1]; cls_embedding = hidden_state[0, 0, :].cpu().numpy()
                 feature_list.append(cls_embedding); label_list.append(sample["labels"])
            else: print("WARNING: Could not get hidden_states for sample.")
        if feature_list: features = np.array(feature_list); print(f"Extracted {len(features)} embeddings.")
        else: print("WARNING: No features were extracted.")
    except Exception as e: print(f"ERROR extracting embeddings: {e}")
else: print("WARNING: Evaluation dataset too small, skipping visualization.")

print("Plotting t-SNE")
if features is not None and features.shape[0] > 1:
    try:
        tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, features.shape[0] - 1))
        features_2d_tsne = tsne.fit_transform(features); plt.figure(figsize=(8, 6))
        scatter = plt.scatter(features_2d_tsne[:, 0], features_2d_tsne[:, 1], c=label_list, cmap="viridis", alpha=0.7)
        cbar = plt.colorbar(scatter, label="Class Label", ticks=range(len(class_names))); cbar.ax.set_yticklabels(class_names)
        plt.title("t-SNE Visualization"); plt.xlabel("t-SNE Comp 1"); plt.ylabel("t-SNE Comp 2"); plt.show()
    except Exception as e: print(f"ERROR in t-SNE: {e}")
else: print("Skipping plots (not enough features extracted).")