In [1]:
import torch
import torch.nn as nn
import math

class LoRALayer(nn.Module):
  def __init__(self, in_features, out_features, rank=4, alpha=8):
    super().__init__()

    self.lora_B = nn.Parameter(torch.zeros(in_features, rank))
    self.lora_A = nn.Parameter(torch.zeros(rank, out_features))
    self.scale = alpha/rank

    nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

  def forward(self, x):
    return self.scale * (x @ self.lora_B @ self.lora_A)

In [2]:
class LinearWithLoRA(nn.Module):
  def __init__(self, linear_layer, rank, alpha):
    super().__init__()

    self.linear_layer = linear_layer
    self.lora = LoRALayer(linear_layer.in_features, linear_layer.out_features, rank, alpha)

    self.lora_enabled = True

    for param in self.linear_layer.parameters():
      param.requires_grad = False

  def forward(self, x):
    if self.lora_enabled:
      return self.linear_layer(x) + self.lora(x)  # x @ (W + B@A)
    else:
      return self.linear_layer(x)

In [3]:
# apply lora
class ModelwithLoRA(nn.Module):
  def __init__(self, base_model, target_modules, rank, alpha):
    super().__init__()

    self.base_model = base_model

    for p in self.base_model.parameters():
      p.requires_grad = False

    self._apply_lora(self.base_model, target_modules, rank, alpha)

  def _apply_lora(self, model, target_modules, rank, alpha):
    for name, module in model.named_children():
      if list(module.children()):
        self._apply_lora(module, target_modules, rank, alpha)

      if name in target_modules and isinstance(module, nn.Linear):
        new_module = LinearWithLoRA(module, rank, alpha)
        setattr(model, name, new_module) # model.name = new_module

  def enabled_disable_lora(self, enabled):
    for module in self.base_model.modules():
      if isinstance(module, LinearWithLoRA):
        module.lora_enabled = enabled

  def forward(self, *args, enabled=True, **kwargs):
    self.enabled_disable_lora(enabled)

    return self.base_model(*args, **kwargs)

In [4]:
class DummyModel(nn.Module):
  def __init__(self):
    super().__init__()

    # Dummy image classification
    self.layer1 = nn.Linear(784, 256)
    self.layer2 = nn.Linear(256, 64)
    self.layer3 = nn.Linear(64, 10)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.relu(self.layer1(x))
    x = self.relu(self.layer2(x))
    x = self.layer3(x)
    return x

In [None]:
model = DummyModel()

for name, module in model.named_modules():
  print(f"name: {name}, module: {module}")

name: , module: DummyModel(
  (layer1): Linear(in_features=784, out_features=256, bias=True)
  (layer2): Linear(in_features=256, out_features=64, bias=True)
  (layer3): Linear(in_features=64, out_features=10, bias=True)
  (relu): ReLU()
)
name: layer1, module: Linear(in_features=784, out_features=256, bias=True)
name: layer2, module: Linear(in_features=256, out_features=64, bias=True)
name: layer3, module: Linear(in_features=64, out_features=10, bias=True)
name: relu, module: ReLU()


In [5]:
# train it
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from tqdm import tqdm

In [6]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_data = MNIST(root='./data', train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=10, shuffle=True)

test_data = MNIST(root='./data', train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=10, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 495kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.93MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.54MB/s]


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
model = DummyModel()
model.to(device)
loss_fn = nn.CrossEntropyLoss()

