<a href="https://colab.research.google.com/github/atanuc073/Genrative-AI-development-and-deployment/blob/main/Prod_Finetuning_LORA_Basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Step 1: Creating the LoRA Layer

In [None]:
import torch.nn as nn

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

## Step 2: LoRA applied to an existing linear layer

In [None]:
class LinearWithLoRA(nn.Module):

    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [None]:
import torch

torch.manual_seed(123)
layer = nn.Linear(10, 2)
x = torch.randn((1, 10))

print("Original output:", layer(x))

Original output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)


In [None]:
layer_lora_1 = LinearWithLoRA(layer, rank=2, alpha=4)
print("LoRA output:", layer_lora_1(x))

LoRA output: tensor([[0.6639, 0.4487]], grad_fn=<AddBackward0>)


## Step 3: A 3 layer model

In [None]:
num_features = 768
num_hidden_1 = 128
num_hidden_2 = 256
num_classes = 10

class MultilayerPerceptron(nn.Module):
    def __init__(self, num_features,
        num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),

            nn.Linear(num_hidden_2, num_classes)
        )

    def forward(self, x):
        x = self.layers(x)
        return x


model = MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)

print(model)

MultilayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=768, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)


### Step 4: 3 layer model with LoRA

In [None]:
model.layers[0] = LinearWithLoRA(model.layers[0], rank=4, alpha=8)
model.layers[2] = LinearWithLoRA(model.layers[2], rank=4, alpha=8)
model.layers[4] = LinearWithLoRA(model.layers[4], rank=4, alpha=8)

print(model)

MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLoRA(
      (linear): Linear(in_features=768, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRA(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLoRA(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)


In [None]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad = False
        else:
            # Recursively freeze linear layers in children modules
            freeze_linear_layers(child)

freeze_linear_layers(model)
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [None]:
# You’ll need scikit-learn available for the dataset + standardization
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

## Step 5: Dataset - 768-dim, 10 classes

In [None]:

from torch.utils.data import TensorDataset, DataLoader
import numpy as np

num_features = 768
num_classes  = 10

N_SAMPLES = 12000  # tweak as desired
VAL_SPLIT = 0.10

X, y = make_classification(
    n_samples=N_SAMPLES,
    n_features=num_features,
    n_informative=64,
    n_redundant=16,
    n_repeated=0,
    n_classes=num_classes,
    n_clusters_per_class=2,
    class_sep=2.0,
    flip_y=0.01,
    random_state=0,
)

scaler = StandardScaler()
X = scaler.fit_transform(X).astype(np.float32)
y = y.astype(np.int64)

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=VAL_SPLIT, stratify=y, random_state=0
)

train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
val_ds   = TensorDataset(torch.from_numpy(X_val),   torch.from_numpy(y_val))

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False)

len(train_ds), len(val_ds), X_train.shape, y_train[:5]



(10800, 1200, (10800, 768), array([5, 7, 8, 4, 6]))

## Step 6: Build comparable models (LoRA student vs Full student)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from copy import deepcopy

# Uses MultilayerPerceptron, LinearWithLoRA, LoRALayer definitions

# Dims (must match dataset)
num_hidden_1 = 128
num_hidden_2 = 256

# fresh base for fair init
base = MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)
base_sd = deepcopy(base.state_dict())

# --- LoRA student ---
lora_student = MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)
# wrap linears with LoRA
lora_student.layers[0] = LinearWithLoRA(lora_student.layers[0], rank=4, alpha=8)
lora_student.layers[2] = LinearWithLoRA(lora_student.layers[2], rank=4, alpha=8)
lora_student.layers[4] = LinearWithLoRA(lora_student.layers[4], rank=4, alpha=8)

# load identical base weights
lora_student.load_state_dict(base_sd, strict=False)

# freeze base linear weights; keep LoRA A/B trainable
def freeze_linear_layers(module: nn.Module):
    for child in module.children():
        if isinstance(child, nn.Linear):
            for p in child.parameters():
                p.requires_grad = False
        else:
            freeze_linear_layers(child)

freeze_linear_layers(lora_student)

# --- Full student ---
full_student = MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)
full_student.load_state_dict(base_sd, strict=False)

def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"LoRA trainables: {count_trainable_params(lora_student):,}")
print(f"Full  trainables: {count_trainable_params(full_student):,}")


LoRA trainables: 6,184
Full  trainables: 134,026


## Step 7: Optimizers

