# Lora finetuning
> Example of finetuning lora

In the following notebook we are going to use my custom implementation of LoRA to fine-tune a simple model

### General Imports

In [1]:
import numpy as np
from tinygrad import Tensor, nn
import copy

# from extra.training import evaluate, train
from utils import *

##### Importing custom LoRA library

In [2]:
import os
import sys

# Get the path of the current working directory
current_dir = os.path.abspath(os.getcwd())

# Get the path of the parent directory
parent_dir = os.path.abspath(os.path.join(current_dir, ".."))

# Add the parent directory to the system path
sys.path.append(parent_dir)

# Now you can import the LoRA module
from lora_tinygrad import LoRA

# Now you can import the DoRA module
from dora_tinygrad import DoRA

### Define a simple model 

In [3]:
class TinyNet:
    def __init__(self):
        self.l1 = nn.Linear(784, 784 * 3, bias=False)
        self.l2 = nn.Linear(784 * 3, 784, bias=False)
        self.l3 = nn.Linear(784, 128, bias=False)
        self.l4 = nn.Linear(128, 10, bias=False)

    def __call__(self, x):
        x = self.l1(x).leakyrelu()
        x = self.l2(x).leakyrelu()
        x = self.l3(x).leakyrelu()
        x = self.l4(x)
        return x

## Model pre-training 

#### Hyperparameters & Fetching Dataset

In [4]:
lr = 1e-3
epochss = 3
BS = 128
n_outputs = 10

X_train, Y_train, X_test, Y_test = fetch_fashion_mnist()
steps = len(X_train) // BS

#### Defining the model and loss function

In [5]:
# Define the model
model = TinyNet()

# Define loss function
lossfn = Tensor.sparse_categorical_crossentropy

#### Traning the model

In [6]:
# Pre-training the model
for _ in range(epochss):
    optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=lr)
    train(model, X_train, Y_train, optimizer, lossfn=lossfn, steps=steps, BS=BS)
    accuracy, Y_test_pred = evaluate(model, X_test, Y_test, return_predict=True)
    lr /= 1.2
    print(f"reducing lr to {lr:.7f}")

loss 0.33 accuracy 0.88: 100%|████████████████████████████████| 468/468 [00:03<00:00, 123.01it/s]
100%|███████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 258.96it/s]


test set accuracy is 0.834200
reducing lr to 0.0008333


loss 0.35 accuracy 0.88: 100%|████████████████████████████████| 468/468 [00:03<00:00, 137.96it/s]
100%|███████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 339.93it/s]


test set accuracy is 0.854100
reducing lr to 0.0006944


loss 0.26 accuracy 0.91: 100%|████████████████████████████████| 468/468 [00:03<00:00, 140.84it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 253.79it/s]

test set accuracy is 0.852600
reducing lr to 0.0005787





#### Get mislabeled predictions

In [7]:
mislabeled_counts = get_mislabeled_counts(Y_test, Y_test_pred, n_output=n_outputs)
worst_class = max(mislabeled_counts, key=lambda k: mislabeled_counts[k])

## Finetuning

Let's start by craeting a dataset for the finetuning on the worst examples to see if there is actually some improvement

In [8]:
pretty_print_mislabeled_counts(mislabeled_counts)
print(f"Fine-tuning the worst class, {worst_class}..")
lrs = 1e-5
epochss = 1
BS = 64

# Get a mixture which is mostly filled with the worst class
X_train, Y_train = mix_old_and_new_data(X_train, Y_train, worst_class, ratio = 0.3)
steps = len(X_train) // BS

Class 0: Missing 192
Class 1: Missing 61
Class 2: Missing 238
Class 3: Missing 74
Class 4: Missing 391
Class 5: Missing 28
Class 6: Missing 309
Class 7: Missing 64
Class 8: Missing 62
Class 9: Missing 55
Fine-tuning the worst class, 4..


### Fine-tuning without Lora (full fine-tuning)

Let's first do a full finetuning of the model to then compare the performance

In [9]:
# Creating a copy of the model
model_full_finetuning = copy.deepcopy(model) 

# Finetuning the model
for _ in range(epochss):
    optimizer = nn.optim.Adam(nn.state.get_parameters(model_full_finetuning), lr=lr)
    # Default loss function is sparse_categorical_crossentropy
    train(model_full_finetuning, X_train, Y_train, optimizer, steps=steps, BS=BS)
    accuracy, Y_test_pred = evaluate(model_full_finetuning, X_test, Y_test, return_predict=True)

loss 0.29 accuracy 0.91: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 375/375 [00:02<00:00, 147.25it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 305.93it/s]

test set accuracy is 0.858800





#### Visualize results

In [10]:
mislabeled_counts = get_mislabeled_counts(Y_test, Y_test_pred, n_output=n_outputs)
pretty_print_mislabeled_counts(mislabeled_counts)
print(f"New worst class: {max(mislabeled_counts, key=lambda k: mislabeled_counts[k])}")