In [9]:
#trainning
def train(model, dataloader):
  params_to_optimize = [p for p in model.parameters() if p.requires_grad]
  optimizer = torch.optim.Adam(params_to_optimize, lr=3e-4)

  total_loss = 0
  loop = tqdm(dataloader, total=len(dataloader), desc="trainning...")
  for n, (x, y) in enumerate(loop):
    x = x.view(-1, 784).to(device)
    y = y.to(device)

    output = model(x)
    loss = loss_fn(output, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loop.set_postfix({'loss':loss.item()})
    total_loss+= loss.item()


In [14]:
#test
def test(model, testloader, enabled=True):
  correct = 0
  total = 0
  wrong_counts = [0 for _ in range(10)]
  for i, (x, y) in enumerate(tqdm(testloader, total=len(test_loader))):
    x= x.view(-1, 28*28).to(device)
    y  = y.to(device) #b

    if isinstance(model, ModelwithLoRA):
      output = model(x, enabled=enabled) #b, 10
    else:
      output = model(x) #b, 10

    for i, pred in enumerate(output):
      if torch.argmax(pred) == y[i]:
        correct+= 1
      else:
        wrong_counts[y[i]] += 1
      total += 1
  acc = correct/total
  print(f"\n Accuracy: {acc*100:.2f}%")
  print("--- wrong counts... ---")
  for i, count in enumerate(wrong_counts):
    print(f"wrong counts for number {i}: {count}")

In [11]:
train(model, train_loader)

trainning...: 100%|██████████| 6000/6000 [00:35<00:00, 169.64it/s, loss=0.195]


In [15]:
test(model, test_loader)

100%|██████████| 1000/1000 [00:02<00:00, 344.32it/s]


 Accuracy: 96.69%
--- wrong counts... ---
wrong counts for number 0: 8
wrong counts for number 1: 19
wrong counts for number 2: 40
wrong counts for number 3: 48
wrong counts for number 4: 54
wrong counts for number 5: 20
wrong counts for number 6: 21
wrong counts for number 7: 29
wrong counts for number 8: 40
wrong counts for number 9: 52





In [None]:
# for name, param in self.named_parameters():
#         is_lora_param = any(module_name in name for module_name in target_modules)
#         if not is_lora_param:  # Freeze only non-LoRA parameters
#             param.requires_grad = False

In [8]:
import copy

base_model = copy.deepcopy(model)
Lora_model = ModelwithLoRA(base_model, target_modules=['layer1', 'layer2', 'layer3'], rank=8, alpha=16)

In [9]:
Lora_model.to(device)

ModelwithLoRA(
  (base_model): DummyModel(
    (layer1): LinearWithLoRA(
      (linear_layer): Linear(in_features=784, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (layer2): LinearWithLoRA(
      (linear_layer): Linear(in_features=256, out_features=64, bias=True)
      (lora): LoRALayer()
    )
    (layer3): LinearWithLoRA(
      (linear_layer): Linear(in_features=64, out_features=10, bias=True)
      (lora): LoRALayer()
    )
    (relu): ReLU()
  )
)

In [None]:
Lora_model.base_model.layer3.lora.lora_B.shape

torch.Size([64, 8])

In [18]:
exclude_indices = train_data.targets == 9
train_data.data = train_data.data[exclude_indices]
train_data.targets = train_data.targets[exclude_indices]

In [19]:
train_data[500][1]

9

In [20]:
loader = torch.utils.data.DataLoader(train_data, batch_size=10, shuffle=True)

In [42]:
train(Lora_model, loader)

trainning...: 100%|██████████| 595/595 [00:03<00:00, 174.93it/s, loss=0.000493]


In [43]:
test(Lora_model, loader)

 60%|█████▉    | 595/1000 [00:02<00:01, 287.52it/s]


 Accuracy: 100.00%
--- wrong counts... ---
wrong counts for number 0: 0
wrong counts for number 1: 0
wrong counts for number 2: 0
wrong counts for number 3: 0
wrong counts for number 4: 0
wrong counts for number 5: 0
wrong counts for number 6: 0
wrong counts for number 7: 0
wrong counts for number 8: 0
wrong counts for number 9: 0





In [44]:
test(Lora_model, test_loader, enabled=False)

100%|██████████| 1000/1000 [00:02<00:00, 342.72it/s]


 Accuracy: 96.69%
--- wrong counts... ---
wrong counts for number 0: 8
wrong counts for number 1: 19
wrong counts for number 2: 40
wrong counts for number 3: 48
wrong counts for number 4: 54
wrong counts for number 5: 20
wrong counts for number 6: 21
wrong counts for number 7: 29
wrong counts for number 8: 40
wrong counts for number 9: 52





In [11]:
tot_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
tot_param

218058

In [14]:
lora_param = sum(p.numel() for p in Lora_model.parameters() if p.requires_grad)
lora_param

11472

In [None]:
total_param = sum(p.numel() for p in Lora_model.parameters() )
total_param # 218058 + 11472

**Try LoRA on GPT2 checkpoint I saved**