# Simplified LoRA adaptation of FFN

We will show how to do LoRA on a simple FFN by first pre-training it on Fashion MNIST and then finetune it on MNIST. As those datasets don't have a ton to do the performance will be quite bad, but we seek to show how to do PEFT in general regardless of the model

## Pre-Training

In [None]:
pip install datasets

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization
from tensorflow.keras.datasets import fashion_mnist

# Load Fashion MNIST dataset
(train_images, train_labels), (test_images, test_labels) = None # Load Fashion MNist dataset

# Normalize the images
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the more complex model
model = keras.Sequential([
    Flatten(input_shape=(28, 28)),

    Dense(1024, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(512, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(256, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(128, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(64, activation='relu'),
    BatchNormalization(),
    Dropout(0.3),

    Dense(10, activation='softmax')  # 10 classes in Fashion MNIST
])

# Compile the model
None # Compile the model




In [None]:
model.summary()

In [None]:
# Train the model
None # Train the model for at least 15 epochs

## Lora-Adaptation

Load the new dataset

In [None]:
from tensorflow.keras.datasets import mnist

# Load MNIST dataset
(train_images, train_labels), (test_images, test_labels) = None # Load MNIST dataset

# Normalize the images
train_images = train_images / 255.0
test_images = test_images / 255.0

# Reshape images for the model
train_images = train_images.reshape((-1, 28, 28, 1))
test_images = test_images.reshape((-1, 28, 28, 1))


Let's implement a LoRA layer, remember the LoRA implementation consists of two low rank dense layers:



<img src='https://www.dropbox.com/scl/fi/dfhuc42h5ohcbfny14gg8/lora.png?rlkey=7ku1ocyzibdgmnkup7kmsd8gb&raw=1'  />


In [None]:
class LoraLayer(keras.layers.Layer):
    def __init__(
        self,
        original_layer,
        rank=8,
        num_heads =1,
        dim = 1,
        trainable=False,
        **kwargs,
    ):
        # We want to keep the name of this layer the same as the original
        # dense layer.
        original_layer_config = original_layer.get_config()
        name = original_layer_config["name"]

        kwargs.pop("name", None)

        super().__init__(name=name, trainable=trainable, **kwargs)

        self.rank = rank


        # Layers.

        # Original dense layer.
        self.original_layer = original_layer
        # No matter whether we are training the model or are in inference mode,
        # this layer should be frozen.
        None # Set layer as non trainable

        # LoRA dense layers.
        self.A = None # Set A to be the first Dense layer, don't use bias, how many units should it have? Set the name as lora_A

        self.B = None # Set B to be the second Dense layer, don't use bias, how many units should it have? Set the name as lora_B

    def call(self, inputs):
        original_output = self.original_layer(inputs)
        if self.trainable:
            # If we are fine-tuning the model, we will add LoRA layers' output
            # to the original layer's output.
            lora_output = None # Implement lora output
            return original_output + lora_output

        # If we are in inference mode, we "merge" the LoRA layers' weights into
        # the original layer's weights
        return original_output

We will randomly change some Dense layers into Lora Adapted layers

In [None]:
import random
# Define a function to replace dense layers with LoraLayer
def replace_with_lora(model):
    new_model = keras.Sequential()
    for layer in model.layers:
        if isinstance(layer, Dense) and random.random() > 0.5:
            new_model.add(None)  # Add LoraLayer, set accordingly the dim
        else:
            new_model.add(None) # Else use the layer
    return new_model

# Replace layers in the model
lora_model = replace_with_lora(model)

None # Build the model


# Compile the model
lora_model.compile(None)  # Compile the model


In [None]:
lora_model.summary()

Notice the non-trainable parameters

In [None]:
# Fine-tune the model
None # Train the model

As mentioned, performance sucks, but the important thing is that we finetuned only the LoraLayers