In [1]:
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, default_data_collator, Trainer,  Seq2SeqTrainer, GenerationConfig
from args import TrainingArguments, DataTrainingArguments, ArgumentParser

from peft import get_peft_model

from arithmetics import PromptArithmeticsConfig

from tasks import Preprocessor, AutoTask

from utils import get_task_prompt_from_safetensor

from torch.utils.data import DataLoader

from metrics import exact_match, f1

from tqdm import tqdm

In [2]:
saves = ["saves/prompt_tuning_08282024142422_qnli_text_origin_0_meta-llama-3-8b_best", "saves/prompt_tuning_08282024142422_qnli_text_origin_1_meta-llama-3-8b/checkpoint-257500", "saves/prompt_tuning_08282024142517_sst2_text_origin_0_meta-llama-3-8b_best"]
origin_prompt = "origin_0_meta-llama-3-8b"

In [3]:
parser = ArgumentParser(
    (TrainingArguments, DataTrainingArguments, PromptArithmeticsConfig)
)

training_args, data_args, pt_args = parser.parse_toml_file("./configs/prompt_tuning/single-task/llama3_8b.toml")
# training_args.do_train = False
# training_args.do_eval = False



In [4]:
model = AutoModelForCausalLM.from_pretrained(training_args.model_name_or_path, torch_dtype=torch.bfloat16).to("cuda")
model = get_peft_model(model, peft_config=pt_args)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(data_args.data_tokenizer_name_or_path, trust_remote_code=True, padding_side="left")
tokenizer.add_special_tokens({"pad_token": "<|reserved_special_token_0|>"})
model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

In [6]:
preprocessor = Preprocessor(
            data_args.dataset_names, data_args, training_args, pt_args, tokenizer
        )

train_dataset, valid_datasets, test_datasets = preprocessor.get_data()

Max target lengths: [5]


Running qnli_text_preprocessor on dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Running preprocess_function on train_dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Running qnli_text_preprocessor on dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Running preprocess_function on valid_dataset:   0%|          | 0/500 [00:00<?, ? examples/s]

Running qnli_text_preprocessor on dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

Running preprocess_function on test_dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

In [7]:
len(test_datasets["qnli_text"]["attention_mask"][0]), len(test_datasets["qnli_text"]["input_ids"][0]), len(test_datasets["qnli_text"]["labels"][0])

(256, 256, 256)

In [9]:
len(train_dataset["attention_mask"][0]), len(train_dataset["input_ids"][0]), len(train_dataset["labels"][0])

(261, 261, 261)

In [14]:
train_dataset["labels"][0][-6:], train_dataset["input_ids"][0][-6:], tokenizer.decode(train_dataset["labels"][0][-5:]), tokenizer.decode(train_dataset["input_ids"][0][-5:])

([-100, 1962, 28525, 607, 479, 128001],
 [220, 1962, 28525, 607, 479, 128001],
 'not_entailment<|end_of_text|>',
 'not_entailment<|end_of_text|>')

In [25]:
train_dataset["labels"][1][-6:], train_dataset["input_ids"][1][-6:], tokenizer.decode(train_dataset["labels"][1][-4:]), tokenizer.decode(train_dataset["input_ids"][1][-4:])

([-100, -100, 306, 607, 479, 128001],
 [25, 220, 306, 607, 479, 128001],
 'entailment<|end_of_text|>',
 'entailment<|end_of_text|>')

In [8]:
test_datasets["qnli_text"]["labels"][0][-6:], test_datasets["qnli_text"]["input_ids"][0][-6:], tokenizer.decode(test_datasets["qnli_text"]["labels"][0][-3:]), tokenizer.decode(test_datasets["qnli_text"]["input_ids"][0][-3:])

([128002, 128002, 128000, 306, 607, 479],
 [27587, 18874, 13, 2440, 25, 220],
 'entailment',
 ' label: ')

In [9]:
test_dls = {td : DataLoader(test_datasets[td], training_args.per_device_eval_batch_size, shuffle=False, collate_fn=default_data_collator) for td in test_datasets}

