In [1]:
import torch
import sys
sys.path.append('../')
from model.continuous_prompt import ContinuousPromptingLLM
from model.graph_encoder import GraphContinuousPromptModel
from model.projection import BasicProjection
from dataset import GraphDataset

from tqdm import tqdm
from util import convert_answer

In [2]:
MODE='test'
TASK='cycle_check'
MODEL_NAME = 'gin'
SAVE_DIR=f'/home/bonbak/continuous-prompting/output/{TASK}'
TASKS_DIR = f'/home/bonbak/continuous-prompting/task/{TASK}'
DEVICE='cuda:3'

In [3]:
test_dataset = GraphDataset(f"{TASKS_DIR}/{MODE}.jsonl")
print('\n'.join((test_dataset.data[0]['node_information'], test_dataset.data[0]['edge_information'], test_dataset.data[0]['question'])))

G describes a graph among nodes 0, 1, 2, 3, 4, 5, 6, and 7.
The edges in G are: (0, 1) (0, 5) (0, 6) (2, 6) (4, 7).
You must answer with "Yes" or "No" under the question.
Q: Is node 0 connected to node 2?
A: 


In [4]:
test_dataset

<data.dataset.GraphDataset at 0x76f0306364d0>

In [5]:
continuous_prompt_model = GraphContinuousPromptModel(input_dim=5, hidden_dim=512)
projection_module = BasicProjection(continuous_prompt_model.model.hidden_dim)

model = ContinuousPromptingLLM(
    "google/gemma-2b-it",
    continuous_prompt_model, 
    continuous_prompt_model.model.hidden_dim
)

model.continuous_prompt_model.load_state_dict(torch.load(f'{SAVE_DIR}/model/{MODEL_NAME}-encoder.bin'))
model.projection_module.load_state_dict(torch.load(f'{SAVE_DIR}/model/{MODEL_NAME}-projection.bin'))

continuous_prompt_model.to(DEVICE)
model.to(DEVICE)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

ContinuousPromptingLLM(
  (llm_model): GemmaForCausalLM(
    (model): GemmaModel(
      (embed_tokens): Embedding(256000, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x GemmaDecoderLayer(
          (self_attn): GemmaSdpaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): GemmaRotaryEmbedding()
          )
          (mlp): GemmaMLP(
            (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
            (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): GemmaRMSNorm()
        

In [6]:
model.eval()
pred = []
label = []

for input_text, continuous_prompt_input, answer_list in tqdm(test_dataset):
    with torch.no_grad():
        inputs_embeds, attention_mask = model.make_input_embed([input_text], continuous_prompt_input, embedding_first=True)
        output = model.llm_model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=4)
        pred.append(model.llm_tokenizer.batch_decode(output, skip_special_tokens=True)[0])
        label.append(answer_list)

100%|██████████| 700/700 [00:35<00:00, 19.91it/s]


In [7]:
y_pred, missed = convert_answer(pred)
y_true, _ = convert_answer(label)

In [8]:
from sklearn.metrics import accuracy_score, f1_score

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

print(accuracy)
print(f1)

0.7814285714285715
0.7171903881700554
