In [None]:
import torch
import sys
sys.path.append('../')
from model.continuous_prompt import ContinuousPromptingLLM
from model.recsys_encoder import RecsysContinuousPromptModel
from model.projection import BasicProjection
from dataset import RecsysDataset

from tqdm import tqdm
from util import convert_answer

In [None]:
MODE='test'
TASK='recommendation'
MODEL_NAME = 'light-gcn'
LLM_DIR = "/SSL_NAS/bonbak/model/models--yanolja--EEVE-Korean-Instruct-2.8B-v1.0/snapshots/482db2d0ba911253d09342c34d0e42ac871bfea3"
SAVE_DIR=f'/home/bonbak/continuous-prompting/output/{TASK}'
TASKS_DIR = f'/home/bonbak/continuous-prompting/task/{TASK}'
DEVICE='cuda:2'

In [None]:
test_dataset = RecsysDataset(f"{TASKS_DIR}/{MODE}.jsonl", f"{TASKS_DIR}/edge.csv")
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
num_users, num_items = len(test_dataset.user_mapping), len(test_dataset.item_mapping)

In [None]:
continuous_prompt_model = RecsysContinuousPromptModel(num_users,num_items,f'{TASKS_DIR}/train_edge_index.pt')
projection_module = BasicProjection(continuous_prompt_model.model.embedding_dim)

model = ContinuousPromptingLLM(
    LLM_DIR,
    continuous_prompt_model, 
    continuous_prompt_model.model.embedding_dim
)

model.continuous_prompt_model.load_state_dict(torch.load(f'{SAVE_DIR}/model/{MODEL_NAME}-encoder.bin'))
model.projection_module.load_state_dict(torch.load(f'{SAVE_DIR}/model/{MODEL_NAME}-projection.bin'))

continuous_prompt_model.to(DEVICE)
model.to(DEVICE)

In [None]:
model.eval()
pred = []
label = []

idx = 0
for input_text, continuous_prompt_input, answer_list in tqdm(test_dataloader):
    with torch.no_grad():
        inputs_embeds, attention_mask = model.make_input_embed(input_text, continuous_prompt_input, embedding_first=True)
        output = model.llm_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, pad_token_id=model.llm_tokenizer.eos_token_id, max_new_tokens=1)
        pred.append(model.llm_tokenizer.batch_decode(output, skip_special_tokens=True)[0])
        label.append(answer_list[0])
    if idx == 500:
        break
    idx+=1

In [None]:
import numpy as np
def convert_answer(answer):
    converted = []
    for a in answer:
        a = a.strip()
        if a == '예':
            converted.append(1)
        elif a == '아니':
            converted.append(0)
        else:
            converted.append(-1)
    return np.array(converted)

In [None]:
y_pred = convert_answer(pred)
y_true = convert_answer(label)

In [None]:
from sklearn.metrics import accuracy_score, f1_score

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

print(accuracy)
print(f1)