In [10]:
model.prompt_encoder.default.embedding.weight = get_task_prompt_from_safetensor(saves[1])

model.eval()

PeftModelForCausalLM(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layerno

In [13]:
for td in test_dls:

    em = 0
    for batch in tqdm(test_dls[td]):
        preds = model.generate(input_ids=batch["input_ids"].to("cuda"), attention_mask=batch["attention_mask"].to("cuda"))
        decoded_preds = [dpred.split("label: ")[1] for dpred in tokenizer.batch_decode(preds, skip_special_tokens=True)]
        decoded_labels = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)

        print(decoded_preds, decoded_labels)

        em += exact_match(decoded_preds, decoded_labels)["exact_match"]
    
    em /= len(test_dls[td])
    print(em)

  0%|                                                                                                                                                                                                                                                                                                                                                                                           | 0/2732 [00:00<?, ?it/s]

  0%|‚ñè                                                                                                                                                                                                                                                                                                                                                                                  | 1/2732 [00:00<42:21,  1.07it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|‚ñé                                                                                                                                                                                                                                                                                                                                                                                  | 2/2732 [00:01<25:02,  1.82it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  0%|‚ñç                                                                                                                                                                                                                                                                                                                                                                                  | 3/2732 [00:01<19:29,  2.33it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|‚ñå                                                                                                                                                                                                                                                                                                                                                                                  | 4/2732 [00:01<16:52,  2.69it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|‚ñã                                                                                                                                                                                                                                                                                                                                                                                  | 5/2732 [00:02<15:26,  2.94it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|‚ñä                                                                                                                                                                                                                                                                                                                                                                                  | 6/2732 [00:02<14:34,  3.12it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|‚ñâ                                                                                                                                                                                                                                                                                                                                                                                  | 7/2732 [00:02<14:02,  3.24it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|‚ñà                                                                                                                                                                                                                                                                                                                                                                                  | 8/2732 [00:02<13:39,  3.32it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|‚ñà‚ñè                                                                                                                                                                                                                                                                                                                                                                                 | 9/2732 [00:03<13:24,  3.38it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  0%|‚ñà‚ñé                                                                                                                                                                                                                                                                                                                                                                                | 10/2732 [00:03<13:15,  3.42it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  0%|‚ñà‚ñç                                                                                                                                                                                                                                                                                                                                                                                | 11/2732 [00:03<13:09,  3.45it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|‚ñà‚ñã                                                                                                                                                                                                                                                                                                                                                                                | 12/2732 [00:04<13:05,  3.46it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  0%|‚ñà‚ñä                                                                                                                                                                                                                                                                                                                                                                                | 13/2732 [00:04<13:01,  3.48it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|‚ñà‚ñâ                                                                                                                                                                                                                                                                                                                                                                                | 14/2732 [00:04<12:59,  3.49it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|‚ñà‚ñà                                                                                                                                                                                                                                                                                                                                                                                | 15/2732 [00:04<12:57,  3.49it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñè                                                                                                                                                                                                                                                                                                                                                                               | 16/2732 [00:05<12:56,  3.50it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|‚ñà‚ñà‚ñé                                                                                                                                                                                                                                                                                                                                                                               | 17/2732 [00:05<12:55,  3.50it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñç                                                                                                                                                                                                                                                                                                                                                                               | 18/2732 [00:05<12:54,  3.50it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñå                                                                                                                                                                                                                                                                                                                                                                               | 19/2732 [00:06<12:53,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñã                                                                                                                                                                                                                                                                                                                                                                               | 20/2732 [00:06<12:54,  3.50it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñä                                                                                                                                                                                                                                                                                                                                                                               | 21/2732 [00:06<12:54,  3.50it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñâ                                                                                                                                                                                                                                                                                                                                                                               | 22/2732 [00:06<12:53,  3.50it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñà                                                                                                                                                                                                                                                                                                                                                                               | 23/2732 [00:07<12:52,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|‚ñà‚ñà‚ñà‚ñé                                                                                                                                                                                                                                                                                                                                                                              | 24/2732 [00:07<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|‚ñà‚ñà‚ñà‚ñç                                                                                                                                                                                                                                                                                                                                                                              | 25/2732 [00:07<12:52,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñà‚ñå                                                                                                                                                                                                                                                                                                                                                                              | 26/2732 [00:08<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñà‚ñã                                                                                                                                                                                                                                                                                                                                                                              | 27/2732 [00:08<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñà‚ñä                                                                                                                                                                                                                                                                                                                                                                              | 28/2732 [00:08<12:51,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'not_entailment']


  1%|‚ñà‚ñà‚ñà‚ñâ                                                                                                                                                                                                                                                                                                                                                                              | 29/2732 [00:08<12:50,  3.51it/s]

['not_entailment', 'not_entailment'] ['entailment', 'entailment']


  1%|‚ñà‚ñà‚ñà‚ñà                                                                                                                                                                                                                                                                                                                                                                              | 30/2732 [00:09<12:49,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  1%|‚ñà‚ñà‚ñà‚ñà‚ñè                                                                                                                                                                                                                                                                                                                                                                             | 31/2732 [00:09<12:50,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  1%|‚ñà‚ñà‚ñà‚ñà‚ñé                                                                                                                                                                                                                                                                                                                                                                             | 32/2732 [00:09<12:50,  3.51it/s]

['not_entailment', 'not_entailment'] ['not_entailment', 'entailment']


  1%|‚ñà‚ñà‚ñà‚ñà‚ñé                                                                                                                                                                                                                                                                                                                                                                             | 32/2732 [00:10<14:05,  3.19it/s]


KeyboardInterrupt: 

In [37]:
for td in test_dls:
    for batch in tqdm(test_dls[td]):
        outputs = model(input_ids=batch["input_ids"].to("cuda"), attention_mask=batch["attention_mask"].to("cuda"))
        print(outputs)
        print(tokenizer.batch_decode(outputs.logits[:,:-1].argmax(dim=-1), skip_special_tokens=True))
        break
    break

  0%|                                                                                                                                                                                                                                                                                                                                                                                             | 0/50 [00:01<?, ?it/s]

CausalLMOutputWithPast(loss={'logits': tensor([[[ 0.8086, -0.5078,  0.7227,  ...,  0.6523,  0.6523,  0.6523],
         [-1.9219, -1.0391, -2.7656,  ..., -0.8828, -0.8828, -0.8828],
         [-1.8359, -1.0000,  2.1719,  ..., -0.3906, -0.3906, -0.3906],
         ...,
         [11.1250,  9.1250,  9.5625,  ..., -1.3438, -1.3438, -1.3438],
         [ 4.5625,  2.6562, -0.0669,  ..., -0.8047, -0.8047, -0.8047],
         [ 1.6641,  2.7188,  2.1094,  ...,  0.3008,  0.3008,  0.3008]],

        [[ 0.8086, -0.5078,  0.7227,  ...,  0.6523,  0.6523,  0.6523],
         [-1.9219, -1.0391, -2.7656,  ..., -0.8828, -0.8828, -0.8828],
         [-1.8359, -1.0000,  2.1719,  ..., -0.3906, -0.3906, -0.3906],
         ...,
         [10.4375,  9.0625, 10.0000,  ..., -1.1406, -1.1406, -1.1406],
         [ 4.9375,  2.6719,  0.4277,  ..., -0.6562, -0.6562, -0.6562],
         [ 1.4609,  2.8125,  2.5469,  ...,  0.2598,  0.2598,  0.2598]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>), 'past_key_values': ((tens




In [20]:
for td in test_dls:
    for batch in tqdm(test_dls[td]):
        outputs = model.generate(input_ids=batch["input_ids"].to("cuda"), attention_mask=batch["attention_mask"].to("cuda"), do_sample=False)
        print(outputs.shape)
        print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
        break
    break

  0%|                                                                                                                                                                                                                                                                                                                                                                                             | 0/25 [00:00<?, ?it/s]

torch.Size([4, 261])
['qnli question: When did Tesla make the induction motor? sentence: One of the things Tesla developed at that laboratory in 1887 was an induction motor that ran on alternating current, a power system format that was starting to be built in Europe and the United States because of its advantages in long-distance, high-voltage transmission. label: not_entailment', "qnli question: What process down the line does rubisco's flaw interfere with? sentence: This is a big problem, since O2 is produced by the initial light reactions of photosynthesis, causing issues down the line in the Calvin cycle which uses rubisco. label: not_entailment", 'qnli question: Which country in 1985 signed a treaty to give it special status? sentence: Greenland signed a Treaty in 1985 giving it a special status. label: not_entailment', 'qnli question: How many seasons did NYPD Blue last? sentence: Daniel Burke departed from Capital Cities/ABC in February 1994, with Thomas Murphy taking over as p




In [16]:
def preprocess_logits_for_metrics(logits, labels):
    # print(logits.shape)
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]

    return torch.argmax(logits, dim=-1)


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    
    decoded_preds = [dpred.split("label: ")[1] for dpred in  tokenizer.batch_decode(preds, skip_special_tokens=True)]
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    print("preds:", decoded_preds)
    print("labels:", decoded_labels)

    return exact_match(decoded_preds, decoded_labels)

trainer = Seq2SeqTrainer(
                    model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    data_collator=default_data_collator,
                    compute_metrics=compute_metrics,
                    # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
                )

trainer.evaluate(eval_dataset=test_datasets["qnli_text"])



{'synced_gpus': False}




transformers: torch.Size([4, 261])


{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus': False}
transformers: torch.Size([4, 261])
{'synced_gpus'

{'eval_loss': 19.51024627685547,
 'eval_exact_match': 0.51,
 'eval_runtime': 18.1376,
 'eval_samples_per_second': 5.513,
 'eval_steps_per_second': 1.378}

In [19]:
model.base_model.params.max_seq_len

AttributeError: 'LlamaForCausalLM' object has no attribute 'params'

In [16]:
# model.prompt_encoder.default.embedding.weight = torch.nn.Parameter(torch.load(f"saves/{origin_prompt}/{origin_prompt}.bin")["prompt_embeddings"].to("cuda"))
model.prompt_encoder.default.embedding.weight = get_task_prompt_from_safetensor(saves[0])

model.eval()

for td in test_dls:
    for batch in test_dls[td]:
        print(batch)
        outputs = model.generate(input_ids=batch["input_ids"][: ,:-1].to("cuda"), attention_mask=batch["attention_mask"][:, :-1].to("cuda"))
        print(outputs)
        break
    break


# model.prompt_encoder.default.embedding.weight = get_task_prompt_from_safetensor(save)


# print(model.prompt_encoder.default.embedding.weight)
# print(model.base_model.lm_head.weight)

{'input_ids': [tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128002, 128002, 128002]), tensor([128002, 128

TypeError: list indices must be integers or slices, not tuple

In [18]:
tokenizer.batch_decode(outputs, skip_special_tokens=True)

['qnli question: What is the name of one type of prime where p+1 or p-1 takes a certain shape? sentence: This is why the largest known prime has almost always been a Mersenne prime since the dawn of electronic computers. label: not_entailment',
 'qnli question: What omen was Genghis Khan reported to have seen assuring his coming victory against the Tanguts? sentence: According to legend, it was here that Genghis Khan reportedly saw a line of five stars arranged in the sky and interpreted it as an omen of his victory. label: not_entailment',
 'qnli question: What is the name of the property where the media event was held for Super Bowl 50? sentence: The event was held on February 1, 2016 at SAP Center in San Jose. label: not_entailment',
 'qnli question: What year did Robert J. Shiller win an Economics Nobel prize? sentence: 2013 Economics Nobel prize winner Robert J. Shiller said that rising inequality in the United States and elsewhere is the most important problem. label: not_entailm

In [29]:
tokenizer.batch_decode(batch["labels"][:,-7:])

[' electronic computers. label: <|end_of_text|>',
 ' his victory. label: <|end_of_text|>',
 ' San Jose. label: <|end_of_text|>',
 ' important problem. label: <|end_of_text|>']

In [36]:
tokenizer.decode(test_datasets["qnli_text"]["labels"][0][-5:])

'not_entailment<|end_of_text|>'