In [1]:
import io
import os
import torch
from tqdm.notebook import tqdm

from transformers import set_seed
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader

import torch.nn.functional as F
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

import os

from datasets import load_from_disk
from transformers import GPT2Tokenizer

model_id = "gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model_base = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype = torch.bfloat16,device_map = device)
import numpy as np

dataset = "sst2"
dataset_org = load_from_disk(dataset)

yes_id = tokenizer.convert_tokens_to_ids("yes")
no_id = tokenizer.convert_tokens_to_ids("no")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

shuffled_data = dataset_org.shuffle(seed=42)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
from peft import LoraConfig, TaskType, get_peft_model
if "gemma" in model_id:
    target_modules = ["embed_tokens","q_proj","k_proj","v_proj","o_proj","o_proj","gate_proj","up_proj","down_proj","lm_head"]
    if model_id == "gemma-7b":
        r, alpha = 64, 128
    elif model_id == "gemma-2b":
        r, alpha = 512, 1024
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False,
                             target_modules = target_modules, r=r, lora_alpha=alpha, lora_dropout=0.1)
model = get_peft_model(model_base, peft_config)
model.print_trainable_parameters()

trainable params: 233,177,088 || all params: 8,770,857,984 || trainable%: 2.658543650180712


In [3]:

if dataset == "imdb" or  dataset == "news":
    shuffled_data["test"] = shuffled_data["test"].select(range(3000))
elif dataset == "sst2":
    shuffled_data["test"] = shuffled_data["validation"]

if dataset == "imdb" or  dataset == "sst2":
    shuffled_data = shuffled_data.filter(lambda x: x["label"] in [0,1])
elif dataset == "news":
    shuffled_data = shuffled_data.filter(lambda x: x["label"] in [0,1,2,3])
if dataset == "imdb" or  dataset == "news":
    shuffled_data = shuffled_data.filter(lambda x: isinstance(x["text"], str))
elif dataset == "sst2":
    shuffled_data = shuffled_data.filter(lambda x: isinstance(x["sentence"], str))

In [4]:
world_id = tokenizer.convert_tokens_to_ids("World")
business_id = tokenizer.convert_tokens_to_ids("Business")
sport_id = tokenizer.convert_tokens_to_ids("Sport")
tech_id = tokenizer.convert_tokens_to_ids("Tech")
positive_id = tokenizer.convert_tokens_to_ids("positive")
negative_id = tokenizer.convert_tokens_to_ids("negative")
tech_id, sport_id, business_id, world_id, positive_id, negative_id

(9254, 25832, 14103, 10772, 30212, 31827)

In [5]:
def instruction(batch):
    if dataset == "news":
        batch["ins"] = batch["text"] + " You are classifying a news article, Choose one of the four categories, World, Business, Sport, and Tech. Answer is "
    elif dataset == "sst2":
        batch["ins"] = batch["sentence"] + " In a sentiment classification task between positive and negative choices, the sentiment of this sentence is "
    elif dataset == "imdb":
        batch["ins"] = batch["text"] + " Based on this opinion, decide what the sentiment is, choose between positive and negative. Answer is "
    return batch

instruction_data = shuffled_data.map(instruction, batched=False)

def gpt_label(d):
    if dataset == "news":
        if d["label"] == 0:
            d["gpt_label"] = world_id
        elif d["label"] == 1:
            d["gpt_label"] = sport_id

        elif d["label"] == 2:
            d["gpt_label"] = business_id
        elif d["label"] == 3:
            d["gpt_label"] = tech_id
            
    elif dataset == "sst2" or dataset == "imdb":
        if d["label"] == 0:
            d["gpt_label"] = negative_id
        elif d["label"] == 1:
            d["gpt_label"] = positive_id
    return d


instruction_data = instruction_data.map(gpt_label, batched=False)

instruction_data = instruction_data.remove_columns(["label"])

if dataset == "sst2":
    max_length = 32
    if model_id == "gemma-7b":
        batch_size = 8
    elif model_id == "gemma-2b":
        batch_size = 16

