# Example of Fine-tuning 176 billion Bloom with 8-bit weights

This notebook shows an example of how to fine tune Bloom with Low Rank Adapters. Heavily inspired by [Hivemind's work](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)

In [1]:
#!g1.1
import transformers
transformers.__version__

  from .autonotebook import tqdm as notebook_tqdm


'4.22.2'

In [3]:
#!g1.1
#%pip uninstall transformers tokenizers
#%pip install transformers==4.20.1
%pip install bitsandbytes
%pip install datasets
%pip install accelerate

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.
Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


### Load and convert original Bloom structure to 8-bit LoRA

You can load an already compressed 8-bit version of Bloom from [joaoalvarenga/bloom-8bit](https://huggingface.co/joaoalvarenga/bloom-8bit), but first we need to make some adaptations into original model structure. Some of the following code is an adaptation from [Hivemind's GPT-J 8-bit fine-tuning notebook](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es).

In [4]:
#!g1.1
import transformers

import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise

from tqdm.auto import tqdm


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link
CUDA SETUP: Loading binary /home/dm/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cpu.so...


  warn(
  warn(
  warn(


In [5]:
import gc
import json
import os
import re
import shutil
import tempfile
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss


In [6]:
from torch.cuda.amp import custom_fwd, custom_bwd

class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)


def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr( 
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

In [7]:
class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):
    def __init__(self, config, layer_number=None):
        super().__init__(config, layer_number)

        convert_to_int8(self.self_attention)
        convert_to_int8(self.mlp)


class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        
#transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock

In [8]:
import transformers
print(transformers.__path__)

['/home/dm/.local/lib/python3.8/site-packages/transformers']


In [9]:
%load_ext autoreload
%autoreload 2

In [10]:
#!g1.1
import psutil
def get_mem():
    return psutil.virtual_memory().available / 1024 / 1024
get_mem()

52904.12890625

In [11]:
#!g1.1
import torch
import torch.nn as nn
import torch.nn.functional as F

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
import transformers

model_name = "bigscience/bloom-7b1"
gpt = transformers.BloomForCausalLM.from_pretrained(model_name, cache_dir="mycache")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name,  cache_dir="mycache")

Downloading: 100%|██████████████████████████████| 734/734 [00:00<00:00, 362kB/s]
Downloading: 100%|██████████████████████████| 27.5k/27.5k [00:00<00:00, 238kB/s]
Downloading: 100%|█████████████████████████| 9.98G/9.98G [14:38<00:00, 11.4MB/s]
Downloading: 100%|█████████████████████████| 4.16G/4.16G [06:09<00:00, 11.3MB/s]


In [12]:
#!g1.1
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
        
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)


In [13]:
#!g1.1
from typing import Tuple
from torch.cuda.amp import custom_fwd, custom_bwd


class DequantizeAndLinear(torch.autograd.Function):
    
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
    
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias


class BNBLinearWithAdapter(nn.Module):
    def __init__(self, weight, absmax, code,  bias=None, adapter_dim=0):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.bias = bias
        
        if adapter_dim > 0:
            self.adapter = nn.Sequential(
                nn.Linear(self.in_features, adapter_dim, bias=False),
                nn.Linear(adapter_dim, self.out_features, bias=False),
            )
            
            nn.init.zeros_(self.adapter[1].weight)
        else:
            self.adapter = None
        
    def forward(self, input):
        out = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        
        if self.adapter:
            return self.adapter(input) + out
            
        return out
        
    
    @classmethod
    def from_linear(cls, linear: nn.Linear, **kwargs) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias, **kwargs)
    
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"


class BNBEmbeddingWithAdapter(nn.Module):
    def __init__(self, weight, absmax, code, adapter_dim=0):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        
        if adapter_dim > 0:
            self.adapter = nn.Sequential(
                nn.Embedding(self.num_embeddings, adapter_dim),
                nn.Linear(adapter_dim, self.embedding_dim, bias=False),
            )
            
            nn.init.zeros_(self.adapter[1].weight)
        else:
            self.adapter = None
        
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            out = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            return out + self.adapter(input, **kwargs)
        
        return out
    
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding, **kwargs) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state, **kwargs)
    
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"

In [14]:
#!g1.1
def bnbfy_(model, adapter_dim: int = 0):
    modules = list(model.transformer.h.modules())
    
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr(module, name, BNBLinearWithAdapter.from_linear(child, adapter_dim=adapter_dim))
                
            elif isinstance(child, nn.Embedding):
                print(name, child)
                setattr(module, name, BNBEmbeddingWithAdapter.from_embedding(child, adapter_dim=adapter_dim))

In [15]:
#!g1.1
gpt

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 4096)
    (word_embeddings_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0): BloomBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
      (1): BloomBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, eleme

In [16]:
#!g1.1
bnbfy_(gpt, adapter_dim=0)

lm_head Linear(in_features=4096, out_features=250880, bias=False)
word_embeddings Embedding(250880, 4096)
query_key_value Linear(in_features=4096, out_features=12288, bias=True)
dense Linear(in_features=4096, out_features=4096, bias=True)
dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)
dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)
query_key_value Linear(in_features=4096, out_features=12288, bias=True)
dense Linear(in_features=4096, out_features=4096, bias=True)
dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)
dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)
query_key_value Linear(in_features=4096, out_features=12288, bias=True)
dense Linear(in_features=4096, out_features=4096, bias=True)
dense_h_to_4h Linear(in_features=4096, out_features=16384, bias=True)
dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)
query_key_value Linear(in_features=4096, out_features=12288, bias=True)
den

dense_4h_to_h Linear(in_features=16384, out_features=4096, bias=True)


In [None]:
model = transformers.BloomForCausalLM.from_pretrained(model_name, cache_dir="mycache")
gpt.transformer.word_embeddings.code = model.transformer.word_embeddings.code
gpt.lm_head.absmax = model.lm_head.absmax
gpt.transformer.word_embeddings.absmax = model.transformer.word_embeddings.absmax
gpt.lm_head.code = model.lm_head.code

In [None]:
#!g1.1
param_size = 0
for param in gpt.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in gpt.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2
print('model size: {:.3f}MB'.format(size_all_mb))
gpt.save_pretrained('bloom-8bit-v3.pt')

In [None]:
#!g1.1
prompt = tokenizer("A cat sat on a mat and", return_tensors='pt')
out = model.generate(**prompt, min_length=10, max_length=10, do_sample=True)
tokenizer.decode(out[0])

In [None]:
gpt.save_pretrained('bloom-8bit-v2')