

---



LORA START/ we begin from a bert checkpoint.

In [None]:
class LORALinear(nn.Module):
  def __init__(self, base_linear: nn.Linear, r = 8, alpha = 16, dropout = 0.1):
    super().__init__()
    assert isinstance(base_linear, nn.Linear)

    self.in_features  = base_linear.in_features
    self.out_features = base_linear.out_features
    self.r = r
    self.alpha = alpha
    self.scaling = alpha / r

    # tie to base weights (don’t copy, just reference)
    self.weight = base_linear.weight
    self.bias   = base_linear.bias
    for p in (self.weight, self.bias):
        if p is not None:
            p.requires_grad = False
    #!!!The new LoRA params (A, B) will default to CPU, but your backbone is on GPU → device mismatch at runtime.
    dev = self.weight.device
    self.A = nn.Parameter(torch.empty(r, self.in_features, device=dev))
    self.B = nn.Parameter(torch.zeros(self.out_features, r, device=dev))
    nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))  # B stays zero → starts as identity update

    self.dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity()

  def forward(self, x):
      base  = F.linear(x, self.weight, self.bias)   # [..., out] # x @ W0^T + b
      delta = F.linear(self.dropout(x), self.A)     # [..., r]   # x @ A^T
      delta = F.linear(delta, self.B)               # [..., out] # (x @ A^T) @ B^T
      return base + self.scaling * delta


In [None]:
#bert_model → your trained BERT backbone (the full PyTorch module).

#r → LoRA rank (low-rank dimension). Controls how much compression you get.

#alpha → LoRA scaling factor (often set to be ≥ r, helps preserve magnitude).

#lora_dropout → dropout applied to the A matrix input in LoRA.

#targets → names of the nn.Linear submodules inside attention blocks that we want to replace with LoRA versions. Here it’s "query" and "value".
def add_lora(bert, r=8, alpha=16, lora_dropout=0.1, targets=("query","value")):
  replaced = 0 #keeps track of how many layers we successfully injected LoRA into
  #.modules() → iterates over all submodules in bert_model, including nested ones.
  for module in bert.modules():
    if isinstance(module, MultiHeadSelfAttention):
      for name in targets:
        #For each target name (by default "query" and "value"), we’ll:
        #Grab that nn.Linear layer from the attention module.
        #Wrap it with LoRA.
        base = getattr(module, name) #dynamically fetch the attribute with the given name ("query" or "value") from the attention module.
        #For example, if the attention block has:
        #self.query = nn.Linear(hidden_size, hidden_size)
        #self.value = nn.Linear(hidden_size, hidden_size)
        #getattr(module, "query") → that Linear layer.
        #getattr(module, "value") → that Linear layer.
        wrapped = LORALinear(base, r=r, alpha=alpha,dropout = lora_dropout)
        #Creates a new LoRA-wrapped linear layer.

        #The base layer’s weights are frozen inside the LoRALinear wrapper.

        #This LoRA layer adds B×A on top of the frozen base.
        setattr(module,name,wrapped)
        replaced += 1
  print(f"LoRA injected into {replaced} Linear layers: {targets}")
  # This leaves our MultiHeadSelfAttention.forward unchanged — it still calls self.query(x), self.value(x), etc., which are now LoRA‑wrapped.

In [None]:
from transformers import AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Recreate the *same* model config you used for pretraining
#    (must match hidden_size, heads, layers, etc., exactly)
model_cfg = dict(
    vocab_size=len(AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)),
    hidden_size=512,                 # <-- use the values you actually trained with
    num_heads=8,
    num_layers=6,
    intermediate_size=2048,
    max_position_embeddings=128,
    type_vocab_size=2,
    dropout=0.1,
)

backbone = BERT(**model_cfg).to(device)

# 2) Load the checkpoint
ckpt_path = "bert_pretrain_epoch10.pt"  # <-- replace if needed
state = torch.load(ckpt_path, map_location=device)

# If you’re not 100% sure the keys match, use strict=False and print diffs.
ik = backbone.load_state_dict(state, strict=False)
print("Missing keys:", ik.missing_keys)
print("Unexpected keys:", ik.unexpected_keys)


