<a href="https://colab.research.google.com/github/RyanGoslenko/Data_Science/blob/main/colab/finetune_gpt_j_6B_8bit_working.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers==4.14.1
!pip install bitsandbytes
!pip install datasets==1.16.1 

In [None]:
import numpy as np
import transformers
import torch
import torch.nn.functional as functional
import matplotlib.pyplot as plt

from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
from bitsandbytes.optim import Adam8bit
from tqdm.auto import tqdm
from IPython.display import clear_output

In [3]:
class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias).clone()
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return functional.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = functional.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr( 
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

In [4]:
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock  # monkey-patch GPT-J

In [5]:
def add_adapters(model, adapter_dim=16):
    assert adapter_dim > 0

    for name, module in model.named_modules():
        if "attn" in name:
            if isinstance(module, FrozenBNBLinear):
                module.adapter = nn.Sequential(
                    nn.Linear(module.in_features, adapter_dim, bias=False),
                    nn.Linear(adapter_dim, module.out_features, bias=False),
                )
                nn.init.zeros_(module.adapter[1].weight)
            elif isinstance(module, FrozenBNBEmbedding):
                module.adapter = nn.Sequential(
                    nn.Embedding(module.num_embeddings, adapter_dim),
                    nn.Linear(adapter_dim, module.embedding_dim, bias=False),
                )
                nn.init.zeros_(module.adapter[1].weight)
              
            elif hasattr(module, "adapter"):
                print("Initializing", name)
                nn.init.zeros_(module.adapter[1].weight)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

gpt = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit", low_cpu_mem_usage=True)
add_adapters(gpt)
gpt.to(device)

In [7]:
with open('small_train.txt', 'r') as f:
    train_data = f.readlines()
with open('small_test.txt', 'r') as f:
    test_data = f.readlines()
with open('small_valid.txt', 'r') as f:
    valid_data = f.readlines()

train_data = train_data[:int(len(train_data)/100)]
test_data = test_data[:int(len(test_data)/100)]
valid_data = valid_data[:int(len(valid_data)/100)]

In [8]:
def train(train_data,
          valid_data,
          n_epochs,
          lr,
          lr_update=False,
          seq_length=256,
          weight_decay=0,
          verbose=False):
  
    if not gpt.training:
        gpt.train()

    optimizer = Adam8bit(gpt.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=int(0.05*len(train_data) * n_epochs),
            num_training_steps=len(train_data) * n_epochs)
    
    gpt.gradient_checkpointing_enable()
    
    model_best_state = {}
    train_losses = []
    train_perplexities = []
    valid_losses = []
    valid_perplexities = []
    min_valid_loss = np.inf

    for epoch in range(n_epochs):
        with torch.cuda.amp.autocast():
            # Trainning loop
            if verbose:
                print("~~~~~~~~ Train ~~~~~~~~")
            for row in tqdm(train_data):
                if len(row) <= 1:
                    continue

                batch = tokenizer(row,
                                  truncation=True,
                                  max_length=seq_length,
                                  return_tensors='pt')
                batch = {k: v.cuda() for k, v in batch.items()}

                out = gpt.forward(**batch,)
                train_loss = functional.cross_entropy(out.logits[:, :-1, :].flatten(0, -2),
                                                      batch['input_ids'][:, 1:].flatten(),
                                                      reduction='mean')
                train_perplexity = torch.exp(train_loss)

                if verbose:
                    #clear_output(wait=True)
                    #plt.plot(train_losses)
                    #plt.show()
                    print("Loss: ", train_loss.item())
                    print("Perplexity: ", train_perplexity.item())
                
                train_losses.append(train_loss.item())
                train_perplexities.append(train_perplexity.item())

                train_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
            
            # Validation loop
            if verbose:
                print("~~~~~~~~ Valid ~~~~~~~~")
            for row in tqdm(valid_data):
                if len(row) <= 1:
                    continue

                batch = tokenizer(row, truncation=True, max_length=seq_length, return_tensors='pt')
                batch = {k: v.cuda() for k, v in batch.items()}

                out = gpt.forward(**batch,)
                valid_loss = functional.cross_entropy(out.logits[:, :-1, :].flatten(0, -2),
                                                      batch['input_ids'][:, 1:].flatten(),
                                                      reduction='mean')
                valid_perplexity = torch.exp(valid_loss)

                if verbose:
                    #clear_output(wait=True)
                    #plt.plot(valid_losses)
                    #plt.show()
                    print("Loss: ", valid_loss.item())
                    print("Perplexity: ", valid_perplexity.item())
                
                valid_losses.append(valid_loss.item())
                valid_perplexities.append(valid_perplexity.item())

                if min_valid_loss > valid_loss:
                    min_valid_loss = valid_loss
                    # Saving State Dict
                    model_best_state = gpt.state_dict()

            if lr_update:
                scheduler.step()

    return model_best_state, train_losses, train_perplexities, valid_losses, valid_perplexities

In [9]:
best_state, train_losses, train_perplexities, valid_losses, valid_perplexities = train(train_data, valid_data, 1, lr=1e-7, lr_update=True, weight_decay=0.01, verbose=True)

~~~~~~~~ Train ~~~~~~~~


  0%|          | 0/59 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  6.229333400726318
Perplexity:  507.4171142578125


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.2015573978424072
Perplexity:  24.570768356323242


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.895937919616699
Perplexity:  49.202178955078125


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.8688344955444336
Perplexity:  17.616474151611328


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  4.329866409301758
Perplexity:  75.93414306640625


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  5.7811079025268555
Perplexity:  324.11810302734375


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.2423670291900635
Perplexity:  25.5942325592041


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.0220274925231934
Perplexity:  20.532878875732422


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.945387363433838
Perplexity:  51.69635772705078


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.7380805015563965
Perplexity:  5.686418056488037


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  0.5028571486473083
Perplexity:  1.653438687324524


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.9530525207519531
Perplexity:  7.050175189971924


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.939511299133301
Perplexity:  18.906604766845703


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.4587748050689697
Perplexity:  11.690479278564453


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.892099142074585
Perplexity:  6.6332783699035645


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.8991730213165283
Perplexity:  18.159120559692383


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.1624107360839844
Perplexity:  8.69206714630127


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.0970659255981445
Perplexity:  8.142244338989258


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.914339542388916
Perplexity:  6.7824578285217285


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.3162474632263184
Perplexity:  10.137561798095703


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  0.6809130311012268
Perplexity:  1.9756807088851929


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.2332143783569336
Perplexity:  9.32980728149414


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.2914795875549316
Perplexity:  26.88260841369629


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  6.741514682769775
Perplexity:  846.8424682617188


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.1804239749908447
Perplexity:  3.255754232406616


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.055257558822632
Perplexity:  7.808848857879639


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.6801081895828247
Perplexity:  5.36613655090332


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.473192811012268
Perplexity:  4.3631439208984375


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.6288052797317505
Perplexity:  5.097780704498291


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.7406761646270752
Perplexity:  5.701197624206543


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.8791614770889282
Perplexity:  6.5480122566223145


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.3434433937072754
Perplexity:  3.832216501235962


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.4372999668121338
Perplexity:  4.209315299987793


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.573521137237549
Perplexity:  13.11191177368164


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  6.741514682769775
Perplexity:  846.8424682617188


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  4.311992645263672
Perplexity:  74.58897399902344


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  0.7977215051651001
Perplexity:  2.220475673675537


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  4.476186275482178
Perplexity:  87.8988037109375


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.8889341354370117
Perplexity:  48.85878372192383


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  4.786746501922607
Perplexity:  119.91060638427734


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  4.519314765930176
Perplexity:  91.77268981933594


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  5.322113037109375
Perplexity:  204.8162078857422


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.991338014602661
Perplexity:  54.127262115478516


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  4.272558689117432
Perplexity:  71.70487213134766


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  4.287032604217529
Perplexity:  72.7502670288086


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  5.498960971832275
Perplexity:  244.43783569335938
Loss:  4.902790546417236
Perplexity:  134.66505432128906
~~~~~~~~ Valid ~~~~~~~~


  0%|          | 0/15 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  6.229333400726318
Perplexity:  507.4171142578125


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.913536310195923
Perplexity:  18.421829223632812


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  2.157444715499878
Perplexity:  8.649008750915527


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.6429179906845093
Perplexity:  5.170234203338623


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.2129244804382324
Perplexity:  3.3633060455322266


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  0.9803085923194885
Perplexity:  2.665278434753418


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.8127042055130005
Perplexity:  6.1269941329956055


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.1879407167434692
Perplexity:  3.2803189754486084


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  3.4230241775512695
Perplexity:  30.662002563476562


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Loss:  1.7974278926849365
Perplexity:  6.034107208251953
Loss:  2.2032575607299805
Perplexity:  9.054460525512695


In [10]:
def predict(test_data, seq_length=256):
    test_losses = []
    test_perplexities = []

    if gpt.training:
        gpt.eval()

    for row in tqdm(test_data):
        if len(row) <= 1:
            continue

        batch = tokenizer(row,
                          truncation=True,
                          max_length=seq_length,
                          return_tensors='pt')
        batch = {k: v.cuda() for k, v in batch.items()}

        out = gpt.forward(**batch,)
        test_loss = functional.cross_entropy(out.logits[:, :-1, :].flatten(0, -2),
                                              batch['input_ids'][:, 1:].flatten(),
                                              reduction='mean')
        test_perplexity = torch.exp(test_loss)
        test_losses.append(test_loss.item())
        test_perplexities.append(test_perplexity.item())
    return test_losses, test_perplexities

In [11]:
test_losses, test_perplexities = predict(test_data)

  0%|          | 0/5 [00:00<?, ?it/s]

In [12]:
print("Mean test perplexity: ")
print(np.mean(test_perplexities))

Mean test perplexity: 
256.3491930067539
