# **Install and Imports**

In [None]:
!pip install transformers datasets scikit-learn

In [None]:
from datasets import load_dataset
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW, get_scheduler
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score

# **Load dataset and Create splits**
## Sample 10% of training data for validation set as done in the original paper

In [None]:
ds = load_dataset("anaumghori/cos_e-rationale", split=['train', 'test'])
train_dataset = ds[0]
test_set = ds[1]

val_size = int(0.1 * len(train_dataset))  # 10% for validation
train_size = len(train_dataset) - val_size

# Create the split
train_set, val_set = random_split(
    train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

In [None]:
print(f"Input: {train_set[0]['input']} | \nLabel: {train_set[0]['label']}")
print("Rationale:", train_set[0]['rationale'])

Input: The people wanted to do a demonstration, where did they decide to do it?
Answer Choices:
(a) supermarket
(b) public place
(c) demolition
(d) space shuttle
(e) roadblock | 
Label: public place
Rationale: To determine the most suitable location for a demonstration, I need to consider the characteristics of each option provided. A supermarket is a commercial establishment focused on selling groceries and consumer goods; demonstrations typically take place in public spaces accessible to the general public. A public place, such as a park, square, or plaza, is intentionally designed for gatherings and meetings, allowing for free expression and assembly. Demolition refers to the process of destroying buildings or structures, which is unrelated to hosting a demonstration. Space shuttles are spacecraft used for space travel and are not available for ground-based demonstrations. Roadblocks are temporary barriers set up at intersections to control traffic, not locations for public gatherin

# **Custom Dataset Class**

In [None]:
class T5StepByStepDistillationDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=526):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        # Create input with both label and rationale prefixes
        input_text_label = f"[label] {example['input']}"
        input_text_rationale = f"[rationale] {example['input']}"

        # Tokenize inputs for both tasks
        input_encoding_label = self.tokenizer(
            input_text_label,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_encoding_rationale = self.tokenizer(
            input_text_rationale,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Tokenize targets for both tasks
        label_encoding = self.tokenizer(
            example['label'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        rationale_encoding = self.tokenizer(
            example['rationale'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Replace padding token ID with -100
        label_targets = label_encoding.input_ids.clone()
        label_targets[label_targets == self.tokenizer.pad_token_id] = -100

        rationale_targets = rationale_encoding.input_ids.clone()
        rationale_targets[rationale_targets == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids_label': input_encoding_label.input_ids.squeeze(),
            'attention_mask_label': input_encoding_label.attention_mask.squeeze(),
            'labels_label': label_targets.squeeze(),
            'input_ids_rationale': input_encoding_rationale.input_ids.squeeze(),
            'attention_mask_rationale': input_encoding_rationale.attention_mask.squeeze(),
            'labels_rationale': rationale_targets.squeeze()
        }

In [None]:
class T5TestDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=128):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        example = self.dataset[idx]

        # Only predict labels for testing
        input_text = f"[label] {example['input']}"

        # Tokenize inputs
        input_encoding = self.tokenizer(
            input_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Tokenize labels
        label_encoding = self.tokenizer(
            example['label'],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Replace padding token ID with -100
        labels = label_encoding.input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_encoding.input_ids.squeeze(),
            'attention_mask': input_encoding.attention_mask.squeeze(),
            'labels': labels.squeeze(),
            'reference': example['label']  # Store original label for evaluation
        }

# **Validation Function**

In [None]:
def validate_model(model, val_loader, device, lambda_value=0.2):
    model.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            # Process label task
            input_ids_label = batch['input_ids_label'].to(device)
            attention_mask_label = batch['attention_mask_label'].to(device)
            labels_label = batch['labels_label'].to(device)

            # Process rationale task
            input_ids_rationale = batch['input_ids_rationale'].to(device)
            attention_mask_rationale = batch['attention_mask_rationale'].to(device)
            labels_rationale = batch['labels_rationale'].to(device)

            # Get losses for both tasks
            outputs_label = model(
                input_ids=input_ids_label,
                attention_mask=attention_mask_label,
                labels=labels_label
            )

            outputs_rationale = model(
                input_ids=input_ids_rationale,
                attention_mask=attention_mask_rationale,
                labels=labels_rationale
            )

            # Calculate combined loss: L = Llabel + λLrationale
            loss_label = outputs_label.loss
            loss_rationale = outputs_rationale.loss
            combined_loss = loss_label + lambda_value * loss_rationale

            total_val_loss += combined_loss.item()

    return total_val_loss / len(val_loader)

# **Train the model**

In [None]:
def train_step_by_step_distillation(train_set, val_set, model_name="google-t5/t5-small", max_length=526,
                                    batch_size=16, num_epochs=25, learning_rate=3e-5, weight_decay=0.01,
                                    lambda_value=0.2):
    # Load model and tokenizer
    tokenizer = T5Tokenizer.from_pretrained(model_name)
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Prepare datasets for combined training
    train_dataset = T5StepByStepDistillationDataset(train_set, tokenizer, max_length)
    val_dataset = T5StepByStepDistillationDataset(val_set, tokenizer, max_length)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    total_steps = len(train_loader) * num_epochs

    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Store the best model
    best_val_loss = float('inf')
    best_model_state = None

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Training - Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            # Process label task
            input_ids_label = batch['input_ids_label'].to(device)
            attention_mask_label = batch['attention_mask_label'].to(device)
            labels_label = batch['labels_label'].to(device)

            # Process rationale task
            input_ids_rationale = batch['input_ids_rationale'].to(device)
            attention_mask_rationale = batch['attention_mask_rationale'].to(device)
            labels_rationale = batch['labels_rationale'].to(device)

            # Get losses for both tasks
            outputs_label = model(
                input_ids=input_ids_label,
                attention_mask=attention_mask_label,
                labels=labels_label
            )

            outputs_rationale = model(
                input_ids=input_ids_rationale,
                attention_mask=attention_mask_rationale,
                labels=labels_rationale
            )

            # Calculate combined loss: L = Llabel + λLrationale
            loss_label = outputs_label.loss
            loss_rationale = outputs_rationale.loss
            combined_loss = loss_label + lambda_value * loss_rationale

            total_train_loss += combined_loss.item()

            # Backward pass
            optimizer.zero_grad()
            combined_loss.backward()
            optimizer.step()
            lr_scheduler.step()

            progress_bar.set_postfix({
                "loss": combined_loss.item(),
                "label_loss": loss_label.item(),
                "rationale_loss": loss_rationale.item()
            })

        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch {epoch+1} - Average training loss: {avg_train_loss:.4f}")

        # Validation
        print("Running validation...")
        val_loss = validate_model(model, val_loader, device, lambda_value)
        print(f"Epoch {epoch+1} - Validation loss: {val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            print(f"New best model saved with validation loss: {val_loss:.4f}")

    # Load best model for final evaluation
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with validation loss: {best_val_loss:.4f}")

    print("Training completed!")
    return model, tokenizer

In [None]:
model, tokenizer = train_step_by_step_distillation(train_set, val_set)

# **Evaluate the model**

In [None]:
def evaluate_model(model, tokenizer, test_dataset, max_length=128, batch_size=16):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Prepare test dataset
    test_dataset_processed = T5TestDataset(test_dataset, tokenizer, max_length)
    test_loader = DataLoader(test_dataset_processed, batch_size=batch_size)

    model.eval()
    predictions = []
    references = []

    print("Starting evaluation...")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # Generate predictions
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length
            )

            # Decode predictions
            preds = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            refs = batch["reference"]

            predictions.extend(preds)
            references.extend(refs)

    # Calculate accuracy
    accuracy = accuracy_score(references, predictions) * 100
    return accuracy, predictions, references

In [None]:
accuracy, predictions, references = evaluate_model(model, tokenizer, test_set)

# **Results**

In [None]:
def display_evaluation_results(accuracy, predictions, references, num_samples=3):
    print(f"Final Evaluation Accuracy: {accuracy:.2f}%")

    # Display some examples
    print("\nSample predictions:")
    indices = list(range(len(predictions)))
    sample_indices = indices[:num_samples]

    for i in sample_indices:
        print(f"Reference: {references[i]}")
        print(f"Prediction: {predictions[i]}")
        print("-" * 50)

# Display results
display_evaluation_results(accuracy, predictions, references, num_samples=3)

Final Evaluation Accuracy: 52.60%

Reference: wooded area
Prediction: wooded area
--------------------------------------------------
Reference: go downtown
Prediction: go downtown
--------------------------------------------------
Reference: play tag
Prediction: play tag
--------------------------------------------------
