<a href="https://colab.research.google.com/github/robgon-art/DeepHaiku/blob/main/Deep_Haiku_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Deep Haiku Trainer**
## Generating rhythmic prose after finetuning a large transformer with phonemes

By Robert. A Gonsalves</br>

![image](https://raw.githubusercontent.com/robgon-art/DeepHaiku/main/deep_haiku.jpg)

You can see my article on Medium.

The source code and generated images are released under the [CC BY-SA license](https://creativecommons.org/licenses/by-sa/4.0/).</br>
![CC BYC-SA](https://licensebuttons.net/l/by-sa/3.0/88x31.png)

## Acknowledgements
- GPT-J, Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX (2021)
- R. Caruana, Multitask learning (1997)
- E. Hu, et al., LoRA: Low-rank Adaptation of Large Language Models (2021)
- Trained on Haikus collected by [bfbarry](https://www.kaggle.com/bfbarry/haiku-dataset) and [Harshit Jhalani](https://www.kaggle.com/hjhalani30/haiku-dataset) on Kaggle.com

This notebook is a proof of concept for fine-tuning [GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) with limited memory. A detailed explanation of how it works can be found in [this model card](https://huggingface.co/hivemind/gpt-j-6B-8bit).

Adapted from here: https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es

And Nikita Schneider's article, [here](https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475).

In [None]:
!nvidia-smi

In [None]:
import csv
haikus_g = []
haikus_p = []
topics_g = []
topics_p = []
count = 0

!gdown --id 1ZKJiMRFwzkuJwjzhpadzAHt3KxhKNKiV

with open("Deep_Haiku.csv", newline='') as csv_read_file:
  reader = csv.DictReader(csv_read_file)

  for row in reader:
    haikus_g.append(row["Haiku G"])
    haikus_p.append(row["Haiku P"])
    topics_g.append(row["Topic G"])
    topics_p.append(row["Topic P"])
    count +=1

    # if count >= 5:
    #   break

print(count)

In [None]:
import random

# graphemes (topic_g = haiku_g)
# phonemes  <topic_p = haik_p>
# g2p       [haiku_g = haiku_p]
# p2g       {haiku_p = haiku_g}

# (encouragement = Need encouragement. / Making myself positive. / I want happiness.)
# <axn|ker|axjh|maxnt = niyd axn|ker|axjh|maxnt / mey|kaxng may|sehlf paa|zax|tihv / ay waant hhae|piy|naxsy>
# [need encouragement / making myself positive / i want happiness = niyd axn|ker|axjh|maxnt / mey|kaxng may|sehlf paa|zax|tihv / ay waant hhae|piy|naxs]  
# {niyd axn|ker|axjh|maxnt / mey|kaxng may|sehlf paa|zax|tihv / ay waant hhae|piy|naxs = need encouragement / making myself positive / i want happiness}

data = []

# text generation from topics using graphines
for t, h in zip(topics_g, haikus_g):
  line = "("+ t + " = " + h + ")"
  data.append(line)

# text generation from topics using ponemes
for t, h in zip(topics_p, haikus_p):
  line = "<"+ t + " = " + h + ">"
  data.append(line)

# translation from graphemes to phonemes and back
for g, p in zip(haikus_g, haikus_p):
  line = "[" + g + " = " + p + "]"
  data.append(line)
  line = "{" + p + " = " + g + "}"
  data.append(line)

random.shuffle(data)

for d in data[:20]:
  print(d)

In [None]:
!pip install transformers
!pip install bitsandbytes
!pip install datasets

In [None]:
import transformers
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd
from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
from tqdm.auto import tqdm

### Converting the model to 8 bits.

We convert EleutherAI's GPT-J-6B model to 8 bits using facebook's [bitsandbytes](https://github.com/facebookresearch/bitsandbytes) library. This reduces the model's size from 20Gb down to just 6Gb.

Note that we don't convert linear layer biases to 8 bit as they take up less that 1% of the model's weight anyway.

In [None]:
class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = torch.clone(DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias))
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr( 
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

In [None]:
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock  # monkey-patch GPT-J

In [None]:
config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

In [None]:
gpt = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit") #, low_cpu_mem_usage=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpt.to(device)

### Text generation example

In [None]:
with torch.no_grad():
  prompt_tokens2 = tokenizer("My pet pug is", return_tensors="pt").input_ids.cuda()
  sample_outputs = gpt.generate(prompt_tokens2, max_length=40, do_sample=True, temperature=0.7)
  for i, sample_output in enumerate(sample_outputs):
    print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

### LoRA fine-tuning example
Here we demonstrate how to fine-tune the proposed model using low-rank adapters [(Hu et al, 2021)](https://arxiv.org/abs/2106.09685) and [8-bit Adam](https://arxiv.org/abs/2110.02861). We also use [dataset streaming API](https://huggingface.co/docs/datasets/dataset_streaming.html) to avoid downloading the large dataset.

In [None]:
def add_adapters(model, adapter_dim=16):
    assert adapter_dim > 0

    for module in model.modules():
        if isinstance(module, FrozenBNBLinear):
            module.adapter = nn.Sequential(
                nn.Linear(module.in_features, adapter_dim, bias=False),
                nn.Linear(adapter_dim, module.out_features, bias=False),
            )
            nn.init.zeros_(module.adapter[1].weight)
        elif isinstance(module, FrozenBNBEmbedding):
            module.adapter = nn.Sequential(
                nn.Embedding(module.num_embeddings, adapter_dim),
                nn.Linear(adapter_dim, module.embedding_dim, bias=False),
            )
            nn.init.zeros_(module.adapter[1].weight)

add_adapters(gpt)
gpt.to(device)

In [None]:
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer

class HaikuDataset(Dataset):
    def __init__(self, txt_list, tokenizer, max_length):
        self.input_ids = []
        self.attn_masks = []
        self.labels = []
        packed_text = ""
        for i, txt in enumerate(txt_list):
            packed_text += txt
            # print(i, packed_text)
            
            if i%8 == 7:
                encodings_dict = tokenizer(packed_text, truncation=True,
                                          max_length=max_length, padding="max_length")
                self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
                self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
                packed_text = ""

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attn_masks[idx]

In [None]:
max_length = max([len(tokenizer(d)[0]) for d in data])
print(max_length)

In [None]:
tokenizer.pad_token = tokenizer.eos_token
dataset = HaikuDataset(data, tokenizer, max_length=max_length)

In [None]:
!mkdir checkpointscheckpoints
!mkdir /content/checkpoints/output
!mkdir /content/checkpoints/logs

In [None]:
print(len(dataset))

In [None]:
from transformers import TrainerCallback

class SaveCallback(TrainerCallback):
  "A callback that prints a message at the beginning of training"

  def on_step_end(self, args, state, control, **kwargs):
    if state.global_step %5000 == 4999:
      file_name = "/content/checkpoints/output/gpt-j-8bit_full_" + str(state.global_step+1).zfill(6) + ".pt"
      torch.save(gpt, file_name)

In [None]:
train_size = int(0.95 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
training_args = TrainingArguments(output_dir="/content/checkpoints/",num_train_epochs=5, logging_steps=1000,
                                  save_strategy="no", per_device_train_batch_size=2, per_device_eval_batch_size=2,
                                  warmup_steps=100, weight_decay=0.01, logging_dir="logs")
Trainer(model=gpt, args=training_args, train_dataset=train_dataset, callbacks=[SaveCallback],
        eval_dataset=val_dataset, data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]),
                                                              'attention_mask': torch.stack([f[1] for f in data]),
                                                              'labels': torch.stack([f[0] for f in data])}).train()

torch.save(gpt, "/content/checkpoints/gpt-j-8bit_full.pt")

In [None]:
prompt_tokens = tokenizer("(pet pug", return_tensors="pt").input_ids.cuda()
sample_outputs = gpt.generate(prompt_tokens, max_length=85, do_sample=True, 
  num_return_sequences=5, temperature=0.8)

In [None]:
for i, sample_output in enumerate(sample_outputs):
  doc = (tokenizer.decode(sample_outputs[i], skip_special_tokens=True))
  parts = doc.split(")")
  print(parts[0][1:])