In [1]:
!pip -q install -U transformers datasets peft accelerate evaluate


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/512.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.3/512.3 kB[0m [31m37.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [1]:
import os, gc, math, random
from dataclasses import dataclass
from typing import List, Dict, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from collections import defaultdict


SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)


torch.backends.cuda.matmul.fp32_precision = "tf32"
torch.backends.cudnn.conv.fp32_precision = "tf32"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16

print("DEVICE:", DEVICE, "| dtype:", DTYPE)
!nvidia-smi -L


DEVICE: cuda | dtype: torch.bfloat16
GPU 0: NVIDIA A100-SXM4-80GB (UUID: GPU-d5a1c6c8-adf0-688b-c297-52c168e36ee6)


In [2]:

BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
BASE_MODEL_7B = "mistralai/Mistral-7B-Instruct-v0.2"


CUTOFF_LEN = 512
LR = 2e-4
WEIGHT_DECAY = 0.0
DROPOUT = 0.05


MICRO_BS = 1
GRAD_ACCUM = 16        # MICRO_BS * GRAD_ACCUM = 16 (paper effective batch)

CLIP_NORM = 1.0
WARMUP_RATIO = 0.03


EPOCHS = 3             # paper: 2
MAX_TRAIN_SAMPLES = 1119
MAX_EVAL_SAMPLES = 299


LORA_R_BASE, LORA_A_BASE = 8, 16
LORA_R_MIX,  LORA_A_MIX  = 8, 16 


NUM_EXPERTS = 8
TOP_K = 2
AUX_LOSS_COEF = 1e-3


USE_GRAD_CHECKPOINTING = True

print("Effective batch:", MICRO_BS * GRAD_ACCUM)


Effective batch: 16


Loading a pre-trained model and tokenizer

In [3]:
def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

def load_model(model_name: str):
    tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=DTYPE,
        device_map="cuda",
    )

    model.config.use_cache = False
    if USE_GRAD_CHECKPOINTING:
        model.gradient_checkpointing_enable()

    # Freeze all base weights
    for p in model.parameters():
        p.requires_grad = False

    return tok, model

def print_trainable(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {trainable/1e6:.2f}M / Total: {total/1e6:.2f}M ({100*trainable/total:.2f}%)")


Loading dataset and formating the prompt format

In [4]:
arc = load_dataset("allenai/ai2_arc", "ARC-Challenge")
train_raw = arc["train"]
val_raw = arc["validation"]
print("ARC-C train:", len(train_raw), "| val:", len(val_raw))

def format_dataset(ex) -> Tuple[str, str]:
    q = ex["question"].strip()
    labels = ex["choices"]["label"]
    texts  = ex["choices"]["text"]
    gold   = ex["answerKey"].strip()

    options = "\n".join([f"{lab}. {txt.strip()}" for lab, txt in zip(labels, texts)])


    prompt_body = (
        "You are given a multiple-choice science question.\n"
        "Choose the correct option.\n\n"
        f"Question: {q}\n\n"
        f"{options}\n"
    )
    return prompt_body, gold



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

ARC-Challenge/train-00000-of-00001.parqu(…):   0%|          | 0.00/190k [00:00<?, ?B/s]

ARC-Challenge/test-00000-of-00001.parque(…):   0%|          | 0.00/204k [00:00<?, ?B/s]

ARC-Challenge/validation-00000-of-00001.(…):   0%|          | 0.00/55.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1119 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1172 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/299 [00:00<?, ? examples/s]

ARC-C train: 1119 | val: 299


In [14]:
def tokenizing(tok, prompt_body: str, answer_letter: str):
    '''
    tokenizing the formatted question and its correct answer,
    creating input IDs, attention masks, and labels suitable
    for causal language modeling
    '''

    target = f"Final Answer: {answer_letter}"
    full = prompt_body + target

    enc_full = tok(full, truncation=True, max_length=CUTOFF_LEN)
    enc_prompt = tok(prompt_body, truncation=True, max_length=CUTOFF_LEN)

    input_ids = enc_full["input_ids"]
    attn_mask = enc_full["attention_mask"]
    labels = input_ids.copy()

    prompt_len = len(enc_prompt["input_ids"])
    labels[:prompt_len] = [-100] * prompt_len

    return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}