In [None]:

criterion = nn.CrossEntropyLoss()

opt_lora = torch.optim.AdamW(
    (p for p in lora_student.parameters() if p.requires_grad),
    lr=2e-3, weight_decay=0.0
)
opt_full = torch.optim.AdamW(
    full_student.parameters(),
    lr=1e-3, weight_decay=0.0
)


## Step 8: Training loop definition

In [None]:
def run_epoch(model, loader, optimizer=None, criterion=None, device=device):
    train = optimizer is not None
    model.train(mode=train)
    total_loss, total_correct, total_n = 0.0, 0, 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        if train:
            optimizer.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()
        else:
            with torch.no_grad():
                logits = model(xb)
                loss = criterion(logits, yb)
        total_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=-1)
        total_correct += (preds == yb).sum().item()
        total_n += xb.size(0)
    return total_loss / total_n, total_correct / total_n


## Step 9: Joint training loop

In [None]:
import time
EPOCHS = 1000  # you can lower to 100–200 for quick runs

parity_epoch_val_loss = None
parity_epoch_val_acc  = None

t0 = time.time()
for ep in range(1, EPOCHS + 1):
    tr_lora = run_epoch(lora_student, train_loader, optimizer=opt_lora, criterion=criterion)
    tr_full = run_epoch(full_student, train_loader, optimizer=opt_full, criterion=criterion)

    va_lora = run_epoch(lora_student, val_loader, optimizer=None, criterion=criterion)
    va_full = run_epoch(full_student, val_loader, optimizer=None, criterion=criterion)

    (tr_loss_lora, tr_acc_lora), (tr_loss_full, tr_acc_full) = tr_lora, tr_full
    (va_loss_lora, va_acc_lora), (va_loss_full, va_acc_full) = va_lora, va_full

    if ep % 5 == 0 or ep == 1:
        print(f"[Ep {ep:04d}] "
              f"LoRA: train {tr_loss_lora:.4f}/{tr_acc_lora:.3f} | val {va_loss_lora:.4f}/{va_acc_lora:.3f}   "
              f"Full: train {tr_loss_full:.4f}/{tr_acc_full:.3f} | val {va_loss_full:.4f}/{va_acc_full:.3f}")

    if parity_epoch_val_loss is None and va_loss_lora <= va_loss_full:
        parity_epoch_val_loss = ep
    if parity_epoch_val_acc is None and va_acc_lora >= va_acc_full:
        parity_epoch_val_acc = ep

total_time = time.time() - t0

print("\n=== Parity summary (validation) ===")
print("Val loss parity:", parity_epoch_val_loss if parity_epoch_val_loss is not None else "Not reached")
print("Val acc  parity:", parity_epoch_val_acc  if parity_epoch_val_acc  is not None else "Not reached")
print(f"Total time: {total_time:.2f}s for {ep} epoch(s)")

[Ep 0001] LoRA: train 2.3323/0.118 | val 2.3046/0.123   Full: train 1.7489/0.437 | val 0.9704/0.677
[Ep 0005] LoRA: train 2.1728/0.195 | val 2.1780/0.202   Full: train 0.0326/0.998 | val 1.0874/0.722
[Ep 0010] LoRA: train 1.7629/0.354 | val 1.8206/0.339   Full: train 0.0021/1.000 | val 1.2437/0.730
[Ep 0015] LoRA: train 1.4709/0.471 | val 1.6153/0.415   Full: train 0.0008/1.000 | val 1.3399/0.728
[Ep 0020] LoRA: train 1.3128/0.528 | val 1.5185/0.466   Full: train 0.0004/1.000 | val 1.4141/0.732
[Ep 0025] LoRA: train 1.2322/0.555 | val 1.4932/0.481   Full: train 0.0003/1.000 | val 1.4761/0.728
[Ep 0030] LoRA: train 1.2050/0.565 | val 1.4704/0.486   Full: train 0.0002/1.000 | val 1.5330/0.728
[Ep 0035] LoRA: train 1.1520/0.583 | val 1.4418/0.495   Full: train 0.0001/1.000 | val 1.5840/0.728
[Ep 0040] LoRA: train 1.1269/0.595 | val 1.4856/0.478   Full: train 0.0001/1.000 | val 1.6312/0.727
[Ep 0045] LoRA: train 1.1263/0.597 | val 1.4874/0.489   Full: train 0.0001/1.000 | val 1.6732/0.727


KeyboardInterrupt: 