## 1. Import packages

In [1]:
import math
from dataclasses import dataclass, asdict
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms


import bitsandbytes as bnb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def check_memory_allocated(model, device):
    initial_mem = torch.cuda.memory_allocated()

    model = model.to(device)

    final_mem = torch.cuda.memory_allocated()

    mem_used = final_mem - initial_mem

    print(f"GPU Memmory used: {mem_used/ (1024 ** 2):.2f} MB")


def trainable_params(model):
    total_params = 0
    for param in model.parameters():
        if param.requires_grad:
            total_params += param.numel()

    print(f"Trainable params: {total_params/1e6:.2f} M")

## 2. Data loader

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 
])

BATCH_SIZE = 256

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

## 3. LoRA and BnB quantize

In [3]:
@dataclass
class Quantize_config:
    quant_type: str = "nf4"
    compute_dtype: torch.dtype = torch.float16
    compress_statistics: bool = False
    device: str = "cuda"

@dataclass
class LoRA_config():
    r: int = 16
    lora_alpha: int = 1
    lora_dropout: float = 0.
    merge_weights: bool = True

In [4]:
class LoRA_Layer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        self.merged = False
        self.merge_weights = merge_weights


In [5]:
class QLoRA_Linear(bnb.nn.Linear4bit, LoRA_Layer):
    def __init__(self, 
                 quantize_config: Quantize_config,
                 lora_config: LoRA_config,
                 source_linear: nn.Linear,
                ) -> None:
        
        bnb.nn.Linear4bit.__init__(self, 
                                   source_linear.in_features, 
                                   source_linear.out_features, 
                                   source_linear.bias is not None,
                                   **asdict(quantize_config))
        LoRA_Layer.__init__(self, **asdict(lora_config))
        
        if self.r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((self.r, source_linear.in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((source_linear.out_features, self.r)))
            self.scaling = self.lora_alpha / self.r

        if hasattr(self, 'lora_A'):
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
        
        self.weight = bnb.nn.Params4bit(data=source_linear.weight, 
                                        quant_type=quantize_config.quant_type, 
                                        requires_grad=False)

    def forward(self, x: torch.Tensor):
        result = self.forward_impl(x)
        result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
        return result


    def forward_impl(self, x: torch.Tensor):
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, "quant_state", None) is None:
            if getattr(self, "quant_state", None) is not None:
                assert self.weight.shape[1] == 1
                if not isinstance(self.weight, bnb.nn.Params4bit):
                    self.weight = bnb.nn.Params4bit(self.weight, quant_storage=self.quant_storage)
                self.weight.quant_state = self.quant_state
            else:
                print(
                    "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
                )
        if not self.compute_type_is_set:
            self.set_compute_type(x)
            self.compute_type_is_set = True

        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
        out = out.to(inp_dtype)
        return out


## 4. Modeling

In [6]:
def get_peft_model(model, quantize_config, lora_config):
    new_layers = []
    for layer in model.children():
        if isinstance(layer, nn.Linear):
            new_layer = QLoRA_Linear(quantize_config, lora_config, layer)
            new_layers.append(new_layer)
        else:
            new_layers.append(layer)
    return nn.Sequential(*new_layers)

In [7]:
from torchvision.models import vgg16, VGG16_Weights

pretrain_vgg16 = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)


lora_classifier = get_peft_model(pretrain_vgg16.classifier, 
                                          Quantize_config(), 
                                          LoRA_config())


class CLS_model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.features = pretrain_vgg16.features.eval()
        for param in self.features.parameters():
            param.requires_grad_(False)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier =  nn.Sequential(
            *lora_classifier,
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=1000, out_features=10, bias=True)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [8]:
model = CLS_model()

In [9]:
check_memory_allocated(model, device)

GPU Memmory used: 117.69 MB


In [10]:
trainable_params(model)


Trainable params: 0.70 M


## 5. Training

In [12]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())


In [13]:
import time 

model.to(device)
model.train()

start = time.time()
for epoch in range(10):  
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(trainloader, 0):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch + 1}/{10}], Average Loss: {running_loss / len(trainloader):.4f}, GPU used: {torch.cuda.memory_allocated(0)/1e9:.2f} G')

print('Finished Training')
print(f'Training time: {time.time() - start:.2f} s')

Epoch [1/10], Average Loss: 1.5365, GPU used: 0.16 G
Epoch [2/10], Average Loss: 1.2916, GPU used: 0.16 G
Epoch [3/10], Average Loss: 1.2154, GPU used: 0.16 G
Epoch [4/10], Average Loss: 1.1515, GPU used: 0.16 G
Epoch [5/10], Average Loss: 1.1128, GPU used: 0.16 G
Epoch [6/10], Average Loss: 1.0824, GPU used: 0.16 G
Epoch [7/10], Average Loss: 1.0618, GPU used: 0.16 G
Epoch [8/10], Average Loss: 1.0405, GPU used: 0.16 G
Epoch [9/10], Average Loss: 1.0250, GPU used: 0.16 G
Epoch [10/10], Average Loss: 1.0062, GPU used: 0.16 G
Finished Training
Training time: 130.53 s


## 6. Evaluate

In [14]:
correct = 0
total = 0
model.eval()
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the 10000 test images: {100 * correct / total} %')

Accuracy of the model on the 10000 test images: 64.75 %
