In [None]:
! git clone https://github.com/hyintell/BLOOM-fine-tuning.git
%cd BLOOM-fine-tuning
! pip install -r requirements.txt 

# Libraries

In [1]:
import torch
import transformers
from transformers import BloomTokenizerFast, BloomForCausalLM, TrainingArguments

from datasets import load_dataset

from utils import ModifiedTrainer, tokenise_data, data_collator

 # Main

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")

In [4]:
model_name = "bloom-560m"
model = BloomForCausalLM.from_pretrained(f"bigscience/{model_name}")
tokeniser = BloomTokenizerFast.from_pretrained(f"bigscience/{model_name}", add_prefix_space=True)

In [5]:
# dataset = load_dataset('tatsu-lab/alpaca')
dataset = load_dataset("json", data_files="../nba_stats_with_text.json")
input_ids = tokenise_data(dataset, tokeniser)

model.to("cpu")

Downloading and preparing dataset json/default to C:/Users/Pierre-Hadrien/.cache/huggingface/datasets/json/default-a88fadb3b7372eba/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to C:/Users/Pierre-Hadrien/.cache/huggingface/datasets/json/default-a88fadb3b7372eba/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:00<00:00, 357.16it/s]


BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 1024)
    (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0): BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
      (1): BloomBlock(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementw

In [6]:
from dataclasses import dataclass, field
from typing import Optional

import torch
import tqdm
from transformers import Trainer

class ModifiedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs["input_ids"].to(model.device)
        attention_mask = torch.ones_like(input_ids).bool().to(model.device)
        labels = inputs["input_ids"].to(model.device)
        
        return model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        ).loss


def data_collator(features: list) -> dict:
    return {"input_ids": torch.stack([torch.LongTensor(f) for f in features])}


def tokenise_data(dataset, tokenizer, max_seq_length=512):
    tokenised_list = []
    for elem in tqdm.tqdm(dataset["train"]):
        tokenised_list.append(
            tokenizer.encode(
                elem["text"],
                max_length=max_seq_length,
                padding="max_length",
                truncation=True,
            )
        )
    return tokenised_list


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="bigscience/bloom-560m")


@dataclass
class DataArguments:
    data_name_or_path: str = field(
        default="tatsu-lab/alpaca", metadata={"help": "Path to the training data."}
    )


In [7]:
model.gradient_checkpointing_enable()
model.is_parallelizable = True
model.model_parallel = True

training_args = TrainingArguments(
    "output",
    fp16=False,
    gradient_accumulation_steps= 1,
    per_device_train_batch_size = 2,
    learning_rate = 2e-5,
    num_train_epochs=2,
    logging_steps=10,
)


trainer = ModifiedTrainer(
    model=model,
    train_dataset=input_ids,
    args=training_args,
    data_collator=data_collator,
)


In [8]:
trainer.train()



  0%|          | 0/10 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


{'loss': 8.2109, 'learning_rate': 0.0, 'epoch': 2.0}
{'train_runtime': 936.4405, 'train_samples_per_second': 0.021, 'train_steps_per_second': 0.011, 'train_loss': 8.210888671875, 'epoch': 2.0}


TrainOutput(global_step=10, training_loss=8.210888671875, metrics={'train_runtime': 936.4405, 'train_samples_per_second': 0.021, 'train_steps_per_second': 0.011, 'train_loss': 8.210888671875, 'epoch': 2.0})

In [18]:
def generate(prompt='', num_samples=10, steps=20, do_sample=True):
    # Load the trained model
    model.eval()

    # Tokenize the input prompt into integer input sequence
    encoded_input = tokeniser(prompt, return_tensors='pt').to(device)
    x = encoded_input['input_ids']

    # We'll process all desired num_samples in a batch, so expand out the batch dim
    x = x.expand(num_samples, -1)

    # Forward the model `steps` times to get samples, in a batch
    with torch.no_grad():
        y = model.generate(
            x,
            max_length= 100,  # Set max_length based on input sequence length + steps
            do_sample=do_sample,
            top_k=40,
            pad_token_id=tokeniser.pad_token_id
        )

    for i in range(num_samples):
        out = tokeniser.decode(y[i].cpu().squeeze().tolist())
        print('-' * 80)
        print(out)


In [19]:
generate(prompt='"Generate specific player stats for player: Herbert Jones', num_samples=5, steps=30)

--------------------------------------------------------------------------------
"Generate specific player stats for player: Herbert Jones.", "race", { }, true, player, "hernandez", { }, { }, {}, {}, { }, { }, { }, { }, { }, { }, { }, { }, { ", ", ", ", "}", 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ", ", ", ", "}", 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, ", ", ", ", "}", 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, ", ", ", ", "}", 0, 0, 0, 0, 0, 0, 0, 0,