add_lora(backbone, r=8, alpha=16, lora_dropout=0.1, targets=("query","value"))

Assume that we are fientuning for classification, so we will add a classification head

In [None]:
class ClassificationWithBERT(nn.Module):
  def __init__(self, backbone, num_labels = 2, dropout = 0.1):
    super().__init__()
    self.bert = backbone
    H = getattr(backbone, "hidden_size", None)
    if H is None:
        # Fallbacks that work with typical custom BERTs
        H = self.bert.embedding.token_embeds.embedding_dim
    self.dropout = nn.Dropout(dropout)
    self.classifier = nn.Linear(H,num_labels)

  def forward(self, input_ids, token_type_ids, attention_mask):
    x = self.bert.embedding(input_ids, token_type_ids)      # [B,L,H]
    x, _ = self.bert.encoder(x, attention_mask)             # [B,L,H]
    cls = x[:, 0]                                           # [B,H]
    return self.classifier(self.dropout(cls))

In [None]:
model = ClassificationWithBERT(backbone, num_labels=2).to(device)

In [None]:
# freeze everything first
for p in model.parameters():
    p.requires_grad = False

# unfreeze LoRA params + classifier
for m in model.modules():
    if isinstance(m, LORALinear):
        for p in m.parameters():
            p.requires_grad = True
for p in model.classifier.parameters(): #this is the self.classifier = nn.Linear(H,num_labels)
    p.requires_grad = True

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}  ({100*trainable/total:.2f}%)")


Now training

In [None]:
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