def padding(tok, batch):
    '''
    padding ensuring all inputs have uniform dimensions for batch processing
    '''

    max_len = max(len(x["input_ids"]) for x in batch)

    def pad(seq, pad_id):
        return seq + [pad_id] * (max_len - len(seq))

    input_ids = torch.tensor([pad(x["input_ids"], tok.pad_token_id) for x in batch], dtype=torch.long)
    attention_mask = torch.tensor([pad(x["attention_mask"], 0) for x in batch], dtype=torch.long)
    labels = torch.tensor([pad(x["labels"], -100) for x in batch], dtype=torch.long)

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}






def make_train_loader(tok, raw_split, max_samples: int):
    '''
    It balances the distribution of answer choices (A, B, C, D, E)
    to prevent bias and creates a batch of encoded examples.
    '''

    buckets = defaultdict(list)
    for i in range(len(raw_split)):
        _, gold = format_dataset(raw_split[i])
        buckets[gold].append(i)

    letters = ["A","B","C","D","E"]
    present = [L for L in letters if len(buckets[L]) > 0]


    per_class = max(1, max_samples // len(present))

    indices = []
    for L in present:
        random.shuffle(buckets[L])
        indices.extend(buckets[L][:per_class])

    random.shuffle(indices)
    indices = indices[:max_samples]

    examples = []
    for idx in indices:
        prompt, gold = format_dataset(raw_split[idx])
        examples.append(tokenizing(tok, prompt, gold))

    return DataLoader(
        examples,
        batch_size=MICRO_BS,
        shuffle=True,
        collate_fn=lambda b: padding(tok, b),
    )



Evaluating the model

In [6]:
@torch.no_grad()
def scoring(tok, model, prompt: str, continuation: str) -> float:
    '''
    calculating the log-likelihood of a given continuation following a prompt. 
    It's used to determine which answer option the model finds most probable.
    '''
    model.eval()

    enc_prompt = tok(prompt, return_tensors="pt", truncation=True, max_length=CUTOFF_LEN).to(model.device)
    enc_full   = tok(prompt + continuation, return_tensors="pt", truncation=True, max_length=CUTOFF_LEN).to(model.device)

    input_ids = enc_full["input_ids"]
    prompt_len = enc_prompt["input_ids"].shape[1]

    out = model(input_ids=input_ids)
    logits = out.logits  # [1, L, V]

    ll = 0.0
    for pos in range(prompt_len, input_ids.shape[1]):
        token_id = int(input_ids[0, pos].item())
        ll += float(torch.log_softmax(logits[0, pos - 1], dim=-1)[token_id].item())

    return ll


@torch.no_grad()
def evaluating(tok, model, raw_split, max_eval: int = None) -> float:
    '''
    evaluating the model's accuracy on the dataset.
    '''
    model.eval()

    if max_eval is None:
        max_eval = MAX_EVAL_SAMPLES

    letters = ["A", "B", "C", "D", "E"]
    total = min(len(raw_split), max_eval)
    correct = 0

    prompt_suffix = "Final Answer:"

    for i in range(total):
        prompt_body, gold = format_dataset(raw_split[i])
        prompt = prompt_body + prompt_suffix

        scores = {L: scoring(tok, model, prompt, " " + L) for L in letters}
        pred = max(scores, key=scores.get)

        correct += int(pred == gold)

    return correct / total


Training

In [15]:
def mix_aux_loss(model) -> torch.Tensor:
    '''
    collecting auxiliary loss components from Mixture of Experts (MoE)
    layers within the model. This auxiliary loss encourages balanced
    expert usage during training.
    '''

    aux_list = []
    for m in model.modules():
        if hasattr(m, "_aux") and m._aux is not None:
            aux_list.append(m._aux)
    if not aux_list:
        return torch.tensor(0.0, device=model.device)
    return torch.stack(aux_list).mean()



def print_peak_vram(prefix=""):
    peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)

    peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
    print(f"{prefix} peak_alloc={peak_alloc:.2f} GB | peak_reserved={peak_reserved:.2f} GB")



