In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as tfs
import matplotlib.pyplot as plt

torch.manual_seed(0)

<torch._C.Generator at 0x108df35d0>

In [2]:
from torch.utils.data import DataLoader

transform = tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize(mean=(0.1307,), std=(0.3081,))
])

train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=10, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=10, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self, hidden_size1=1000, hidden_size2=2000):
        super().__init__()

        self.linear1 = nn.Linear(28*28, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size2)
        self.linear3 = nn.Linear(hidden_size2, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)

        return x
    
model = Network().to(device)

In [4]:
num_epochs = 1

In [5]:
def train(train_loader, model, num_epochs=1):
    loss_fn = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        model.train()

        train_loss = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            optim.zero_grad()

            out = model(imgs.view(-1, 28*28))
            loss = loss_fn(out, labels)
            train_loss += loss.item()

            loss.backward()
            optim.step()

        train_loss /= len(train_loader)

        print(f'Epoch {epoch+1} / {num_epochs} | Train Loss: {train_loss: .4f}')

train(train_loader, model, num_epochs)

Epoch 1 / 1 | Train Loss:  0.2387


In [6]:
# Keep copy of original weights
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [7]:
def test():
    correct, total = 0, 0
    wrong_counts = [0 for _ in range(10)]

    model.eval()
    with torch.inference_mode():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            out = model(imgs.view(-1, 28*28))

            for idx, num in enumerate(out):
                if torch.argmax(num) == labels[idx]:
                    correct += 1
                else:
                    wrong_counts[labels[idx]] += 1
                
                total += 1

    print(f'Accuracy: {correct / total}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for digit {i}: {wrong_counts[i]}')

test()

Accuracy: 0.9518
wrong counts for digit 0: 28
wrong counts for digit 1: 12
wrong counts for digit 2: 49
wrong counts for digit 3: 68
wrong counts for digit 4: 33
wrong counts for digit 5: 18
wrong counts for digit 6: 103
wrong counts for digit 7: 44
wrong counts for digit 8: 24
wrong counts for digit 9: 103


In [8]:
# Count total number of parameters
total_params = sum([p.numel() for p in model.parameters()])
print(f'Total number of parameters: {total_params}')

Total number of parameters: 2807010


In [9]:
class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, rank=1, alpha=1.0, device='cpu'):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha

        # We use a random gaussian initialization for A and zero for B so ∆W = BA is zero at the beginning of training
        self.B = nn.Parameter(torch.zeros((in_features, rank)).to(device))
        self.A = nn.Parameter(torch.zeros((rank, out_features)).to(device))
        nn.init.normal_(self.A, mean=0, std=1)

        self.scale = alpha / rank

    def forward(self, original_weights):
        return original_weights + torch.matmul(self.B, self.A).view(original_weights.shape) * self.scale

In [10]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, alpha=1):
    in_features, out_features = layer.weight.shape

    return LoRALinear(in_features, out_features, rank=rank, alpha=alpha, device=device)

# register_parameterization() - runs through LoRALinear before net.linear
parametrize.register_parametrization(model.linear1, "weight", linear_layer_parameterization(model.linear1, device))
parametrize.register_parametrization(model.linear2, "weight", linear_layer_parameterization(model.linear2, device))
parametrize.register_parametrization(model.linear3, "weight", linear_layer_parameterization(model.linear3, device))

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRALinear()
    )
  )
)

In [11]:
total_params_lora = 0
total_params_non_lora = 0

for layer in [model.linear1, model.linear2, model.linear3]:
    total_params_lora += layer.parametrizations['weight'][0].A.nelement() + layer.parametrizations['weight'][0].B.nelement()
    total_params_non_lora += layer.weight.nelement() + layer.bias.nelement()

print(f'Number of non-LoRA parameters: {total_params_non_lora}')
print(f'Number of LoRA parameters: {total_params_lora}')

Number of non-LoRA parameters: 2807010
Number of LoRA parameters: 6794


In [12]:
# Freeze the non-LoRA layers
for name, param in model.named_parameters():
    if 'A' not in name and 'B' not in name:
        print(f'Freezing {name}')
        param.requires_grad = False

Freezing linear1.bias
Freezing linear1.parametrizations.weight.original
Freezing linear2.bias
Freezing linear2.parametrizations.weight.original
Freezing linear3.bias
Freezing linear3.parametrizations.weight.original


In [13]:
ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

keep_indices = (ds.targets == 9)
ds.data = ds.data[keep_indices]
ds.targets = ds.targets[keep_indices]

train_loader = DataLoader(ds, batch_size=10, shuffle=True)

train(train_loader, model, num_epochs)

Epoch 1 / 1 | Train Loss:  0.0216


In [14]:
test()

Accuracy: 0.4966
wrong counts for digit 0: 957
wrong counts for digit 1: 540
wrong counts for digit 2: 202
wrong counts for digit 3: 692
wrong counts for digit 4: 471
wrong counts for digit 5: 587
wrong counts for digit 6: 259
wrong counts for digit 7: 502
wrong counts for digit 8: 823
wrong counts for digit 9: 1
