In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import CIFAR10

import lora
from lora.nn import Conv2d
from lora import LoRA

In [15]:
def accuracy(y_true, y_pred):
    total_num = len(y_true)
    correct = torch.sum(y_true == y_pred)
    return correct / total_num

In [16]:
BATCH_SIZE = 128
epochs = 10
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [17]:
train_data = CIFAR10(root="./data/", train=True, transform=ToTensor(), download=True)
test_data = CIFAR10(root="./data/", train=False, transform=ToTensor(), download=True)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

In [19]:
class ConvModel(nn.Module, LoRA):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def train(self, epochs, train_dataloader, loss_fn, optimizer, device):
        for epoch in range(epochs):
            for X_train, y_train in train_dataloader:
                X_train= X_train.type(torch.float32).to(device)
                y_train = y_train.to(device)
                y_pred = self(X_train)
                loss = loss_fn(y_pred, y_train)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
    def eval(self, test_dataloader, loss_fn, accuracy_fn, device):
        with torch.inference_mode():
            test_loss, test_acc = 0, 0
            for X_test, y_test in test_dataloader:
                X_test= X_test.type(torch.float32).to(device)
                y_test = y_test.to(device)
                y_pred = self(X_test)
                test_loss += loss_fn(y_pred, y_test)
                test_acc += accuracy_fn(y_test, y_pred.argmax(dim=1))
            test_loss /= len(test_dataloader)
            test_acc /= len(test_dataloader)
            print(f"Test Acc: {test_acc} | Test Loss: {test_loss}")

    def get_lora_layers(self):
        return [self.conv1, self.conv2]

In [20]:
model = ConvModel().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [21]:
model.train(epochs, train_dataloader, loss_fn, optimizer, device)
model.eval(train_dataloader, loss_fn, accuracy, device)

Test Acc: 0.63995760679245 | Test Loss: 1.0225789546966553


In [23]:
model = lora.set_lora_configs_all(model, 8, 1, True).to(device)

In [24]:
model.train(epochs, train_dataloader, loss_fn, optimizer, device)
model.eval(train_dataloader, loss_fn, accuracy, device)

Test Acc: 0.7079923748970032 | Test Loss: 0.8303349614143372


In [27]:
lora_state = lora.lora_state_dict(model)
torch.save(lora_state, "./data/conv_lora_state.pth")