def train_model(tok, model, method_name: str):
    train_loader = make_train_loader(tok, train_raw, MAX_TRAIN_SAMPLES)  # <-- FIX

    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=LR, weight_decay=WEIGHT_DECAY)

    best_acc = -1.0
    model.train()

    torch.cuda.reset_peak_memory_stats()  # track peak VRAM for this run

    for epoch in range(EPOCHS):
        opt.zero_grad(set_to_none=True)

        for i, batch in enumerate(train_loader):
            batch = {k: v.to(model.device) for k, v in batch.items()}
            out = model(**batch)
            loss = out.loss

            (loss / GRAD_ACCUM).backward()

            if (i + 1) % GRAD_ACCUM == 0:
                torch.nn.utils.clip_grad_norm_(params, 1.0)
                opt.step()
                opt.zero_grad(set_to_none=True)

        acc = evaluating(tok, model, val_raw)
        print(f"[{method_name}] epoch {epoch+1}/{EPOCHS} | val acc={acc:.3f}")

        # Peak VRAM used so far (allocated + reserved)
        peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
        peak_resv  = torch.cuda.max_memory_reserved() / (1024**3)
        print(f"[{method_name}] epoch {epoch+1} | peak_alloc={peak_alloc:.2f} GB | peak_reserved={peak_resv:.2f} GB")

        best_acc = max(best_acc, acc)

    print(f"[{method_name}] BEST val acc={best_acc:.3f}")
    return best_acc




In [8]:
def get_base_model_layers(model) -> List[nn.Module]:
    '''
    extract the transformer layers from the base model.
    '''

    base = model.get_base_model() if hasattr(model, "get_base_model") else model
    return base.model.layers

class LoRADelta(nn.Module):
    """
    low-rank adaptation matrix
    """
    def __init__(self, in_f, out_f, r, alpha, dropout):
        super().__init__()
        self.scale = alpha / r
        self.drop = nn.Dropout(dropout)
        self.A = nn.Linear(in_f, r, bias=False)
        self.B = nn.Linear(r, out_f, bias=False)
        nn.init.kaiming_uniform_(self.A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.B.weight)

    def forward(self, x):
        return self.B(self.A(self.drop(x))) * self.scale

class DoRAProj(nn.Module):
    """
    Simple DoRA-like projection for an expert:
      y_dir = W0(x) + Δ(x)
      y = y_dir * (m / W0)  (per output channel)
    """
    def __init__(self, base_linear: nn.Linear, r, alpha, dropout):
        super().__init__()
        self.base = base_linear  # frozen
        self.delta = LoRADelta(base_linear.in_features, base_linear.out_features, r, alpha, dropout)

        with torch.no_grad():
            w = self.base.weight.data
            row_norm = torch.norm(w, dim=1) + 1e-8
        self.m = nn.Parameter(row_norm)

    def forward(self, x):
        y_dir = self.base(x) + self.delta(x)
        with torch.no_grad():
            row_norm = torch.norm(self.base.weight.data, dim=1) + 1e-8
        scale = (self.m / row_norm).to(y_dir.dtype)
        return y_dir * scale

