In [1]:
import math
import time
import os
import pandas as pd
import torch
from transformers import RobertaModel, RobertaTokenizer, TrainingArguments, Trainer, DataCollatorWithPadding, RobertaForSequenceClassification, RobertaConfig
# from peft import LoraConfig, get_peft_model, PeftModel
from datasets import load_dataset, Dataset, ClassLabel
import pickle
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print(device)
batch_size_val = 4 if device == "cpu" else 16 if device == "cuda" else 4
num_workers_val = 4 if device == "cuda" else 0
print(batch_size_val, num_workers_val)

mps
4 0


In [4]:
# Load tokenizer and model
base_model = 'roberta-base'
tokenizer = RobertaTokenizer.from_pretrained(base_model)

# Load datasets
train_dataset = load_dataset('ag_news', split='train', cache_dir='./data/')
test_dataset = load_dataset('ag_news', split='test', cache_dir='./data/')

# Extract the number of classess and their names
num_labels = test_dataset.features['label'].num_classes
class_names = test_dataset.features["label"].names
print(f"number of labels: {num_labels}")
print(f"the labels: {class_names}")

# Create an id2label mapping
# We will need this for our classifier.
id2label = {i: label for i, label in enumerate(class_names)}

# Tokenization function
def preprocess(examples):
    return tokenizer(examples['text'], truncation=True, padding="max_length", max_length=512)

# Apply tokenization
tokenized_train = train_dataset.map(preprocess, batched=True, remove_columns=["text"])
tokenized_test = test_dataset.map(preprocess, batched=True, remove_columns=["text"])

# Rename label column
tokenized_train = tokenized_train.rename_column("label", "labels")
tokenized_test = tokenized_test.rename_column("label", "labels")

# Set format for PyTorch
tokenized_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
tokenized_test.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Create data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

# Create DataLoaders
train_dataloader = DataLoader(tokenized_train, batch_size=batch_size_val, shuffle=True, collate_fn=data_collator, num_workers=num_workers_val)
test_dataloader = DataLoader(tokenized_test, batch_size=batch_size_val, shuffle=False, collate_fn=data_collator, num_workers=num_workers_val)

number of labels: 4
the labels: ['World', 'Sports', 'Business', 'Sci/Tech']


In [5]:
def get_accuracy(y_pred, targets):
  predictions = torch.log_softmax(y_pred, dim=1).argmax(dim=1)
  accuracy = (predictions == targets).sum() / len(targets)
  return accuracy

In [6]:
def load_checkpoint(model, optimizer, checkpoint_path, resume_training=True):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    if resume_training:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        step = checkpoint.get('step', 0)
        epoch = checkpoint.get('epoch', 0)
        print(f"Resumed training from epoch {epoch}, step {step}")
        return model, optimizer, step, epoch
    else:
        print("Model loaded for inference")
        return model.to(device)

In [7]:
from functools import partial
import torch.nn as nn

def setup_lora_model(base_model, r, alpha):
    model = RobertaForSequenceClassification.from_pretrained(
        base_model,
        id2label=id2label,
        cache_dir="./model_dir"
    )

    # Freeze base model params
    for param in model.parameters():
        param.requires_grad = False

    # Define LoRA layers
    class LoRALayer(nn.Module):
        def __init__(self, in_dim, out_dim, r, alpha):
            super().__init__()
            self.A = nn.Parameter(torch.empty(r, in_dim))
            self.B = nn.Parameter(torch.empty(out_dim, r))
            torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
            torch.nn.init.zeros_(self.B)
            self.scaling = alpha / r

        def forward(self, x):
            A = self.A.to(x.device)
            B = self.B.to(x.device)
            return self.scaling * (x @ A.T @ B.T)

    class LinearWithLoRA(nn.Module):
        def __init__(self, linear, r, alpha):
            super().__init__()
            self.linear = linear
            self.lora = LoRALayer(linear.in_features, linear.out_features, r, alpha)

            # add bias terms:
            if linear.bias is not None:
                linear.bias.requires_grad = True

        def forward(self, x):
            return self.linear(x) + self.lora(x)

    assign_lora = partial(LinearWithLoRA, r=r, alpha=alpha)

    # Adding interjecting parameters to other layers:
    # for layer in model.roberta.encoder.layer:
    #     layer.attention.self.query = assign_lora(layer.attention.self.query)
    #     layer.attention.self.value = assign_lora(layer.attention.self.value)
    #     layer.output.dense = assign_lora(layer.output.dense)

    for i, layer in enumerate(model.roberta.encoder.layer):
        if i >= 8:
            layer.attention.self.query = assign_lora(layer.attention.self.query)
            layer.attention.self.value = assign_lora(layer.attention.self.value)
            # Optional: output dense layer
            layer.output.dense = assign_lora(layer.output.dense)

    for name, param in model.named_parameters():
        if "classifier" in name:
            param.requires_grad = True
    
    return model.to(device)

