# Simplified LoRA adaptation of FFN

We will show how to do LoRA on a simple FFN by first pre-training it on Fashion MNIST and then finetune it on MNIST. As those datasets don't have a ton to do the performance will be quite bad, but we seek to show how to do PEFT in general regardless of the model

## Pre-Training

In [1]:
pip install datasets

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194

In [2]:
import tensorflow as tf
from datasets import load_dataset

from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization
from tensorflow.keras.datasets import fashion_mnist

# Load Fashion MNIST datasetda
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
 # Load Fashion MNist dataset

# Normalize the images
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the more complex model
model = keras.Sequential([
    Flatten(input_shape=(28, 28)),

    Dense(1024, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(512, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(256, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(128, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(64, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(10, activation='softmax')  # 10 classes in Fashion MNIST
])

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics='accuracy'
) # Compile the model




Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [3]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 1024)              803840    
                                                                 
 batch_normalization (Batch  (None, 1024)              4096      
 Normalization)                                                  
                                                                 
 dropout (Dropout)           (None, 1024)              0         
                                                                 
 dense_1 (Dense)             (None, 512)               524800    
                                                                 
 batch_normalization_1 (Bat  (None, 512)               2048      
 chNormalization)                                       

In [4]:
# Train the model
model.fit(train_images,train_labels,batch_size=256, epochs=15, validation_split=0.2)# Train the model for at least 15 epochs

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


<keras.src.callbacks.History at 0x7adfdb0560b0>

## Lora-Adaptation

Load the new dataset

In [5]:
from tensorflow.keras.datasets import mnist

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() # Load MNIST dataset

# Normalize the images
train_images = train_images / 255.0
test_images = test_images / 255.0

# Reshape images for the model
train_images = train_images.reshape((-1, 28, 28, 1))
test_images = test_images.reshape((-1, 28, 28, 1))


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Let's implement a LoRA layer, remember the LoRA implementation consists of two low rank dense layers:



<img src='https://www.dropbox.com/scl/fi/dfhuc42h5ohcbfny14gg8/lora.png?rlkey=7ku1ocyzibdgmnkup7kmsd8gb&raw=1'  />


In [6]:
class LoraLayer(keras.layers.Layer):
    def __init__(
        self,
        original_layer,
        rank=8,
        num_heads =1,
        dim = 1,
        trainable=False,
        **kwargs,
    ):
        # We want to keep the name of this layer the same as the original
        # dense layer.
        original_layer_config = original_layer.get_config()
        name = original_layer_config["name"]

        kwargs.pop("name", None)

        super().__init__(name=name, trainable=trainable, **kwargs)

        self.rank = rank


        # Layers.

        # Original dense layer.
        self.original_layer = original_layer
        # No matter whether we are training the model or are in inference mode,
        # this layer should be frozen.
        self.original_layer.trainable = False # Set layer as non trainable

        # LoRA dense layers.
        self.A = Dense(units=rank, use_bias=False, trainable=trainable, name="lora_A") # Set A to be the first Dense layer, don't use bias, how many units should it have? Set the name as lora_A

        self.B = Dense(units=dim, use_bias=False, trainable=trainable, name="lora_B") # Set B to be the second Dense layer, don't use bias, how many units should it have? Set the name as lora_B

    def call(self, inputs):
        original_output = self.original_layer(inputs)
        if self.trainable:
            # If we are fine-tuning the model, we will add LoRA layers' output
            # to the original layer's output.
            lora_output = self.B(self.A(inputs)) # Implement lora output
            return original_output + lora_output

        # If we are in inference mode, we "merge" the LoRA layers' weights into
        # the original layer's weights
        return original_output

We will randomly change some Dense layers into Lora Adapted layers

In [10]:
import random
# Define a function to replace dense layers with LoraLayer
def replace_with_lora(model):
    new_model = keras.Sequential()
    for layer in model.layers:
        if isinstance(layer, Dense) and random.random() > 0.5:
            new_model.add(LoraLayer(original_layer=layer, rank=4, dim=layer.units, trainable=True))
        else:
            new_model.add(layer)
    return new_model

# Replace layers in the model
lora_model = replace_with_lora(model)

lora_model.build(input_shape=(None, 28, 28, 1))

# Compile the model
lora_model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-4),
                   loss='sparse_categorical_crossentropy',
                   metrics=['accuracy'])

In [11]:
lora_model.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 1024)              803840    
                                                                 
 batch_normalization (Batch  (None, 1024)              4096      
 Normalization)                                                  
                                                                 
 dropout (Dropout)           (None, 1024)              0         
                                                                 
 dense_1 (Dense)             (None, 512)               524800    
                                                                 
 batch_normalization_1 (Bat  (None, 512)               2048      
 chNormalization)                                     

Notice the non-trainable parameters

In [None]:
# Fine-tune the model
lora_model.fit(train_images,train_labels, batch_size=128, epochs=10, validation_split=0.2) # Train the model

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10

As mentioned, performance sucks, but the important thing is that we finetuned only the LoraLayers