class MixFFN(nn.Module):
    """
    Wrap a LLaMA-style MLP (gate_proj, up_proj, down_proj), distributing tokens to multiple experts with:
      - shared frozen FFN (base MLP)
      - router top-2 (token-wise)
      - 8 experts (LoRA or DoRA style)
      - aux loss for load balancing (stored as self._aux)
    """
    def __init__(self, base_mlp: nn.Module, hidden_size: int, intermediate_size: int, expert_mode: str):
        super().__init__()
        self.base = base_mlp  # frozen
        self.E = NUM_EXPERTS
        self.K = TOP_K

        # Router
        self.router = nn.Linear(hidden_size, self.E, bias=False).to(torch.float32)

        # Experts
        self.experts = nn.ModuleList()
        for _ in range(self.E):
            if expert_mode == "lora":
                self.experts.append(nn.ModuleDict({
                    "gate_d": LoRADelta(hidden_size, intermediate_size, LORA_R_MIX, LORA_A_MIX, DROPOUT),
                    "up_d":   LoRADelta(hidden_size, intermediate_size, LORA_R_MIX, LORA_A_MIX, DROPOUT),
                    "down_d": LoRADelta(intermediate_size, hidden_size, LORA_R_MIX, LORA_A_MIX, DROPOUT),
                }))
            elif expert_mode == "dora":
                self.experts.append(nn.ModuleDict({
                    "gate": DoRAProj(self.base.gate_proj, LORA_R_MIX, LORA_A_MIX, DROPOUT),
                    "up":   DoRAProj(self.base.up_proj,   LORA_R_MIX, LORA_A_MIX, DROPOUT),
                    "down": DoRAProj(self.base.down_proj, LORA_R_MIX, LORA_A_MIX, DROPOUT),
                }))
            else:
                raise ValueError("expert_mode must be 'lora' or 'dora'")

        self._aux = None  # filled in forward

    def forward(self, x):
        b, s, h = x.shape
        T = b * s
        x_flat = x.reshape(T, h)

        # Router probs
        logits = self.router(x_flat.to(torch.float32))
        probs = torch.softmax(logits, dim=-1)  # (T, E)

        # Top-k routing
        topv, topi = torch.topk(probs, k=self.K, dim=-1)  # (T, K)
        topv = topv / (topv.sum(dim=-1, keepdim=True) + 1e-9)

        # Aux loss: encourage balanced expert usage
        with torch.no_grad():
            assign = torch.zeros((T, self.E), device=x.device)
            assign.scatter_add_(1, topi, torch.ones_like(topv, device=x.device))
            f = assign.sum(dim=0) / (T * self.K + 1e-9)         # fraction of assignments
            p = probs.to(x.device).mean(dim=0)                  # mean prob mass
        self._aux = self.E * torch.sum(f * p)

        # Shared base FFN output (frozen)
        base_out = self.base(x).reshape(T, h)
        out_flat = base_out.clone()

        # Sparse expert compute: only tokens routed to expert e
        for e in range(self.E):
            mask = (topi == e)  # (T, K)
            if not mask.any():
                continue

            token_ids, slot_ids = torch.where(mask)
            w = topv[token_ids, slot_ids].to(base_out.dtype).unsqueeze(-1)  # (N,1)
            x_tok = x_flat[token_ids].to(base_out.dtype)

            if "gate" in self.experts[e]:
                # MixDoRA experts: full expert projection
                gate_e = self.experts[e]["gate"](x_tok)
                up_e   = self.experts[e]["up"](x_tok)
                hid_e  = F.silu(gate_e) * up_e
                y_e    = self.experts[e]["down"](hid_e)

                out_flat[token_ids] += w * (y_e - base_out[token_ids])
            else:
                # MixLoRA experts: use deltas on top of base projections
                gate0 = self.base.gate_proj(x_tok)
                up0   = self.base.up_proj(x_tok)

                gate = gate0 + self.experts[e]["gate_d"](x_tok)
                up   = up0   + self.experts[e]["up_d"](x_tok)

                hid = F.silu(gate) * up

                y0 = self.base.down_proj(hid)
                y  = y0 + self.experts[e]["down_d"](hid)

                out_flat[token_ids] += w * (y - base_out[token_ids])

        return out_flat.reshape(b, s, h)