In [8]:
loss_function = torch.nn.CrossEntropyLoss()
# optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
# model = setup_lora_model(base_model, r=20, alpha=40)

In [9]:
def evaluate(model, test_loader, return_values=False):
    interval = len(test_loader) // 5 if len(test_loader) >= 5 else 1
    total_test_loss = 0
    total_test_acc = 0

    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            acc = get_accuracy(logits, labels)

            total_test_loss += loss.item()
            total_test_acc += acc.item()

            if (batch_idx + 1) % interval == 0:
                print(f"Batch: {batch_idx+1}/{len(test_loader)} | Test loss: {loss:.4f} | accuracy: {acc:.4f}")

    test_loss = total_test_loss / len(test_loader)
    test_acc = total_test_acc / len(test_loader)

    print(f"Test loss: {test_loss:.4f} acc: {test_acc:.4f}\n")

    if return_values:
        return test_loss, test_acc


In [10]:
import matplotlib.pyplot as plt

def plot_metrics(train_data, test_data, title, ylabel, save_path_base):
    steps, train_vals = zip(*train_data)
    _, test_vals = zip(*test_data)

    # Plot with interactivity
    fig, ax = plt.subplots()
    ax.plot(steps, train_vals, label='Train')
    ax.plot(steps, test_vals, label='Test')
    ax.set_xlabel("Steps")
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.legend()
    ax.grid(True)

    # Save interactive plot as HTML
    try:
        import plotly.graph_objects as go
        import plotly.io as pio

        fig_plotly = go.Figure()
        fig_plotly.add_trace(go.Scatter(x=steps, y=train_vals, mode='lines+markers', name='Train'))
        fig_plotly.add_trace(go.Scatter(x=steps, y=test_vals, mode='lines+markers', name='Test'))
        fig_plotly.update_layout(title=title, xaxis_title='Steps', yaxis_title=ylabel)
        interactive_path = save_path_base.replace(".png", ".html")
        pio.write_html(fig_plotly, file=interactive_path, auto_open=False)
    except ImportError:
        print("Plotly not installed — skipping interactive plot save.")

    # Save static plot
    static_path = save_path_base.replace(".html", ".png")
    fig.savefig(static_path)
    plt.close(fig)

    # Save raw data
    df = pd.DataFrame({
        "Step": steps,
        "Train": train_vals,
        "Test": test_vals
    })
    csv_path = save_path_base.replace(".png", "_data.csv").replace(".html", "_data.csv")
    df.to_csv(csv_path, index=False)
    print(f"Saved plot to {static_path} and interactive/data files to {interactive_path if 'interactive_path' in locals() else 'N/A'}")


In [11]:
def evaluate_unlabelled(model, data_loader):
    model.eval()
    preds = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits  # [B, num_classes]
            predictions = torch.argmax(logits, dim=1)  # get predicted class indices
            preds.append(predictions.cpu())

    return torch.cat(preds, dim=0)  # combine into a single tensor


In [12]:
def train(model, train_loader, optimizer, lr_scheduler, save_every_steps=200, output_dir="./checkpoints", epochs=None, max_steps=None):
    import time, os, torch
    os.makedirs(output_dir, exist_ok=True)
    recent_checkpoints = []

    total_training_time = 0
    total_steps = 0
    epoch = 0
    train_losses, train_accuracies = [], []
    test_losses, test_accuracies = [], []
    log_every_steps = 500  # how often to log + evaluate

    print("Starting training...\n")
    training_start = time.time()

    while True:
        if epochs is not None and epoch >= epochs:
            break

        model.train()
        epoch_start = time.time()
        epoch_loss = 0.0
        epoch_acc = 0.0
        epoch_steps = 0

        for batch_idx, batch in enumerate(train_loader):
            if max_steps is not None and total_steps >= max_steps:
                break

            step_start = time.time()

            optimizer.zero_grad()
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss, logits = outputs.loss, outputs.logits

            acc = get_accuracy(logits, labels)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
            epoch_steps += 1
            total_steps += 1
            
            if total_steps % log_every_steps == 0 or total_steps == max_steps:
                avg_train_loss = epoch_loss / epoch_steps
                avg_train_acc = epoch_acc / epoch_steps
                train_losses.append((total_steps, avg_train_loss))
                train_accuracies.append((total_steps, avg_train_acc))
            
                # Evaluate on test set
                test_loss, test_acc = evaluate(model, test_dataloader, return_values=True)
                test_losses.append((total_steps, test_loss))
                test_accuracies.append((total_steps, test_acc))
            
            loss.backward()
            # Grad Norm Clipping: 
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            if lr_scheduler is not None:
                lr_scheduler.step()
            
            step_end = time.time()
            step_time = step_end - step_start

            if total_steps % 250 == 0 or total_steps == 1:
                avg_loss = epoch_loss / epoch_steps
                avg_acc = epoch_acc / epoch_steps
                print(f"[Step {total_steps}] Avg Loss (epoch): {avg_loss:.4f} | Avg Acc (epoch): {avg_acc:.4f} | Step Time: {step_time:.2f}s")

            if total_steps % save_every_steps == 0 or total_steps == max_steps:
                ckpt_path = os.path.join(output_dir, f"model_step_{total_steps}.pt")
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'step': total_steps,
                    'epoch': epoch,
                }, ckpt_path)

                print(f"Checkpoint saved: {ckpt_path}")

                recent_checkpoints.append(ckpt_path)
                if len(recent_checkpoints) > 1:
                    old_ckpt = recent_checkpoints.pop(0)
                    if os.path.exists(old_ckpt):
                        os.remove(old_ckpt)
                        print(f"Old checkpoint removed: {old_ckpt}")

        epoch_end = time.time()
        epoch_time = epoch_end - epoch_start
        total_training_time += epoch_time

        # Per-epoch average loss & accuracy
        avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
        avg_epoch_acc = epoch_acc / epoch_steps if epoch_steps > 0 else 0
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Avg Loss: {avg_epoch_loss:.4f}")
        print(f"  Avg Accuracy: {avg_epoch_acc:.4f}")
        print(f"  Epoch Time: {epoch_time:.2f}s\n")

        epoch += 1

        if max_steps is not None and total_steps >= max_steps:
            print(f"Reached max_steps={max_steps}, stopping training.")
            break
    ckpt_path = os.path.join(output_dir, f"model_step_{total_steps}_final.pt")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'step': total_steps,
        'epoch': epoch,
    }, ckpt_path)
    print(f"Final checkpoint saved: {ckpt_path}")

    overall_time = time.time() - training_start
    print(f"Training completed in {epoch} epoch(s)")
    print(f"Total training time: {overall_time:.2f}s")
    return train_losses, train_accuracies, test_losses, test_accuracies