opt_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(opt_params, lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler(enabled=(device=="cuda"))

# train_loader should yield dict with input_ids, token_type_ids, attention_mask, labels
for epoch in range(3):
    model.train()
    running = 0.0
    for batch in train_loader:
        optimizer.zero_grad(set_to_none=True)
        input_ids      = batch["input_ids"].to(device)
        token_type_ids = batch["token_type_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels         = batch["label"].to(device)

        with autocast(enabled=(device=="cuda")):
            logits = model(input_ids, token_type_ids, attention_mask)
            loss   = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(opt_params, 1.0)
        scaler.step(optimizer)
        scaler.update()
        running += loss.item()
    print(f"Epoch {epoch+1} | loss {running/len(train_loader):.4f}")


In [None]:
def lora_and_head_state_dict(model):
    keep = {}
    for k, v in model.state_dict().items():
        # keep LoRA params and classifier head weights
        if (k.endswith(".A") or k.endswith(".B") or k.startswith("classifier.")):
            keep[k] = v
    return keep

torch.save(lora_and_head_state_dict(model), "bert_lora_cls.pt")
print("Saved LoRA adapters + head → bert_lora_cls.pt")


In [None]:
# rebuild backbone from pretrain ckpt
backbone2 = BERT(**model_cfg).to(device)
_ = backbone2.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
add_lora_to_bert(backbone2, r=8, alpha=16, lora_dropout=0.1, targets=("query","value"))
model2 = BERTForSequenceClassification(backbone2, num_labels=2).to(device)

# load adapters+head
state = torch.load("bert_lora_cls.pt", map_location=device)
_ = model2.load_state_dict(state, strict=False)   # fills A/B and classifier.*
model2.eval()


In [None]:
@torch.no_grad()
def merge_lora_into_base(model):
    merged = 0
    for m in model.modules():
        if isinstance(m, LoRALinear):
            # effective W = W0 + scaling * (B @ A)
            update = m.scaling * (m.B @ m.A)            # [out, in]
            m.weight.add_(update)                       # in-place add to base
            # zero-out adapters so forward == base
            m.A.zero_(); m.B.zero_()
            merged += 1
    print(f"Merged {merged} LoRA layers into base weights.")

# Example:
# merge_lora_into_base(model)     # do this on the classifier-wrapped model


SQuAd

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
CKPT_DIR = "/content/drive/MyDrive/lora_checkpoints"
import os
os.makedirs(CKPT_DIR, exist_ok=True)


In [None]:
!pip install -q "transformers>=4.43" "datasets>=2.20" accelerate evaluate

import os, math, re, string, collections
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, DistilBertModel
from datasets import load_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hdevice: cuda


In [None]:
raw = load_dataset("squad")  # v1.1
train_small = raw["train"].shuffle(seed=42).select(range(int(0.3 * len(raw["train"]))))
dev = raw["validation"]

print("Train(30%):", len(train_small), "| Dev:", len(dev))

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
max_len, doc_stride = 384, 128


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Train(30%): 26279 | Dev: 10570


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
def preprocess_train(examples):
    # Tokenize with sliding window
    tokenized = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_len,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map   = tokenized.pop("overflow_to_sample_mapping")
    offset_maps  = tokenized["offset_mapping"]

    start_positions = []
    end_positions   = []

    for i, offsets in enumerate(offset_maps):
        # CLS index for "no answer in this window"
        input_ids = tokenized["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        sample_idx  = sample_map[i]
        answers     = examples["answers"][sample_idx]
        ans_text    = answers["text"][0]
        ans_start   = answers["answer_start"][0]
        ans_end     = ans_start + len(ans_text)

        seq_ids = tokenized.sequence_ids(i)
        # context token range
        ctx_start = next(k for k,s in enumerate(seq_ids) if s == 1)
        ctx_end   = max(k for k,s in enumerate(seq_ids) if s == 1)

        # if answer not fully inside this span → point to CLS
        if (ans_end <= offsets[ctx_start][0]) or (ans_start >= offsets[ctx_end][1]):
            start_positions.append(cls_index)
            end_positions.append(cls_index)
            continue

        # otherwise, find token start
        token_start = ctx_start
        while token_start <= ctx_end and offsets[token_start][0] <= ans_start:
            token_start += 1
        token_start -= 1

        # and token end
        token_end = ctx_end
        while token_end >= ctx_start and offsets[token_end][1] >= ans_end:
            token_end -= 1
        token_end += 1

        start_positions.append(token_start)
        end_positions.append(token_end)

    tokenized["start_positions"] = start_positions
    tokenized["end_positions"]   = end_positions
    # we don't need offsets for training
    tokenized.pop("offset_mapping")
    return tokenized

train_proc = train_small.map(
    preprocess_train,
    batched=True,
    remove_columns=train_small.column_names,
)


Map:   0%|          | 0/26279 [00:00<?, ? examples/s]

In [None]:
def preprocess_val(examples):
    tokenized = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_len,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = tokenized.pop("overflow_to_sample_mapping")
    # keep example ids to regroup features → examples["id"][sample_idx]
    example_ids = []
    new_offsets = []

    for i, offsets in enumerate(tokenized["offset_mapping"]):
        seq_ids = tokenized.sequence_ids(i)
        # set offsets to None for non-context tokens (so we ignore them at prediction time)
        mapped = []
        for k, off in enumerate(offsets):
            if seq_ids[k] != 1:
                mapped.append((None, None))
            else:
                mapped.append(off)
        new_offsets.append(mapped)
        example_ids.append(examples["id"][sample_map[i]])

    tokenized["example_id"]    = example_ids
    tokenized["offset_mapping"] = new_offsets
    return tokenized

dev_proc = dev.map(
    preprocess_val,
    batched=True,
    remove_columns=dev.column_names,
)


Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [None]:
class LORALinear(nn.Module):
    def __init__(self, base_linear: nn.Linear, r=8, alpha=16, dropout=0.1):
        super().__init__()
        self.in_features  = base_linear.in_features
        self.out_features = base_linear.out_features
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r

        self.weight = base_linear.weight
        self.bias   = base_linear.bias
        for p in (self.weight, self.bias):
            if p is not None: p.requires_grad = False

        dev = self.weight.device
        self.A = nn.Parameter(torch.empty(r, self.in_features, device=dev))
        self.B = nn.Parameter(torch.zeros(self.out_features, r, device=dev))
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        self.dropout = nn.Dropout(dropout) if dropout and dropout > 0 else nn.Identity()

    def forward(self, x):
        base  = F.linear(x, self.weight, self.bias)
        delta = F.linear(self.dropout(x), self.A)
        delta = F.linear(delta, self.B)
        return base + self.scaling * delta

# DistilBERT attention class (robust across versions)
DistilBertSelfAttention = type(
    DistilBertModel.from_pretrained("distilbert-base-uncased").transformer.layer[0].attention
)

def add_lora_to_distilbert(model, r=8, alpha=16, lora_dropout=0.1, targets=("q_lin","v_lin")):
    replaced = 0
    for m in model.modules():
        if isinstance(m, DistilBertSelfAttention):
            for name in targets:
                base = getattr(m, name, None)
                if isinstance(base, nn.Linear):
                    setattr(m, name, LORALinear(base, r=r, alpha=alpha, dropout=lora_dropout))
                    replaced += 1
    print(f"LoRA injected into {replaced} Linear(s): {targets}")


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [None]:
qa = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased").to(device)
add_lora_to_distilbert(qa, r=8, alpha=16, lora_dropout=0.1, targets=("q_lin","v_lin"))

for p in qa.parameters(): p.requires_grad = False
for m in qa.modules():
    if isinstance(m, LORALinear):
        for p in m.parameters(): p.requires_grad = True
for p in qa.qa_outputs.parameters(): p.requires_grad = True

trainable = sum(p.numel() for p in qa.parameters() if p.requires_grad)
total     = sum(p.numel() for p in qa.parameters())
print(f"Trainable {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")


Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LoRA injected into 12 Linear(s): ('q_lin', 'v_lin')
Trainable 7,236,098 / 66,511,874 (10.88%)


In [None]:
def collate(features):
    return {
        "input_ids": torch.tensor([f["input_ids"] for f in features]),
        "attention_mask": torch.tensor([f["attention_mask"] for f in features]),
        "start_positions": torch.tensor([f["start_positions"] for f in features]) if "start_positions" in features[0] else None,
        "end_positions": torch.tensor([f["end_positions"] for f in features]) if "end_positions" in features[0] else None,
    }

train_loader = DataLoader(train_proc, batch_size=12, shuffle=True,  collate_fn=collate)
dev_loader   = DataLoader(dev_proc,   batch_size=32, shuffle=False, collate_fn=lambda b: {
    "input_ids": torch.tensor([f["input_ids"] for f in b]),
    "attention_mask": torch.tensor([f["attention_mask"] for f in b]),
    "offset_mapping": [f["offset_mapping"] for f in b],
    "example_id": [f["example_id"] for f in b],
})


In [None]:
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

opt_params = [p for p in qa.parameters() if p.requires_grad]
optimizer = AdamW(opt_params, lr=5e-4, weight_decay=0.01)
scaler = GradScaler(enabled=(device=="cuda"))

for epoch in range(10):
    qa.train()
    running = 0.0
    for batch in train_loader:
        batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k,v in batch.items()}
        optimizer.zero_grad(set_to_none=True)
        with autocast(enabled=(device=="cuda")):
            out = qa(input_ids=batch["input_ids"],
                     attention_mask=batch["attention_mask"],
                     start_positions=batch["start_positions"],
                     end_positions=batch["end_positions"])
            loss = out.loss
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(opt_params, 1.0)
        scaler.step(optimizer)
        scaler.update()
        running += loss.item()
    print(f"Epoch {epoch+1} | train loss {running/len(train_loader):.4f}")


  scaler = GradScaler(enabled=(device=="cuda"))
  with autocast(enabled=(device=="cuda")):


Epoch 1 | train loss 2.1148
Epoch 2 | train loss 1.5385
Epoch 3 | train loss 1.3139
Epoch 4 | train loss 1.1686
Epoch 5 | train loss 1.0623
Epoch 6 | train loss 0.9867
Epoch 7 | train loss 0.9308
Epoch 8 | train loss 0.8795
Epoch 9 | train loss 0.8336
Epoch 10 | train loss 0.8148


In [None]:
def normalize_text(s):
    def remove_articles(t): return re.sub(r"\b(a|an|the)\b", " ", t)
    def white_space_fix(t): return " ".join(t.split())
    def remove_punc(t):    return "".join(ch for ch in t if ch not in set(string.punctuation))
    return white_space_fix(remove_articles(remove_punc(s.lower())))

def f1_score(prediction, ground_truth):
    pred_tokens  = normalize_text(prediction).split()
    truth_tokens = normalize_text(ground_truth).split()
    if not pred_tokens or not truth_tokens: return float(pred_tokens == truth_tokens)
    common = collections.Counter(pred_tokens) & collections.Counter(truth_tokens)
    num_same = sum(common.values())
    if num_same == 0: return 0.0
    prec = num_same / len(pred_tokens)
    rec  = num_same / len(truth_tokens)
    return 2 * prec * rec / (prec + rec)

def exact_match_score(prediction, ground_truth):
    return float(normalize_text(prediction) == normalize_text(ground_truth))

@torch.no_grad()
def evaluate(model, tokenizer, dev_ds, dev_proc):
    model.eval()

    # Build id -> context dict once (fast lookup)
    id2ctx  = {ex["id"]: ex["context"] for ex in dev_ds}
    id2gold = {ex["id"]: ex["answers"]["text"][0] for ex in dev_ds}

    all_starts, all_ends, all_examples, all_offsets = [], [], [], []
    for batch in dev_loader:  # uses your existing dev_loader built from dev_proc
        out = model(input_ids=batch["input_ids"].to(device),
                    attention_mask=batch["attention_mask"].to(device))
        all_starts.append(out.start_logits.cpu())
        all_ends.append(out.end_logits.cpu())
        all_examples.extend(batch["example_id"])
        all_offsets.extend(batch["offset_mapping"])

    start_logits = torch.cat(all_starts, dim=0)
    end_logits   = torch.cat(all_ends,   dim=0)

    pred_text_by_id = {}
    score_by_id     = {}

    for i, ex_id in enumerate(all_examples):
        ex_id = str(ex_id)  # ensure string ids
        offsets = all_offsets[i]
        # keep only context tokens
        valid = [j for j,(a,b) in enumerate(offsets) if a is not None and b is not None]
        if not valid:
            continue

        s = start_logits[i][valid]
        e = end_logits[i][valid]
        s_idx = int(torch.argmax(s))
        e_idx = int(torch.argmax(e[s_idx:])) + s_idx
        if e_idx < s_idx:
            e_idx = s_idx
        span_score = float(s[s_idx] + e[e_idx])

        char_start, char_end = offsets[valid[s_idx]][0], offsets[valid[e_idx]][1]
        ctx = id2ctx[ex_id]
        pred_text = ctx[char_start:char_end]

        if (ex_id not in score_by_id) or (span_score > score_by_id[ex_id]):
            score_by_id[ex_id] = span_score
            pred_text_by_id[ex_id] = pred_text

    # Metrics
    em = f1 = 0.0
    for ex_id, gold in id2gold.items():
        pred = pred_text_by_id.get(ex_id, "")
        em  += exact_match_score(pred, gold)
        f1  += f1_score(pred, gold)

    n = len(dev_ds)
    return 100*em/n, 100*f1/n


In [None]:
em, f1 = evaluate(qa, tokenizer, dev, dev_proc)
print(f"Dev EM: {em:.2f} | F1: {f1:.2f}")


Dev EM: 50.17 | F1: 65.01


In [None]:
def lora_and_head_state_dict(model):
    keep = {}
    for k, v in model.state_dict().items():
        if k.endswith(".A") or k.endswith(".B") or k.startswith("qa_outputs."):
            keep[k] = v
    return keep

save_path = os.path.join(CKPT_DIR, "distilbert_squad_lora.pt")
torch.save(lora_and_head_state_dict(qa), save_path)
print("Saved →", save_path)


Saved → /content/drive/MyDrive/lora_checkpoints/distilbert_squad_lora.pt
