In [1]:
import math
import torch
import torch.nn as nn

# LORA

In [3]:


class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        # Implementing the LORA adaptation
        lora_adjustment = self.alpha * (x @ self.A) @ self.B
        return nn.functional.linear(x, self.weight + lora_adjustment, self.bias)

# Example usage
input_dim = 512
output_dim = 512
rank = 16  # Rank for the low-rank matrices A and B
alpha = 2  # Scaling factor for LORA adjustment

lora_layer = LORALayer(input_dim, output_dim, rank, alpha)
lora_layer

LORALayer()

In [8]:
import torch
import torch.nn as nn

# Initialize the LORALayer
input_dim, output_dim, rank, alpha = 512, 512, 16, 2
lora_layer = LORALayer(input_dim, output_dim, rank, alpha)

# Initialize an optimizer, using SGD for simplicity
optimizer = torch.optim.SGD(lora_layer.parameters(), lr=0.01)

# Create a dummy input tensor
x = torch.randn(1, input_dim)

# Forward pass
output = lora_layer(x)

# Assert the output shape is as expected
assert output.shape == (1, output_dim), "Output shape is incorrect"
print("Output shape test passed for LORALayer.")

# Compute a simple loss (sum of the output)
loss = output.sum()

# Reset gradients
optimizer.zero_grad()

# Perform backward pass
loss.backward()

# Save the initial state of weights for comparison
original_weight = lora_layer.weight.data.clone()

# Update weights
optimizer.step()

# Check if weights have been updated
assert not torch.equal(original_weight, lora_layer.weight.data), "Weights did not update after optimization step"
print("Weights update test passed for LORALayer.")


Output shape test passed for LORALayer.
Weights update test passed for LORALayer.


# QLORA

In [10]:


class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        # Quantize A and B
        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute the LORA adjustment with quantized parameters
        lora_adjustment = self.alpha * (x @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        return nn.functional.linear(x, self.weight + lora_adjustment, self.bias)

# Example usage
input_dim = 512
output_dim = 512
rank = 16
alpha = 2
quantization_bits = 8

qlora_layer = QLORALayer(input_dim, output_dim, rank, alpha, quantization_bits)
qlora_layer

QLORALayer()

In [11]:
import torch
import torch.nn as nn

# Initialize the QLORALayer
input_dim, output_dim, rank, alpha, quantization_bits = 512, 512, 16, 2, 8
qlora_layer = QLORALayer(input_dim, output_dim, rank, alpha, quantization_bits)

# Initialize an optimizer, using SGD for simplicity
optimizer = torch.optim.SGD(qlora_layer.parameters(), lr=0.01)

# Create a dummy input tensor
x = torch.randn(1, input_dim)

# Forward pass
output = qlora_layer(x)

# Assert the output shape is as expected
assert output.shape == (1, output_dim), "Output shape is incorrect"
print("Output shape test passed for QLORALayer.")

# Compute a simple loss (sum of the output)
loss = output.sum()

# Reset gradients
optimizer.zero_grad()

# Perform backward pass
loss.backward()

# Save the initial state of weights for comparison
original_weight = qlora_layer.weight.data.clone()

# Update weights
optimizer.step()

# Check if weights have been updated
assert not torch.equal(original_weight, qlora_layer.weight.data), "Weights did not update after optimization step"
print("Weights update test passed for QLORALayer.")

# Test for quantization effectiveness

# Create copies of A and B before quantization
A_original = qlora_layer.A.data.clone()
B_original = qlora_layer.B.data.clone()

# Perform a forward pass to apply quantization
output = qlora_layer(x)

# Retrieve the quantized versions of A and B from the forward pass
A_quantized, _ = qlora_layer.quantize(A_original, qlora_layer.quantization_bits)
B_quantized, _ = qlora_layer.quantize(B_original, qlora_layer.quantization_bits)

# Calculate the number of unique values before and after quantization
unique_values_A_before = torch.unique(A_original).numel()
unique_values_B_before = torch.unique(B_original).numel()
unique_values_A_after = torch.unique(A_quantized).numel()
unique_values_B_after = torch.unique(B_quantized).numel()

# Check if quantization reduced the number of unique values
assert unique_values_A_after < unique_values_A_before, "Quantization not effective on A"
assert unique_values_B_after < unique_values_B_before, "Quantization not effective on B"
print("Quantization effectiveness test passed for QLORALayer.")


Output shape test passed for QLORALayer.
Weights update test passed for QLORALayer.
Quantization effectiveness test passed for QLORALayer.
