## SciCite Distillation Process with Gemma-3-12b

**Objective**: Fine-tune Gemma-3-12b-it using reasoning-enhanced data from teacher models (Llama-3.3/Gemma2)

In [1]:
%pip install -q transformers datasets accelerate peft bitsandbytes

Note: you may need to restart the kernel to use updated packages.


### Load Augmented Dataset

We are using the partitioned dataset from the Teacher model.

In [1]:
import pandas as pd
from datasets import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
def load_partition(partition_path: str) -> Dataset:
    df = pd.read_csv(partition_path)
    return Dataset.from_pandas(df[["id", "model_classification", "reasoning"]])

# Replace with dataset path
train_dataset = load_partition("./results/Gemma2_27b/first_partition.csv")

### Student Model and Tokenizer Steup
We are using Hugging Face API with Gemma-3-12b-it as the student model. The tokenizer is also from the same model.

In [None]:
# from transformers import AutoTokenizer, AutoModelForMaskedLM, TrainerCallback
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os

In [None]:
%pip install --upgrade transformers
%pip install --upgrade torch torchvision torchaudio

In [None]:
hf_token = os.getenv('HUGGINGFACE_API_KEY')

model_id = "google/mobilebert-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cpu",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    token=hf_token,
    num_labels=3,
)

Some weights of MobileBertForSequenceClassification were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Data Preparation
Format data with our Teacher reasoning

In [6]:
label_map = {"background": 0, "method": 1, "result": 2}

def format_for_distillation(examples):
    tokenized_text = tokenizer(
        f"Text: {examples['id']}\nTeacher Reasoning: {examples['reasoning']}\nClassification:",
        padding="max_length",
        truncation=True,
        max_length=512
    )
    
    # # Create a labels tensor of the same length as input_ids
    # labels = [label_map[examples["model_classification"]]] * len(tokenized_text["input_ids"])

    # Create a single label for the entire sequence
    labels = label_map[examples["model_classification"]]

    return {
        "input_ids": tokenized_text["input_ids"],
        "attention_mask": tokenized_text["attention_mask"],
        "labels": labels  # Now same shape as input_ids / Now a single scalar value
    }

tokenized_dataset = train_dataset.map(
    format_for_distillation,
    remove_columns=['id', 'model_classification', 'reasoning']
)

Map:   0%|          | 0/1365 [00:00<?, ? examples/s]

Map: 100%|██████████| 1365/1365 [00:02<00:00, 496.53 examples/s]


### Custom Distillation Trainer
Aligns student model with teacher reasoning

In [7]:
import time

In [None]:
from transformers import Trainer, TrainingArguments

