In [None]:
# @title # Disable Weights and Biases

# @markdown ### Enable this cell if you are using this notebook in Kaggle

# import wandb
# wandb.init(mode="disabled")

In [None]:
# @title # Installing Libraries

!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U trl

In [2]:
# @title # Importing Libraries

import torch
import gc
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from trl import SFTTrainer

from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
# @title ## Free Memory

def clean():
    gc.collect()
    torch.cuda.empty_cache()
clean()

In [4]:
# @title ## Your Model and Dataset

# @markdown ### Model
# @markdown Select your model

model_name = "Sharathhebbar24/math_gpt2_sft" # @param {type:"string"}

# @markdown ### Dataset
# @markdown Select your datasets

dataset_name = "gamino/wiki_medical_terms" # @param {type:"string"}

# @markdown ### Choose your split

split = "train" # @param {type: "string"}


In [None]:
# @title ## Load Model and Dataset

try:
  model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
except:
  print("Check if the model exists or not")
if split is None or split == "":
  split = "train"
try:
  dataset = load_dataset(dataset_name, split=split)
  print(dataset.to_pandas().head())
  num_rows = dataset.num_rows
  print("Total number of rows in dataset is: ", num_rows)
except:
  print("Check if dataset or split exists or not")


In [6]:
# @title ## Test Size
test_size = 0.1 # @param {type:"slider", min:0.1, max:0.5, step:0.1}
dataset = dataset.shuffle(42).select(range(num_rows)).train_test_split(test_size=test_size, seed=42)
dataset

DatasetDict({
    train: Dataset({
        features: ['page_title', 'page_text', '__index_level_0__'],
        num_rows: 6174
    })
    test: Dataset({
        features: ['page_title', 'page_text', '__index_level_0__'],
        num_rows: 687
    })
})

In [None]:
# @title ## Data Splitting and Tokenizing

train_dataset = dataset['train']
test_dataset = dataset['test']

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
data_collator

In [11]:
# @title ## Training Arguments

batch_size = 2 # @param {type:"integer"}
max_steps = 100 # @param {type:"integer"}
gradient_accumulation_steps = 2 # @param {type:"integer"}
per_device_train_batch_size = 2 # @param {type:"integer"}
per_device_eval_batch_size = 2 # @param {type:"integer"}
learning_rate = 2e-5 # @param {type:"number"}
output_dir = "./models/gpt2" # @param {type:"string"}

# @markdown ## Enable it if you are using GPU
fp16 = True # @param {type:"boolean"}

training_arguments = TrainingArguments(
    output_dir=output_dir,
    gradient_accumulation_steps=gradient_accumulation_steps,
    evaluation_strategy="steps",
    do_eval=True,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    log_level="debug",
    save_strategy="no",
    save_total_limit=2,
    save_safetensors=True,
    fp16=fp16,
    logging_steps=50,
    learning_rate=learning_rate,
    eval_steps=50,
    max_steps=max_steps,
    warmup_steps=30,
    lr_scheduler_type="cosine",
)

In [None]:
# @title ## Training

dataset_text_field = "page_text" # @param {type:"string"}
max_seq_length = 512 # @param {type:"integer"}
trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        dataset_text_field=dataset_text_field,
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        args=training_arguments,
)

trainer.train()

In [None]:
# @title ## Pushing to Hub
MODEL_PATH = "Sharathhebbar24/math_gpt2_sft" # @param {type:"string"}
HF_TOKEN = "" # @param {type:"string"}

tokenizer.push_to_hub(
    MODEL_PATH,
    token=HF_TOKEN
)

model.push_to_hub(
    MODEL_PATH,
    token=HF_TOKEN
)