In [5]:
from dataclasses import dataclass

@dataclass
class Args:
    train_file: str = "../public_data/train/track_a/sun.csv"
    test_file: str = "../public_data/dev/track_a/sun_a.csv"
    model_checkpoint: str = "../models/gemma2-9b-cpt-sea-lionv3-base-SemEval-sun"
    max_length: int = 512

args = Args()

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig
from peft import PeftConfig, PeftModel
from datasets import Dataset
import torch
from accelerate import Accelerator
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
test_df = pd.read_csv(args.test_file)
test_df = pd.melt(test_df.drop(["id"], axis=1), id_vars=["text"], var_name="emotion", value_name="label")
dataset = Dataset.from_pandas(test_df)

In [4]:
def preprocess_function(example):
    example["prompt"] = f"### Text: {example['text']}\n### Emotion: {example['emotion']}\n### Label: "
    return example

dataset = dataset.map(preprocess_function)

Map: 100%|██████████| 1194/1194 [00:00<00:00, 12300.28 examples/s]


In [5]:
dataset[0]

{'text': 'Aa aa salam kenal nya,ti abi 🙋 sumpah kot piseurieun 😂😂😂lucuu . Sing sukses nya aa aa😊👌',
 'emotion': 'Anger',
 'label': None,
 'prompt': '### Text: Aa aa salam kenal nya,ti abi 🙋 sumpah kot piseurieun 😂😂😂lucuu . Sing sukses nya aa aa😊👌\n### Emotion: Anger\n### Label: '}

In [6]:
tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [9]:
tokenizer.decode([2])

'<bos>'

In [7]:
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), batched=True)

Map: 100%|██████████| 1194/1194 [00:00<00:00, 14831.05 examples/s]


In [None]:
torch_dtype = torch.float16

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_quant_storage=torch_dtype,
)

device_index = Accelerator().process_index

peft_config = PeftConfig.from_pretrained(args.model_checkpoint)

model = AutoModelForCausalLM.from_pretrained(
    peft_config.base_model_name_or_path,
    attn_implementation="sdpa",  # alternatively use "flash_attention_2"
    torch_dtype=torch_dtype,
    # load_in_4bit=True,
    # quantization_config=quantization_config,
    # device_map={"": device_index},
    device_map="auto",
    low_cpu_mem_usage=True,
)

Loading checkpoint shards:  25%|██▌       | 1/4 [00:42<02:08, 42.78s/it]

In [None]:
model = PeftModel.from_pretrained(model, args.model_checkpoint)
model.eval()

In [5]:
import json

with open("../results/gemma2-9b-cpt-sea-lionv3-base-SemEval-sun.json") as f:
    raw_results = json.load(f)
raw_results = raw_results["predictions"]

In [16]:
results = [result.split("\n")[0].replace(".", "") for result in raw_results]
result2id = {"yes": 1, "no": 0}

In [45]:
import pandas as pd

test_df = pd.read_csv(args.test_file)
text2id = dict(zip(test_df["text"], test_df["id"]))
test_df = pd.melt(test_df.drop(["id"], axis=1), id_vars=["text"], var_name="emotion", value_name="label")
test_df["label"] = results
test_df["label"] = test_df["label"].map(result2id)
test_df = test_df.pivot(index="text", columns="emotion", values="label").reset_index()
test_df["id"] = test_df["text"].map(text2id)
test_df = test_df.sort_values("id")
test_df = test_df[["id", "Anger", "Disgust", "Fear", "Joy", "Sadness", "Surprise"]]
test_df

emotion,id,Anger,Disgust,Fear,Joy,Sadness,Surprise
26,sun_dev_track_a_00001,0,0,0,1,0,0
7,sun_dev_track_a_00002,0,0,0,1,1,1
138,sun_dev_track_a_00003,0,0,0,1,0,1
122,sun_dev_track_a_00004,0,0,0,1,0,0
150,sun_dev_track_a_00005,0,0,0,1,0,0
...,...,...,...,...,...,...,...
86,sun_dev_track_a_00195,0,0,0,1,0,1
115,sun_dev_track_a_00196,0,0,0,1,0,0
61,sun_dev_track_a_00197,0,0,0,0,1,1
45,sun_dev_track_a_00198,0,0,0,1,0,1


In [46]:
test_df.to_csv("pred_sun_a.csv", index=False)

In [47]:
test_df

emotion,id,Anger,Disgust,Fear,Joy,Sadness,Surprise
26,sun_dev_track_a_00001,0,0,0,1,0,0
7,sun_dev_track_a_00002,0,0,0,1,1,1
138,sun_dev_track_a_00003,0,0,0,1,0,1
122,sun_dev_track_a_00004,0,0,0,1,0,0
150,sun_dev_track_a_00005,0,0,0,1,0,0
...,...,...,...,...,...,...,...
86,sun_dev_track_a_00195,0,0,0,1,0,1
115,sun_dev_track_a_00196,0,0,0,1,0,0
61,sun_dev_track_a_00197,0,0,0,0,1,1
45,sun_dev_track_a_00198,0,0,0,1,0,1