class ReasoningDistiller(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.step_counter = 0
        self.start_time = time.time()
        self.total_steps = 0
        self.batch_size = kwargs.get('args').per_device_train_batch_size
        self.epochs = kwargs.get('args').num_train_epochs
        
        # Calculate total steps
        if hasattr(kwargs.get('train_dataset'), '__len__'):
            dataset_size = len(kwargs.get('train_dataset'))
            grad_accum = kwargs.get('args').gradient_accumulation_steps
            self.total_steps = (dataset_size // (self.batch_size * grad_accum)) * self.epochs
            print(f"\n===== TRAINING INFO =====")
            print(f"Dataset size: {dataset_size} examples")
            print(f"Batch size: {self.batch_size}")
            print(f"Gradient accumulation steps: {grad_accum}")
            print(f"Epochs: {self.epochs}")
            print(f"Estimated total steps: {self.total_steps}")
            print(f"========================\n")

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # print("compute_loss called!")
        # print(f"Input keys: {inputs.keys()}")
        # print("Model config:", model.config)
        # for key, value in inputs.items():
        #     print(f"{key}: {value.shape} {value.dtype}")

        # step = self.state.global_step if hasattr(self, 'state') else 0
        # verbose = (step % 10 == 0)
        # if verbose:
        #     print(f"\nStep {step}: Computing loss...")
        #     print(f"Input keys: {inputs.keys()}")

        self.step_counter += 1
        elapsed_time = time.time() - self.start_time
        if self.total_steps > 0:
            progress = (self.step_counter / self.total_steps) * 100
            
            # Only print status every 10 steps to avoid cluttering
            if self.step_counter % 10 == 0 or self.step_counter == 1:
                # Calculate time estimates
                if self.step_counter > 1:
                    avg_time_per_step = elapsed_time / self.step_counter
                    remaining_steps = self.total_steps - self.step_counter
                    remaining_time = avg_time_per_step * remaining_steps
                    
                    # Format time remaining
                    mins_remaining = int(remaining_time // 60)
                    hrs_remaining = mins_remaining // 60
                    mins_remaining = mins_remaining % 60
                    
                    elapsed_mins = int(elapsed_time // 60)
                    elapsed_hrs = elapsed_mins // 60
                    elapsed_mins = elapsed_mins % 60
                    
                    time_info = f"Elapsed: {elapsed_hrs}h {elapsed_mins}m | Remaining: {hrs_remaining}h {mins_remaining}m"
                else:
                    time_info = "Calculating time remaining..."
                
                print(f"\nStep {self.step_counter}/{self.total_steps} ({progress:.2f}% complete)")
                print(f"{time_info}")


        
        # Forward pass
        outputs = model(**inputs, output_hidden_states=True)
        # Extract logits
        logits = outputs.logits  # Shape: (batch_size, sequence_length, vocab_size)
        labels = inputs["labels"]  # Shape: (batch_size, sequence_length)
        
        # Flatten logits and labels for CrossEntropyLoss
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        # Get hidden states for distillation
        # student_hidden = outputs.hidden_states[-1].mean(dim=1)  # Shape: (batch_size, hidden_dim)
        student_hidden = outputs.hidden_states[-1][:, 0, :]  # Use [CLS] token representation
        teacher_hidden = torch.randn_like(student_hidden)  

        # Alignment loss
        reasoning_loss = torch.nn.functional.mse_loss(student_hidden, teacher_hidden)

        # Combined loss
        total_loss = loss + 0.3 * reasoning_loss

        if self.step_counter % 10 == 0 or self.step_counter == 1:
            print(f"Classification Loss: {loss.item():.4f} | Reasoning Loss: {reasoning_loss.item():.4f}")
            print(f"Total Loss: {total_loss.item():.4f}")

        return (total_loss, outputs) if return_outputs else total_loss


### Training Configuration
For 24FB GPUs as per documentation

In [13]:
training_args = TrainingArguments(
    output_dir="gemma3-distilled",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    fp16=True,
    logging_steps=50,
    report_to="none",
    save_strategy="steps",
    eval_steps=200,
    remove_unused_columns=False
)

In [14]:
# def tokenize_function(examples):
#     return tokenizer(
#         examples["text"],
#         padding="max_length",
#         truncation=True,
#         max_length=512,
#         return_tensors="pt"
#     )

# tokenized_train = tokenized_dataset.map(
#     tokenize_function,
#     batched=True,
#     batch_size=32
# )

trainer = ReasoningDistiller(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)


===== TRAINING INFO =====
Dataset size: 1365 examples
Batch size: 4
Gradient accumulation steps: 4
Epochs: 1
Estimated total steps: 85



In [15]:
for batch in trainer.get_train_dataloader():
    print(f"Batch Input Shape: {batch['input_ids'].shape}")
    print(f"Batch Labels Shape: {batch['labels'].shape}")
    break  # Only print the first batch

Batch Input Shape: torch.Size([4, 512])
Batch Labels Shape: torch.Size([4])


In [None]:
trainer.train()


Step 1/85 (1.18% complete)
Calculating time remaining...
Classification Loss: nan | Reasoning Loss: nan
Total Loss: nan


Step,Training Loss



Step 10/85 (11.76% complete)
Elapsed: 0h 23m | Remaining: 2h 57m
Classification Loss: nan | Reasoning Loss: nan
Total Loss: nan


In [None]:
model.save_pretrained("gemma3-distilled-scicite")
tokenizer.save_pretrained("gemma3-distilled-scicite")