In [14]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [15]:
import os
import pandas as pd

lora_r_values = [14, 14, 14]
lora_alpha_values = [28, 36, 48]
max_steps = None
save_every_steps = 1500
max_epochs = 4
if max_epochs is not None:
    num_training_steps = len(train_dataloader) * max_epochs
else:
    num_training_steps = max_steps


# Load unlabelled test set
unlabelled_df = pd.read_pickle("test_unlabelled.pkl")
tokenized_unlabelled = unlabelled_df.map(preprocess, batched=True, remove_columns=["text"])
tokenized_unlabelled.set_format(type='torch', columns=['input_ids', 'attention_mask'])
unlabelled_loader = DataLoader(tokenized_unlabelled, batch_size=batch_size_val, shuffle=False, collate_fn=data_collator)

for r, alpha in zip(lora_r_values, lora_alpha_values):
    tag = f"r{r}_alpha{alpha}"
    print(f"\n=== Training LoRA {tag} ===")
    model = setup_lora_model(base_model, r, alpha)
    print("Trainable Parameter Count: {}".format(count_parameters(model)))
    
    optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps
    )
    
    ckpt_dir = f"./results/{tag}/checkpoints/"
    train_losses, train_accuracies, test_losses, test_accuracies = train(model, train_dataloader, optimizer, save_every_steps=save_every_steps, output_dir=ckpt_dir, epochs=max_epochs, max_steps=max_steps, lr_scheduler=lr_scheduler)
    evaluate(model, test_dataloader, return_values=False)
    print(f"=== Inference for {tag} ===")
    inference_start = time.time()
    preds = evaluate_unlabelled(model, unlabelled_loader)

    output_dir = "./results"
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"inference_output_{tag}.csv")

    df_output = pd.DataFrame({
      'ID': range(len(preds)),
      'Label': preds.numpy()
    })
    df_output.to_csv(output_path, index=False)

    print(f"Predictions saved to {output_path}")
    inference_end = time.time()
    inference_time = inference_end - inference_start
    print(f"Inference time: {inference_time:.2f}s")
    
    plot_dir = f"./results/{tag}"
    os.makedirs(plot_dir, exist_ok=True)
    plot_metrics(train_losses, test_losses, f"{tag} - Loss", "Loss", f"./results/{tag}/loss_plot_{tag}.png")
    plot_metrics(train_accuracies, test_accuracies, f"{tag} - Accuracy", "Accuracy", f"./results/{tag}/accuracy_plot_{tag}.png")

    # Cleanup after inference
    del train_losses, train_accuracies, test_losses, test_accuracies
    del model
    del optimizer
    del preds
    del df_output
    torch.cuda.empty_cache()  # safe to call on CPU too

    gc.collect()


Map: 100%|█████████████████████████| 8000/8000 [00:01<00:00, 4661.94 examples/s]
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



=== Training LoRA r14_alpha28 ===
Trainable Parameter Count: 989956
Starting training...



Consider using tensor.detach() first. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/Scalar.cpp:23.)
  epoch_loss += loss.item()


[Step 1] Avg Loss (epoch): 1.3106 | Avg Acc (epoch): 0.5000 | Step Time: 5.01s


KeyboardInterrupt: 