In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)


100%|██████████| 9.91M/9.91M [00:04<00:00, 2.04MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 56.8kB/s]
100%|██████████| 1.65M/1.65M [00:07<00:00, 232kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.25MB/s]


In [39]:
class simpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1=nn.Linear(784,256)
        self.l2=nn.Linear(256, 10)
        self.relu=nn.ReLU()
    
    def forward(self,x):
        x=x.view(-1,784)
        x=self.relu(self.l1(x))
        x=self.l2(x)
        return x


In [40]:
model = simpleNN().to(device)

def noofparams(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

totalbefore,trainablebefore=noofparams(model)


In [41]:
def train_model(model, loader, epochs=1, max_iters=None):
    model.train()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    total = 0
    for epoch in range(epochs):
        loop = tqdm(loader, total=max_iters or len(loader))
        for x, y in loop:
            if max_iters and total >= max_iters:
                return
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())
            total += 1


def evaluate(model):
    model.eval()
    correct = [0]*10
    total = [0]*10
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            preds = out.argmax(dim=1)
            for i in range(len(y)):
                total[y[i]] += 1
                if preds[i] == y[i]:
                    correct[y[i]] += 1
    for i in range(10):
        print(f"Digit {i} Accuracy: {correct[i]}/{total[i]}")

In [42]:
print('before LORA')
evaluate(model)
print("no. of total params:",totalbefore)
print("no. of trainable params:",trainablebefore)


original_weights = {name: param.detach().clone() for name, param in model.named_parameters()}


before LORA
Digit 0 Accuracy: 10/980
Digit 1 Accuracy: 623/1135
Digit 2 Accuracy: 42/1032
Digit 3 Accuracy: 58/1010
Digit 4 Accuracy: 1/982
Digit 5 Accuracy: 117/892
Digit 6 Accuracy: 33/958
Digit 7 Accuracy: 40/1028
Digit 8 Accuracy: 2/974
Digit 9 Accuracy: 1/1009
no. of total params: 203530
no. of trainable params: 203530


In [43]:
class Lora(nn.Module):
    def __init__(self,indim,outdim,rank=1,alpha=1):
        super().__init__()
        self.lora_a=nn.Parameter(torch.randn(rank,outdim))
        self.lora_b=nn.Parameter(torch.zeros(indim,rank))
        self.scale=alpha/rank
        self.enabled=True
    
    def forward(self,w):
        if self.enabled:
            return w+(self.lora_b @ self.lora_a).view(w.shape) * self.scale
        else :
            return w



In [44]:
def apply_lora(layer):
    in_f, out_f = layer.weight.shape
    parametrize.register_parametrization(layer, "weight", Lora(in_f, out_f, rank=1, alpha=1))

In [45]:
apply_lora(model.l1)
apply_lora(model.l2)

In [46]:
def toggle_lora(enabled):
    for layer in [model.l1, model.l2]:
        layer.parametrizations["weight"][0].enabled = enabled

for name, param in model.named_parameters():
    if 'lora' not in name:
        param.requires_grad = False

In [47]:
def noofparams(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

totalafter,trainableafter=noofparams(model)
print('after LORA')
evaluate(model)
print("no. of total params:",totalafter)
print("no. of trainable params:",trainableafter)


after LORA
Digit 0 Accuracy: 10/980
Digit 1 Accuracy: 623/1135
Digit 2 Accuracy: 42/1032
Digit 3 Accuracy: 58/1010
Digit 4 Accuracy: 1/982
Digit 5 Accuracy: 117/892
Digit 6 Accuracy: 33/958
Digit 7 Accuracy: 40/1028
Digit 8 Accuracy: 2/974
Digit 9 Accuracy: 1/1009
no. of total params: 204836
no. of trainable params: 1306


In [None]:
## also sanity check
assert torch.allclose(model.l1.parametrizations.weight.original, original_weights['l1.weight'])
assert torch.allclose(model.l2.parametrizations.weight.original, original_weights['l2.weight'])