In [1]:
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn
from torch.utils.data import DataLoader
import lora
from lora import LoRA
from lora.nn import Linear

In [2]:
def accuracy(y_true, y_pred):
    total_num = len(y_true)
    correct = torch.sum(y_true == y_pred)
    return correct / total_num

In [3]:
BATCH_SIZE = 128
epochs = 10
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
train_data = MNIST("data/", train=True, transform=ToTensor(), download=True)
test_data = MNIST("data/", train=False, transform=ToTensor(), download=True)

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

In [5]:
class LinearModel(nn.Module, LoRA):
    def __init__(self, in_features, out_features) -> None:
        super().__init__()
        self.layer1 = Linear(in_features, 1000)
        self.layer2 = Linear(1000, 2000)
        self.layer3 = Linear(2000, out_features)
    
    def forward(self, X):
        out = self.layer1(X)
        out = self.layer2(out)
        out = self.layer3(out)
        return out
    
    def train(self, epochs, train_dataloader, loss_fn, optimizer, device):
        for epoch in range(epochs):
            for X_train, y_train in train_dataloader:
                X_train= X_train.view(-1, 28 * 28).type(torch.float32).to(device)
                y_train = y_train.to(device)
                y_pred = self(X_train)
                loss = loss_fn(y_pred, y_train)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
    def eval(self, test_dataloader, loss_fn, accuracy_fn, device):
        with torch.inference_mode():
            test_loss, test_acc = 0, 0
            for X_test, y_test in test_dataloader:
                X_test= X_test.view(-1, 28 * 28).type(torch.float32).to(device)
                y_test = y_test.to(device)
                y_pred = self(X_test)
                test_loss += loss_fn(y_pred, y_test)
                test_acc += accuracy_fn(y_test, y_pred.argmax(dim=1))
            test_loss /= len(test_dataloader)
            test_acc /= len(test_dataloader)
            print(f"Test Acc: {test_acc} | Test Loss: {test_loss}")
    
    def get_lora_layers(self):
        return [self.layer1, self.layer2, self.layer3]

In [6]:
model = LinearModel(28 * 28, 10).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
model.train(epochs, train_dataloader, loss_fn, optimizer, device)
model.eval(test_dataloader, loss_fn, accuracy, device)

Test Acc: 0.9212816953659058 | Test Loss: 0.29458293318748474


In [8]:
lora_layers = model.get_lora_layers()
for layer in lora_layers:
    layer.set_lora_configs(rank=8, alpha=1)
    layer.set_lora_status(True)
model.to(device)

LinearModel(
  (layer1): Linear(
    (linear): Linear(in_features=784, out_features=1000, bias=True)
    (lora_A): Linear(in_features=784, out_features=8, bias=False)
    (lora_B): Linear(in_features=8, out_features=1000, bias=False)
  )
  (layer2): Linear(
    (linear): Linear(in_features=1000, out_features=2000, bias=True)
    (lora_A): Linear(in_features=1000, out_features=8, bias=False)
    (lora_B): Linear(in_features=8, out_features=2000, bias=False)
  )
  (layer3): Linear(
    (linear): Linear(in_features=2000, out_features=10, bias=True)
    (lora_A): Linear(in_features=2000, out_features=8, bias=False)
    (lora_B): Linear(in_features=8, out_features=10, bias=False)
  )
)

In [9]:
model.train(epochs, train_dataloader, loss_fn, optimizer, device)
model.eval(test_dataloader, loss_fn, accuracy, device)

Test Acc: 0.9226661920547485 | Test Loss: 0.2909885346889496


In [None]:
lora_state = lora.lora_state_dict(model)
torch.save(lora_state, "chkpts/test_lora_state.pth")

In [None]:
new_model = LinearModel(28 * 28, 10)
new_model = lora.set_lora_configs_all(new_model, rank=8, alpha=1, enable_lora=False)
lora_state = torch.load('chkpts/test_lora_state.pth', weights_only=True)
new_model.load_state_dict(lora_state, strict=False)