In [1]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as T
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader

## Data

In [2]:
transform = T.Compose(
    [T.ToTensor(), T.Normalize((0.1307, ), (0.3081,))]
)

# load the MNIST dataset
mnist_trainset = datasets.MNIST(root = "./data", train = True, download = True, transform = transform)
train_loader = DataLoader(mnist_trainset, batch_size = 10, shuffle = True)

mnist_testset = datasets.MNIST(root = "./data", train = False, download = True, transform = transform)
test_loader = DataLoader(mnist_testset, batch_size = 10, shuffle = True)

# Define the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# SimpleNeuralNetwork

class SimpleNN(nn.Module):
    def __init__(
        self,
        hidden_size_1:int = 1000,
        hidden_size_2:int = 2000
    ):
        super(SimpleNN, self).__init__()
        
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        
        return x
    
model = SimpleNN().to(device)

In [4]:
model

SimpleNN(
  (linear1): Linear(in_features=784, out_features=1000, bias=True)
  (linear2): Linear(in_features=1000, out_features=2000, bias=True)
  (linear3): Linear(in_features=2000, out_features=10, bias=True)
  (relu): ReLU()
)

In [5]:
# only 1 epoch to simulate a complete general pre-training on the data

def train(train_loader, model, epochs:int = 5, total_iterations_limit = None):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    
    total_iterations = 0
    
    for epoch in range(epochs):
        model.train()
        
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc = f"Epoch {epoch+1}")
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
            
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            
            x, y = data
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = criterion(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss = avg_loss)
            loss.backward()
            optimizer.step()
            
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
train(train_loader, model, epochs = 1)

Epoch 1: 100%|██████████| 6000/6000 [00:19<00:00, 301.13it/s, loss=0.239]


In [6]:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [7]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 478.32it/s]

Accuracy: 0.95
wrong counts for the digit 0: 34
wrong counts for the digit 1: 18
wrong counts for the digit 2: 93
wrong counts for the digit 3: 36
wrong counts for the digit 4: 104
wrong counts for the digit 5: 14
wrong counts for the digit 6: 39
wrong counts for the digit 7: 54
wrong counts for the digit 8: 77
wrong counts for the digit 9: 32





- LoRA matrices 만들기 전, 기존 model에서 얼마나 많은 파라미터가 있는지 확인

In [8]:
# print the size of the weights matrices of the network
# save the count of the total number of parameters

total_parameters_original = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


- LoRA parameterization 

In [9]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cuda'):
        super().__init__()
        # Section 4.1 of the paper: 
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # Section 4.1 of the paper: 
        #   We then scale ∆Wx by α/r , where α is a constant in r. 
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
        #   As a result, we simply set α to the first r we try and do not tune it. 
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

- 기존 model에 parametrization 추가

In [10]:
import torch.nn.utils.parametrize as parametrize

In [11]:
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.
    
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    model.linear1, "weight", linear_layer_parameterization(model.linear1, device)
)
parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [12]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear1, model.linear2, model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


- 기존 모델의 파라미터는 freeze하고 LoRA 파라미터만 가동하여 파인튜닝.
- 숫자 9에 대해서만 100 배치를 사용

In [13]:
# Freeze the non-Lora parameters
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


In [14]:
for name, param in model.named_parameters():
    if 'lora' in name:
        print(f'Freezing LoRA parameter {name}')
        # param.requires_grad = False

Freezing LoRA parameter linear1.parametrizations.weight.0.lora_A
Freezing LoRA parameter linear1.parametrizations.weight.0.lora_B
Freezing LoRA parameter linear2.parametrizations.weight.0.lora_A
Freezing LoRA parameter linear2.parametrizations.weight.0.lora_B
Freezing LoRA parameter linear3.parametrizations.weight.0.lora_A
Freezing LoRA parameter linear3.parametrizations.weight.0.lora_B


In [16]:
# MNIST 데이터셋 로드, 숫자 9만 유지
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)

Epoch 1:   0%|          | 0/100 [00:00<?, ?it/s, loss=0.128] 

Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 270.73it/s, loss=0.044] 


- 파인튜닝으로 인해 원래 가중치가 변경되지 않고 LoRA에서 도입한 가중치만 변경되었는지 확인

In [17]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(model.linear1.weight, model.linear1.parametrizations.weight.original + (model.linear1.parametrizations.weight[0].lora_B @ model.linear1.parametrizations.weight[0].lora_A) * model.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(model.linear1.weight, original_weights['linear1.weight'])

In [18]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 485.34it/s]

Accuracy: 0.878
wrong counts for the digit 0: 42
wrong counts for the digit 1: 21
wrong counts for the digit 2: 100
wrong counts for the digit 3: 87
wrong counts for the digit 4: 360
wrong counts for the digit 5: 106
wrong counts for the digit 6: 41
wrong counts for the digit 7: 211
wrong counts for the digit 8: 248
wrong counts for the digit 9: 8





In [19]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 516.74it/s]

Accuracy: 0.95
wrong counts for the digit 0: 34
wrong counts for the digit 1: 18
wrong counts for the digit 2: 93
wrong counts for the digit 3: 36
wrong counts for the digit 4: 104
wrong counts for the digit 5: 14
wrong counts for the digit 6: 39
wrong counts for the digit 7: 54
wrong counts for the digit 8: 77
wrong counts for the digit 9: 32