def inject_mix_ffn(model, expert_mode: str):
    '''
    replacing the standard FFN modules with the custom MixFFN modules,
     effectively injecting the Mixture of Experts architecture into the mode
    '''

    layers = get_base_model_layers(model)
    replaced = 0

    for layer in layers:
        base_mlp = layer.mlp
        hidden = base_mlp.gate_proj.in_features
        inter  = base_mlp.gate_proj.out_features

        mix = MixFFN(base_mlp, hidden, inter, expert_mode=expert_mode)


        mix = mix.to(device=model.device, dtype=DTYPE)


        mix.router = mix.router.to(dtype=torch.float32)

        layer.mlp = mix
        replaced += 1



In [9]:
def infer_target_modules(model):
    '''
    identifiing the specific linear projection layers
    '''

    want_suffixes = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
    found = set()

    for name, mod in model.named_modules():
        if not isinstance(mod, torch.nn.Linear):
            continue
        for s in want_suffixes:
            if name.endswith(s):
                found.add(s)


    return [s for s in want_suffixes if s in found]


In [10]:
def show_trainable_modules(model, keywords=("q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj")):
    '''
    inspect which specific modules within the model have
    trainable parameters after applying PEFT configurations.
    '''
    hits = {k: 0 for k in keywords}

    for name, p in model.named_parameters():
        if p.requires_grad:
            for k in keywords:
                if k in name:
                    hits[k] += 1

    print(hits)
    return hits

Single task fine-tuning with different methods

In [11]:
cleanup()
tok, model = load_model(BASE_MODEL)

TARGETS = infer_target_modules(model)

cfg = LoraConfig(
    r=LORA_R_BASE,
    lora_alpha=LORA_A_BASE,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TARGETS,
)
model = get_peft_model(model, cfg)
model.to(DEVICE)

show_trainable_modules(model)
print_trainable(model)

lora_acc = train_model(tok, model, "LoRA")
print("LoRA val acc:", lora_acc)


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

{'q_proj': 44, 'k_proj': 44, 'v_proj': 44, 'o_proj': 44, 'gate_proj': 44, 'up_proj': 44, 'down_proj': 44}
Trainable: 6.31M / Total: 1106.36M (0.57%)
[LoRA] epoch 1/3 | val acc=0.294
[LoRA] epoch 2/3 | val acc=0.247
[LoRA] epoch 3/3 | val acc=0.261
[LoRA] BEST val acc=0.294
LoRA val acc: 0.29431438127090304


In [12]:
cleanup()
tok, model = load_model(BASE_MODEL)

TARGETS = infer_target_modules(model)

cfg = LoraConfig(
    r=LORA_R_BASE,
    lora_alpha=LORA_A_BASE,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TARGETS,
)
setattr(cfg, "use_dora", True)

model = get_peft_model(model, cfg)
model.to(DEVICE)

show_trainable_modules(model)
print_trainable(model)

dora_acc = train_model(tok, model, "DoRA")
print("DoRA val acc:", dora_acc)


{'q_proj': 66, 'k_proj': 66, 'v_proj': 66, 'o_proj': 66, 'gate_proj': 66, 'up_proj': 66, 'down_proj': 66}
Trainable: 6.70M / Total: 1106.75M (0.61%)
[DoRA] epoch 1/3 | val acc=0.234
[DoRA] epoch 2/3 | val acc=0.207
[DoRA] epoch 3/3 | val acc=0.278
[DoRA] BEST val acc=0.278
DoRA val acc: 0.27759197324414714


In [13]:
cleanup()
tok, model = load_model(BASE_MODEL)

# Attention adapters: independent, non-expert LoRA on Q,K,V,O
cfg_attn = LoraConfig(
    r=LORA_R_MIX,
    lora_alpha=LORA_A_MIX,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
)
model = get_peft_model(model, cfg_attn)
model.to(DEVICE)