Class 0: Missing 171
Class 1: Missing 44
Class 2: Missing 319
Class 3: Missing 77
Class 4: Missing 117
Class 5: Missing 78
Class 6: Missing 474
Class 7: Missing 61
Class 8: Missing 32
Class 9: Missing 39
New worst class: 6


### Fine-tuning with LoRA

Now let's do the LoRA finetuning on the other same data with a rank of 64

In [11]:
# Getting the Lora model from the original model without modifying the original one
lora_model = LoRA.from_module(model, rank=64, inplace=False)

# Pre-training the model
for _ in range(epochss):
    optimizer = nn.optim.Adam(lora_model.parameters(), lr=lr)
    # Default loss function is sparse_categorical_crossentropy
    train(lora_model, X_train, Y_train, optimizer, steps=steps, BS=BS)
    accuracy, Y_test_pred = evaluate(lora_model, X_test, Y_test, return_predict=True)

loss 1.56 accuracy 0.64: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 375/375 [00:02<00:00, 172.22it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 120.35it/s]

test set accuracy is 0.576100





#### Visualize results

In [12]:
mislabeled_counts = get_mislabeled_counts(Y_test, Y_test_pred, n_output=n_outputs)
pretty_print_mislabeled_counts(mislabeled_counts)
print(f"New worst class: {max(mislabeled_counts, key=lambda k: mislabeled_counts[k])}")

Class 0: Missing 852
Class 1: Missing 75
Class 2: Missing 671
Class 3: Missing 251
Class 4: Missing 208
Class 5: Missing 401
Class 6: Missing 938
Class 7: Missing 335
Class 8: Missing 446
Class 9: Missing 62
New worst class: 6


#### Show the parameters we trained in the model

In [13]:
original_parameters = sum(p.numel() for p in nn.state.get_parameters(model_full_finetuning))
lora_parameters = sum(p.numel() for p in lora_model.parameters())

print(f"{original_parameters = }")
print(f"{lora_parameters = }")
print(f"Percentage of parameters we update: {(lora_parameters / original_parameters) * 100:.2f}%")

original_parameters = 3789568
lora_parameters = 468608
Percentage of parameters we update: 12.37%


### Fine-tuning with DoRA

Now let's do the DoRA finetuning on the other same data with a rank of 32

In [14]:
# Getting the Lora model from the original model without modifying the original one
dora_model = DoRA.from_module(model, rank=64, inplace=False)

# Pre-training the model
for _ in range(epochss):
    optimizer = nn.optim.Adam(dora_model.parameters(), lr=lr)
    # Default loss function is sparse_categorical_crossentropy
    train(dora_model, X_train, Y_train, optimizer, steps=steps, BS=BS)
    accuracy, Y_test_pred = evaluate(dora_model, X_test, Y_test, return_predict=True)

loss 1.95 accuracy 0.66: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 375/375 [00:02<00:00, 182.14it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [00:00<00:00, 129.30it/s]

test set accuracy is 0.691000





#### Visualize results

In [15]:
mislabeled_counts = get_mislabeled_counts(Y_test, Y_test_pred, n_output=n_outputs)
pretty_print_mislabeled_counts(mislabeled_counts)
print(f"New worst class: {max(mislabeled_counts, key=lambda k: mislabeled_counts[k])}")

Class 0: Missing 364
Class 1: Missing 88
Class 2: Missing 401
Class 3: Missing 136
Class 4: Missing 153
Class 5: Missing 548
Class 6: Missing 672
Class 7: Missing 409
Class 8: Missing 49
Class 9: Missing 270
New worst class: 6


#### Show the parameters we trained in the model

In [16]:
original_parameters = sum(p.numel() for p in nn.state.get_parameters(model_full_finetuning))
dora_parameters = sum(p.numel() for p in dora_model.parameters())

print(f"{original_parameters = }")
print(f"{dora_parameters = }")
print(f"Percentage of parameters we update: {(dora_parameters / original_parameters) * 100:.2f}%")

original_parameters = 3789568
dora_parameters = 471882
Percentage of parameters we update: 12.45%


## Other functionalities

In the following section we will test some other functionalities I implemented in the library

In [17]:
# Getting a random example to test the model
x = Tensor.randn(1, 28, 28).reshape(-1)

# Assert if the values are not all the same and thus I have done something
assert not np.allclose(model(x).numpy(), lora_model(x).numpy()), "The outputs are too close!"

# Disable the lora parameters
lora_model.disable_lora()

# Assert if the values are the same and thus I haven't changed the original model
assert np.allclose(model(x).numpy(), lora_model(x).numpy()), "The outputs are too close!"

# Showcase that lora can be re-enabled
lora_model.enable_lora()


# Merge lora into the original weights not inplace
new_model = lora_model.merge_lora(inplace=False)

assert np.allclose(new_model(x).numpy(), lora_model(x).numpy()), "The outputs are too close!"

# NOTE: new_model has the same type as the original model! Inference is just as fast as in the original model.
assert isinstance(new_model, TinyNet)

print("Everything works as expected")

Everything works as expected