elif dataset == "news":
    max_length = 64
    if model_id == "gemma-7b":
        batch_size = 6
    elif model_id == "gemma-2b":
        batch_size = 12
    instruction_data["test"] = instruction_data["test"].select(range(3000))

elif dataset == "imdb":
    max_length = 128
    if model_id == "gemma-7b":
        batch_size = 3
    elif model_id == "gemma-2b":
        batch_size = 12
    instruction_data["test"] = instruction_data["test"].select(range(3000))
def func2(a):
    a = tokenizer(a['ins'], padding="max_length",max_length = max_length)
    return a

instruction_data = instruction_data.map(func2, batched= True, batch_size = batch_size)
instruction_data = instruction_data.filter(lambda x: len(x["input_ids"]) <= max_length)


Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Filter:   0%|          | 0/67349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/872 [00:00<?, ? examples/s]

In [6]:

set_seed(123)


import torch

import torch
op = torch.optim.Adam(model.parameters(),lr=5e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(op,"max",factor=0.5, patience = 1,verbose=True)

num_epoch = 2
model.to(device)
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
from tqdm import tqdm
for epoch in range(num_epoch):
    loss = 0
    total = 0
    num_correct = 0

    for i in tqdm(range(0, len( instruction_data["train"]),batch_size)):
        ids = torch.tensor(instruction_data["train"][i:i+batch_size]["input_ids"],dtype=torch.int64).to(device)
        mask = torch.tensor(instruction_data["train"][i:i+batch_size]["attention_mask"],dtype=torch.int64).to(device)
        label = torch.tensor(instruction_data["train"][i:i+batch_size]["gpt_label"],dtype=torch.int64).to(device)
        logit = model(input_ids = ids,attention_mask = mask).logits[:,-1,:]
        l = loss_fn(logit, label)
        l.backward()
        op.step()
        op.zero_grad()
        num_correct = num_correct + (torch.argmax(logit,dim=-1) ==label).sum().item()
        total = total + label.shape[0]
        loss = loss + l.item()
 

    scheduler.step(loss)

    print(f"epoch {epoch} train loss: {loss/len(instruction_data['train'])}")
    print(f"epoch {epoch} train acc: {num_correct/total}")

    loss = 0
    total = 0
    num_correct = 0
    with torch.no_grad():
        for i in tqdm(range(0, len( instruction_data["test"]),batch_size)):
            ids = torch.tensor(instruction_data["test"][i:i+batch_size]["input_ids"],dtype=torch.int64).to(device)
            mask = torch.tensor(instruction_data["test"][i:i+batch_size]["attention_mask"],dtype=torch.int64).to(device)
            label = torch.tensor(instruction_data["test"][i:i+batch_size]["gpt_label"],dtype=torch.int64).to(device)
            logit = model(input_ids = ids,attention_mask = mask).logits[:,-1,:]
            l = loss_fn(logit, label)
            num_correct = num_correct + (torch.argmax(logit,dim=-1) ==label).sum().item()
            total = total + label.shape[0]
            loss = loss + l.item()
        
    print(f"epoch {epoch} eval loss: {loss/len(instruction_data['test'])}")
    print(f"epoch {epoch} eval acc: {num_correct/total}")

100%|██████████| 5723/5723 [22:49<00:00,  4.18it/s]


epoch 0 train loss: 0.03795564974565765
epoch 0 train acc: 0.8789129912401424


100%|██████████| 21/21 [00:02<00:00,  8.83it/s]


epoch 0 eval loss: 0.03171209569043684
epoch 0 eval acc: 0.8719512195121951


100%|██████████| 5723/5723 [22:50<00:00,  4.18it/s]


epoch 1 train loss: 0.015095048907970242
epoch 1 train acc: 0.9530113375712694


100%|██████████| 21/21 [00:02<00:00,  8.84it/s]

epoch 1 eval loss: 0.03298880543031662
epoch 1 eval acc: 0.9024390243902439





In [7]:
model.merge_and_unload()
torch.save(model.base_model.model.state_dict(),
            f"predictor_{model_id}_{dataset}_merged.pt")