# FFN becomes MixLoRA experts (8 experts, top-2)
inject_mix_ffn(model, expert_mode="lora")

# Ensure router + expert params trainable (base FFN remains frozen)
for name, p in model.named_parameters():
    if ("router" in name) or ("experts" in name) or ("lora_" in name):
        p.requires_grad = True

print_trainable(model)
mixlora_acc = train_model(tok, model, "MixLoRA")
print("MixLoRA val acc:", mixlora_acc)


Trainable: 35.05M / Total: 1135.10M (3.09%)
[MixLoRA] epoch 1/3 | val acc=0.244
[MixLoRA] epoch 2/3 | val acc=0.274
[MixLoRA] epoch 3/3 | val acc=0.237
[MixLoRA] BEST val acc=0.274
MixLoRA val acc: 0.27424749163879597


In [14]:
cleanup()
tok, model = load_model(BASE_MODEL)

# Attention adapters: DoRA on Q,K,V,O
cfg_attn = LoraConfig(
    r=LORA_R_MIX,
    lora_alpha=LORA_A_MIX,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
)
setattr(cfg_attn, "use_dora", True)
model = get_peft_model(model, cfg_attn)
model.to(DEVICE)

# FFN becomes MixDoRA experts
inject_mix_ffn(model, expert_mode="dora")

# Ensure router + expert params trainable
for name, p in model.named_parameters():
    if ("router" in name) or ("experts" in name) or ("lora_" in name) or name.endswith(".m"):
        p.requires_grad = True

print_trainable(model)
mixdora_acc = train_model(tok, model, "MixDoRA")
print("MixDoRA val acc:", mixdora_acc)


Trainable: 37.50M / Total: 1137.55M (3.30%)
[MixDoRA] epoch 1/3 | val acc=0.251
[MixDoRA] epoch 2/3 | val acc=0.291
[MixDoRA] epoch 3/3 | val acc=0.284
[MixDoRA] BEST val acc=0.291
MixDoRA val acc: 0.2909698996655518


Single task fine-tuning with different methods - Mistral 7b

In [19]:
cleanup()
tok, model = load_model(BASE_MODEL_7B)

TARGETS = infer_target_modules(model)

cfg = LoraConfig(
    r=LORA_R_BASE,
    lora_alpha=LORA_A_BASE,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TARGETS,
)
model = get_peft_model(model, cfg)
model.to(DEVICE)

show_trainable_modules(model)
print_trainable(model)

lora_acc = train_model(tok, model, "LoRA")
print("LoRA val acc:", lora_acc)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

{'q_proj': 64, 'k_proj': 64, 'v_proj': 64, 'o_proj': 64, 'gate_proj': 64, 'up_proj': 64, 'down_proj': 64}
Trainable: 20.97M / Total: 7262.70M (0.29%)
[LoRA] epoch 1/3 | val acc=0.672
[LoRA] epoch 1 | peak_alloc=41.37 GB | peak_reserved=55.15 GB
[LoRA] epoch 2/3 | val acc=0.709
[LoRA] epoch 2 | peak_alloc=43.19 GB | peak_reserved=55.15 GB
[LoRA] epoch 3/3 | val acc=0.726
[LoRA] epoch 3 | peak_alloc=43.19 GB | peak_reserved=55.15 GB
[LoRA] BEST val acc=0.726
LoRA val acc: 0.725752508361204


In [16]:
cleanup()
tok, model = load_model(BASE_MODEL_7B)

TARGETS = infer_target_modules(model)

cfg = LoraConfig(
    r=LORA_R_BASE,
    lora_alpha=LORA_A_BASE,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TARGETS,
)
setattr(cfg, "use_dora", True)

model = get_peft_model(model, cfg)
model.to(DEVICE)

show_trainable_modules(model)
print_trainable(model)

dora_acc = train_model(tok, model, "DoRA")
print("DoRA val acc:", dora_acc)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

