In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
import copy

# LoRA Layer
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        return (x @ self.A @ self.B) * self.alpha

# Wrapper Linear Layer with LoRA
class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

# Hyperparameters
random_seed = 123
torch.manual_seed(random_seed)
x = torch.randn(1, 10)
layer = nn.Linear(10, 5)
print(x)
print(layer)
print('Original output:', layer(x))

# Applying LoRA to Linear Layer
layer_lora_1 = LinearWithLoRA(layer, rank=4, alpha=8)
print(layer_lora_1(x))

# Equivalent: Merged version
class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight + self.lora.alpha * lora.T
        return F.linear(x, combined_weight, self.linear.bias)

layer_lora_2 = LinearWithLoRAMerged(layer, rank=4, alpha=8)
print(layer_lora_2(x))

# Multilayer Perceptron with 3 linear layers
class MultilayerPerceptron(nn.Module):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes)
        )

    def forward(self, x):
        return self.layers(x)

# Architecture parameters
num_features = 28 * 28
num_hidden_1 = 256
num_hidden_2 = 128
num_classes = 10

# Training settings
DEVICE = torch.device('mps' if torch.mps.is_available() else 'cpu')
learning_rate = 0.001
num_epochs = 5

model = MultilayerPerceptron(
    num_features=num_features,
    num_hidden_1=num_hidden_1,
    num_hidden_2=num_hidden_2,
    num_classes=num_classes
)

model.to(DEVICE)
optimizer_pretrained = torch.optim.Adam(model.parameters(), lr=learning_rate)
print(DEVICE)
print(model)
print(optimizer_pretrained)

# Dataset loading
BATCH_SIZE = 64
train_dataset = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

for images, labels in train_loader:
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

# Evaluation function
def compute_accuracy(model, data_loader, device):
    model.eval()
    correct_pred, num_examples = 0, 0
    with torch.no_grad():
        for features, targets in data_loader:
            features = features.view(features.size(0), -1).to(device)
            targets = targets.to(device)
            logits = model(features)
            _, predicted_labels = torch.max(logits, 1)
            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum().item()
    return correct_pred / num_examples * 100

# Training loop
def train(num_epochs, model, optimizer, train_loader, device):
    start_time = time.time()
    for epoch in range(num_epochs):
        model.train()
        for batch_idx, (features, targets) in enumerate(train_loader):
            features = features.view(features.size(0), -1).to(device)
            targets = targets.to(device)

            logits = model(features)
            loss = F.cross_entropy(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if not batch_idx % 400:
                print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} | Batch {batch_idx:03d}/{len(train_loader)} | Loss: {loss:.4f}')

        with torch.no_grad():
            acc = compute_accuracy(model, train_loader, device)
            print(f'Epoch: {epoch+1:03d}/{num_epochs:03d} | Training Accuracy: {acc:.2f}%')

        print(f'Time elapsed: {(time.time() - start_time) / 60:.2f} min')
    print(f'Total Training Time: {(time.time() - start_time) / 60:.2f} min')

# Train baseline model
train(num_epochs, model, optimizer_pretrained, train_loader, DEVICE)
print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')

# Replace Linear layers with LoRA
model_lora = copy.deepcopy(model)
model_lora.layers[0] = LinearWithLoRAMerged(model_lora.layers[0], rank=4, alpha=8)
model_lora.layers[2] = LinearWithLoRAMerged(model_lora.layers[2], rank=4, alpha=8)
model_lora.layers[4] = LinearWithLoRAMerged(model_lora.layers[4], rank=4, alpha=8)

model_lora.to(DEVICE)
optimizer_lora = torch.optim.Adam(model_lora.parameters(), lr=learning_rate)
print(model_lora)

print(f'Test accuracy orig model: {compute_accuracy(model, test_loader, DEVICE):.2f}%')
print(f'Test accuracy LoRA model: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')

# Freeze original linear weights
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad = False
        else:
            freeze_linear_layers(child)

freeze_linear_layers(model_lora)

for name, param in model_lora.named_parameters():
    print(f'{name}: {param.requires_grad}')

# Train LoRA fine-tuned model
optimizer_lora = torch.optim.Adam(filter(lambda p: p.requires_grad, model_lora.parameters()), lr=learning_rate)
train(num_epochs, model_lora, optimizer_lora, train_loader, DEVICE)

print(f'Test accuracy LoRA finetune: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')
print(f'Test accuracy orig model: {compute_accuracy(model, test_loader, DEVICE):.2f}%')
print(f'Test accuracy LoRA model: {compute_accuracy(model_lora, test_loader, DEVICE):.2f}%')


tensor([[-0.1115,  0.1204, -0.3696, -0.2404, -1.1969,  0.2093, -0.9724, -0.7550,
          0.3239, -0.1085]])
Linear(in_features=10, out_features=5, bias=True)
Original output: tensor([[ 0.4257,  0.0299, -0.1865, -0.3084,  0.4543]],
       grad_fn=<AddmmBackward0>)
tensor([[ 0.4257,  0.0299, -0.1865, -0.3084,  0.4543]], grad_fn=<AddBackward0>)
tensor([[ 0.4257,  0.0299, -0.1865, -0.3084,  0.4543]],
       grad_fn=<AddmmBackward0>)
mps
MultilayerPerceptron(
  (layers): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    decoupled_weight_decay: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)
Image batch dimensions: torch.Size([6