In [None]:
!pip install transformers torch datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel

# -----------------------------
# LoRA Module (based on Hu et al. 2021)
# -----------------------------
class LoRA(nn.Module):
    def __init__(self, r, alpha, layer):
        """
        Implements LoRA injection: W + ΔW, where ΔW = BA, rank r
        """
        super().__init__()
        self.base_layer = layer                     # Original frozen weight (e.g., query or value)
        self.scale = alpha / r                      # LoRA scaling factor (from paper)

        self.n = layer.in_features
        self.m = layer.out_features

        # Low-rank trainable matrices
        self.B = nn.Parameter(torch.zeros(self.n, r))    # B ∈ R^{in_features × r}
        self.A = nn.Parameter(torch.rand(r, self.m))     # A ∈ R^{r × out_features}

        self.merged = False

        # Dropout on the delta path (optional; improves generalization)
        self.dropout = nn.Dropout(p=0.5)

        # Freeze the base layer's weights
        for param in self.base_layer.parameters():
            param.requires_grad = False

    def forward(self, x):
        base = self.base_layer(x)                         # frozen forward
        delta = F.linear(x, torch.matmul(self.B, self.A)) # LoRA delta
        delta = self.dropout(delta)                       # optional dropout
        return base + self.scale * delta                  # Inject LoRA

    def merge(self):
        """
        Merges LoRA weights into base weights (for inference).
        Only needed if you want to export or remove dependency on LoRA modules.
        """
        if not self.merged:
            merged = torch.matmul(self.B, self.A).T       # Match shape [out_features, in_features]
            self.base_layer.weight.data += self.scale * merged
            self.merged = True

# -----------------------------
# LoRA Applied to BERT (query/value)
# -----------------------------
class loraBERT(nn.Module):
    def __init__(self, model_name, num_classes):
        """
        Wraps a BERT model and injects LoRA adapters into attention (query and value).
        """
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)

        # Inject LoRA into each transformer layer (Q and V only, per original paper)
        for layer in self.model.encoder.layer:
            q = layer.attention.self.query
            v = layer.attention.self.value
            layer.attention.self.query = LoRA(r=4, alpha=10, layer=q)
            layer.attention.self.value = LoRA(r=4, alpha=10, layer=v)

        # Freeze all model parameters by default
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze LoRA weights only
        for module in self.model.modules():
            if isinstance(module, LoRA):
                for param in module.parameters():
                    param.requires_grad = True

        # Task-specific classifier head
        self.classifier = nn.Linear(self.model.config.hidden_size, num_classes)

    def count_params(self):
        """
        Returns number of trainable LoRA params vs total model size.
        """
        lora = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.model.parameters())
        return {'lora_params': lora, 'total_params': total}

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        pooled = output.last_hidden_state[:, 0, :]  # CLS token
        return self.classifier(pooled)

# -----------------------------
# LoRA-augmented MLP (example outside transformers)
# -----------------------------
class loraMLP(nn.Module):
    def __init__(self, n, m, h):
        """
        LoRA applied to MLP layers.
        Args:
          m: input_dim
          h: hidden_dim
          n: output_dim
        """
        super().__init__()
        fc1 = nn.Linear(m, h)
        fc2 = nn.Linear(h, n)
        self.relu = nn.ReLU()

        self.new_fc1 = LoRA(4, 10, fc1)
        self.new_fc2 = LoRA(4, 10, fc2)

    def forward(self, x):
        h1 = self.relu(self.new_fc1(x))
        return self.new_fc2(h1)

# -----------------------------
# Instantiation Example
# -----------------------------
model_name = 'bert-base-uncased'
lora_bert = loraBERT(model_name, num_classes=2)

lora_mlp = loraMLP(n=10, m=2, h=4)


In [None]:
# === Setup ===
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.optim import AdamW

# Load tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
dataset = load_dataset('glue', 'sst2')

# Tokenization function
def tokenize_fn(batch):
    return tokenizer(batch['sentence'], padding='max_length', truncation=True, max_length=128)

# Apply tokenizer to dataset
encoded = dataset.map(tokenize_fn, batched=True)
encoded.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# Create DataLoaders
train_loader = DataLoader(encoded['train'], batch_size=32, shuffle=True)
val_loader = DataLoader(encoded['validation'], batch_size=32)

# Instantiate LoRA-BERT model
model = loraBERT('bert-base-uncased', num_classes=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Loss and optimizer (only trainable LoRA + classifier weights)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5)

# === Training Loop ===
for epoch in range(3):
    model.train()
    total_loss = 0
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Train Loss: {total_loss / len(train_loader):.4f}")

# === Validation ===
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        logits = model(input_ids, attention_mask)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Validation Accuracy: {correct / total:.4f}")