{'q_proj': 96, 'k_proj': 96, 'v_proj': 96, 'o_proj': 96, 'gate_proj': 96, 'up_proj': 96, 'down_proj': 96}
Trainable: 22.35M / Total: 7264.08M (0.31%)
[DoRA] epoch 1/3 | val acc=0.753
[DoRA] epoch 1 | peak_alloc=43.44 GB | peak_reserved=43.70 GB
[DoRA] epoch 2/3 | val acc=0.756
[DoRA] epoch 2 | peak_alloc=46.21 GB | peak_reserved=46.73 GB
[DoRA] epoch 3/3 | val acc=0.719
[DoRA] epoch 3 | peak_alloc=46.21 GB | peak_reserved=46.73 GB
[DoRA] BEST val acc=0.756
DoRA val acc: 0.7558528428093646


In [17]:
cleanup()
tok, model = load_model(BASE_MODEL_7B)

# Attention adapters: independent, non-expert LoRA on Q,K,V,O
cfg_attn = LoraConfig(
    r=LORA_R_MIX,
    lora_alpha=LORA_A_MIX,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
)
model = get_peft_model(model, cfg_attn)
model.to(DEVICE)

# FFN becomes MixLoRA experts (8 experts, top-2)
inject_mix_ffn(model, expert_mode="lora")

# Ensure router + expert params trainable (base FFN remains frozen)
for name, p in model.named_parameters():
    if ("router" in name) or ("experts" in name) or ("lora_" in name):
        p.requires_grad = True

print_trainable(model)
mixlora_acc = train_model(tok, model, "MixLoRA")
print("MixLoRA val acc:", mixlora_acc)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Trainable: 121.11M / Total: 7362.84M (1.64%)
[MixLoRA] epoch 1/3 | val acc=0.645
[MixLoRA] epoch 1 | peak_alloc=41.86 GB | peak_reserved=55.41 GB
[MixLoRA] epoch 2/3 | val acc=0.739
[MixLoRA] epoch 2 | peak_alloc=44.81 GB | peak_reserved=55.77 GB
[MixLoRA] epoch 3/3 | val acc=0.712
[MixLoRA] epoch 3 | peak_alloc=44.81 GB | peak_reserved=55.77 GB
[MixLoRA] BEST val acc=0.739
MixLoRA val acc: 0.7391304347826086


In [18]:
cleanup()
tok, model = load_model(BASE_MODEL_7B)

# Attention adapters: DoRA on Q,K,V,O
cfg_attn = LoraConfig(
    r=LORA_R_MIX,
    lora_alpha=LORA_A_MIX,
    lora_dropout=DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj"],
)
setattr(cfg_attn, "use_dora", True)
model = get_peft_model(model, cfg_attn)
model.to(DEVICE)

# FFN becomes MixDoRA experts
inject_mix_ffn(model, expert_mode="dora")

# Ensure router + expert params trainable
for name, p in model.named_parameters():
    if ("router" in name) or ("experts" in name) or ("lora_" in name) or name.endswith(".m"):
        p.requires_grad = True

print_trainable(model)
mixdora_acc = train_model(tok, model, "MixDoRA")
print("MixDoRA val acc:", mixdora_acc)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Trainable: 129.83M / Total: 7371.56M (1.76%)
[MixDoRA] epoch 1/3 | val acc=0.763
[MixDoRA] epoch 1 | peak_alloc=42.19 GB | peak_reserved=55.47 GB
[MixDoRA] epoch 2/3 | val acc=0.709
[MixDoRA] epoch 2 | peak_alloc=46.23 GB | peak_reserved=56.13 GB
[MixDoRA] epoch 3/3 | val acc=0.709
[MixDoRA] epoch 3 | peak_alloc=46.23 GB | peak_reserved=56.13 GB
[MixDoRA] BEST val acc=0.763
MixDoRA val acc: 0.7625418060200669
