In [1]:
import os
import torch
import torch.nn as nn
from shared.utils import estimate_model_memory

from transformers import (
    AutoTokenizer, 
    AutoConfig, 
    AutoModelForCausalLM, 
    Trainer, 
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

from datasets import load_dataset, DatasetDict

from peft import LoraConfig, get_peft_model
from shared.eval import compute_metrics


In [2]:
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

In [3]:
mem = estimate_model_memory("bigscience/bloom-560m")
print(f"Estimated model memory usage: {mem:.2f} GB")

Estimated model memory usage: 2.08 GB


In [4]:
for param in model.parameters():
    param.requires_grad = False
    if param.ndim == 1:
        param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
    def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

In [5]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model=model, peft_config=config)
model.print_trainable_parameters()

'NoneType' object has no attribute 'cadam32bit_grad_fp32'
trainable params: 1,572,864 || all params: 560,787,456 || trainable%: 0.2805


  warn("The installed version of bitsandbytes was compiled without GPU support. "


### Dataset Prep

In [6]:
data = load_dataset("Abirate/english_quotes")

def merge_columns(example):
    example["prediction"] = example["quote"] + " ->: " + str(example["tags"])
    return example

data["train"] = data["train"].map(merge_columns)
items = data["train"]["prediction"][:5]
print(data)


DatasetDict({
    train: Dataset({
        features: ['quote', 'author', 'tags', 'prediction'],
        num_rows: 2508
    })
})


In [7]:
train_test_split = data['train'].train_test_split(test_size=0.3)

final_dataset = DatasetDict({
    'train': train_test_split['train'],
    'validation': train_test_split['test']
})
print(final_dataset)

DatasetDict({
    train: Dataset({
        features: ['quote', 'author', 'tags', 'prediction'],
        num_rows: 1755
    })
    validation: Dataset({
        features: ['quote', 'author', 'tags', 'prediction'],
        num_rows: 753
    })
})


In [8]:
tokinzed_dataset = final_dataset.map(lambda samples: tokenizer(samples['prediction']), batched=True)
print(tokinzed_dataset)

Map:   0%|          | 0/1755 [00:00<?, ? examples/s]

Map:   0%|          | 0/753 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['quote', 'author', 'tags', 'prediction', 'input_ids', 'attention_mask'],
        num_rows: 1755
    })
    validation: Dataset({
        features: ['quote', 'author', 'tags', 'prediction', 'input_ids', 'attention_mask'],
        num_rows: 753
    })
})


### Training

In [9]:
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    warmup_steps=3,
    max_steps=5,
    learning_rate=1e-3,
    eval_strategy="steps",
    save_strategy="steps",
    logging_steps=1,
    output_dir='outputs',
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

In [10]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model, 
    args=training_args,
    train_dataset=tokinzed_dataset['train'],
    eval_dataset=tokinzed_dataset['validation'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)
trainer.train()

max_steps is given, it will override any value given in num_train_epochs


  0%|          | 0/5 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 3.6324, 'grad_norm': 1.9654381275177002, 'learning_rate': 0.0003333333333333333, 'epoch': 0.0}


  0%|          | 0/753 [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 11.51 GB, other allocations: 36.70 MB, max allowed: 18.13 GB). Tried to allocate 6.73 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).