In [None]:
! pip install -q transformers==4.52.2
! pip install -q -U datasets
! pip install -q peft accelerate bitsandbytes

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.2/40.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.5/10.5 MB[0m [31m103.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m78.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from transformers import TrainingArguments, Trainer, TrainerCallback

import sys
sys.path.append('/content/drive/MyDrive/Transformers/Project_OPT')

from OPT_to_GQA import convert_opt_to_gqa

# **Load Model Converted To GQA**


In [None]:
def create_standard_heads_grouping_arr(num_layers: int, num_heads: int, kv_heads: int):
    assert num_heads % kv_heads == 0, "num_heads must be divisible by kv_heads"
    group_size = num_heads // kv_heads
    grouping = [i for i in range(kv_heads) for _ in range(group_size)]
    return [grouping.copy() for _ in range(num_layers)]

heads_grouping_arr_V = [[14, 0, 2, 11, 15, 0, 8, 5, 9, 10, 6, 1, 15, 1, 4, 10, 13, 7, 9, 12, 12, 13, 7, 4, 14, 11, 2, 3, 6, 3, 8, 5], [9, 2, 14, 1, 3, 10, 0, 6, 0, 4, 7, 10, 11, 5, 15, 12, 12, 5, 13, 7, 15, 8, 13, 3, 8, 11, 2, 9, 6, 1, 14, 4], [6, 14, 7, 4, 11, 11, 0, 8, 0, 15, 2, 15, 12, 3, 10, 5, 9, 5, 2, 14, 10, 9, 12, 3, 13, 8, 1, 13, 7, 6, 1, 4], [11, 8, 11, 0, 14, 8, 7, 2, 6, 3, 12, 5, 4, 13, 7, 3, 4, 10, 0, 5, 1, 6, 2, 1, 10, 13, 14, 9, 9, 12, 15, 15], [6, 15, 4, 3, 13, 9, 11, 2, 0, 12, 3, 10, 7, 11, 5, 7, 14, 9, 0, 14, 13, 10, 4, 8, 8, 1, 12, 2, 15, 1, 5, 6], [3, 3, 13, 8, 10, 7, 6, 12, 6, 1, 13, 5, 0, 0, 15, 10, 1, 9, 11, 12, 15, 9, 4, 11, 2, 2, 4, 14, 14, 7, 8, 5], [2, 0, 9, 14, 14, 5, 9, 10, 4, 11, 12, 6, 2, 8, 6, 10, 13, 7, 0, 3, 1, 15, 11, 5, 4, 3, 15, 12, 8, 1, 7, 13], [9, 5, 12, 8, 14, 10, 6, 9, 3, 2, 2, 1, 12, 0, 8, 11, 15, 5, 15, 13, 3, 7, 6, 4, 7, 0, 10, 1, 13, 4, 11, 14], [9, 10, 6, 1, 12, 5, 5, 14, 11, 3, 13, 15, 4, 8, 4, 3, 13, 9, 2, 7, 0, 10, 0, 2, 11, 14, 6, 8, 1, 15, 7, 12], [2, 10, 11, 14, 8, 15, 13, 6, 2, 6, 5, 0, 0, 4, 15, 9, 3, 11, 5, 7, 10, 7, 14, 8, 4, 12, 12, 1, 9, 1, 13, 3], [8, 7, 0, 9, 10, 10, 4, 5, 14, 0, 15, 2, 15, 7, 12, 13, 11, 1, 12, 4, 5, 6, 2, 11, 13, 3, 1, 3, 6, 9, 8, 14], [9, 10, 7, 13, 7, 2, 0, 1, 3, 10, 3, 12, 2, 6, 4, 6, 5, 1, 11, 13, 4, 14, 0, 11, 14, 8, 12, 15, 15, 8, 9, 5], [9, 1, 12, 11, 8, 10, 0, 5, 1, 0, 6, 7, 11, 12, 3, 5, 2, 9, 15, 4, 3, 10, 6, 13, 2, 14, 4, 8, 13, 14, 15, 7], [5, 11, 2, 7, 3, 6, 13, 14, 15, 8, 9, 2, 10, 15, 1, 4, 7, 10, 12, 8, 11, 0, 3, 12, 0, 9, 4, 14, 6, 13, 1, 5], [9, 10, 10, 11, 13, 15, 12, 4, 3, 14, 5, 8, 0, 2, 11, 9, 6, 7, 6, 13, 4, 3, 2, 1, 1, 0, 15, 12, 14, 7, 8, 5], [3, 10, 13, 4, 5, 15, 15, 14, 11, 12, 1, 1, 0, 6, 9, 8, 12, 0, 10, 14, 6, 7, 4, 11, 7, 8, 2, 9, 3, 5, 13, 2], [5, 0, 12, 11, 5, 12, 1, 14, 9, 15, 10, 4, 2, 3, 3, 7, 10, 6, 11, 13, 4, 14, 1, 0, 15, 7, 8, 2, 9, 6, 8, 13], [5, 0, 7, 11, 14, 1, 6, 3, 14, 15, 10, 10, 7, 2, 11, 6, 4, 2, 8, 0, 1, 8, 5, 12, 13, 4, 12, 13, 9, 3, 9, 15], [14, 0, 12, 12, 1, 3, 6, 9, 11, 9, 15, 10, 5, 13, 4, 1, 15, 7, 13, 2, 8, 0, 3, 4, 11, 7, 14, 2, 5, 6, 8, 10], [2, 7, 10, 2, 5, 13, 10, 7, 8, 12, 1, 14, 5, 3, 3, 11, 6, 9, 4, 9, 8, 1, 12, 13, 14, 15, 4, 0, 15, 0, 11, 6], [1, 15, 8, 7, 5, 13, 3, 2, 9, 14, 15, 6, 0, 10, 10, 4, 5, 11, 9, 11, 6, 8, 14, 12, 2, 1, 4, 12, 7, 3, 13, 0], [10, 14, 3, 14, 11, 15, 1, 0, 1, 12, 9, 13, 7, 12, 10, 8, 6, 0, 7, 11, 3, 9, 4, 2, 5, 2, 6, 4, 5, 13, 8, 15], [5, 0, 7, 0, 1, 3, 14, 14, 12, 11, 9, 9, 6, 7, 15, 6, 11, 10, 2, 13, 3, 4, 15, 4, 13, 5, 12, 8, 1, 8, 2, 10], [8, 14, 0, 12, 14, 2, 4, 3, 10, 7, 1, 9, 0, 7, 12, 5, 13, 4, 5, 15, 11, 15, 2, 6, 9, 1, 3, 10, 11, 13, 8, 6]]

repo_size = "1.3b"
repo_id = f"facebook/opt-{repo_size}"
model = AutoModelForSequenceClassification.from_pretrained(
   repo_id, num_labels=3, device_map="auto",
)

# Convert to GQA.
kv_heads = 16
num_heads = 32
group_size = num_heads // kv_heads
num_layers = len(model.model.decoder.layers) # 1.3b - 24, 6.7b - 32
# heads_grouping_arr = create_standard_heads_grouping_arr(num_layers, num_heads,
#                                                         kv_heads)
heads_grouping_arr = heads_grouping_arr_V

model_gqa = convert_opt_to_gqa(model, kv_heads=kv_heads, heads_grouping_arr=heads_grouping_arr, inplace=True)
print(f"Converted to GQA with {heads_grouping_arr}")
model_gqa

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-1.3b and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Converted to GQA with [[14, 0, 2, 11, 15, 0, 8, 5, 9, 10, 6, 1, 15, 1, 4, 10, 13, 7, 9, 12, 12, 13, 7, 4, 14, 11, 2, 3, 6, 3, 8, 5], [9, 2, 14, 1, 3, 10, 0, 6, 0, 4, 7, 10, 11, 5, 15, 12, 12, 5, 13, 7, 15, 8, 13, 3, 8, 11, 2, 9, 6, 1, 14, 4], [6, 14, 7, 4, 11, 11, 0, 8, 0, 15, 2, 15, 12, 3, 10, 5, 9, 5, 2, 14, 10, 9, 12, 3, 13, 8, 1, 13, 7, 6, 1, 4], [11, 8, 11, 0, 14, 8, 7, 2, 6, 3, 12, 5, 4, 13, 7, 3, 4, 10, 0, 5, 1, 6, 2, 1, 10, 13, 14, 9, 9, 12, 15, 15], [6, 15, 4, 3, 13, 9, 11, 2, 0, 12, 3, 10, 7, 11, 5, 7, 14, 9, 0, 14, 13, 10, 4, 8, 8, 1, 12, 2, 15, 1, 5, 6], [3, 3, 13, 8, 10, 7, 6, 12, 6, 1, 13, 5, 0, 0, 15, 10, 1, 9, 11, 12, 15, 9, 4, 11, 2, 2, 4, 14, 14, 7, 8, 5], [2, 0, 9, 14, 14, 5, 9, 10, 4, 11, 12, 6, 2, 8, 6, 10, 13, 7, 0, 3, 1, 15, 11, 5, 4, 3, 15, 12, 8, 1, 7, 13], [9, 5, 12, 8, 14, 10, 6, 9, 3, 2, 2, 1, 12, 0, 8, 11, 15, 5, 15, 13, 3, 7, 6, 4, 7, 0, 10, 1, 13, 4, 11, 14], [9, 10, 6, 1, 12, 5, 5, 14, 11, 3, 13, 15, 4, 8, 4, 3, 13, 9, 2, 7, 0, 10, 0, 2, 11, 14, 6, 8, 1,

OPTForSequenceClassification(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTGQAAttention(
            (k_proj): Linear(in_features=2048, out_features=1024, bias=True)
            (v_proj): Linear(in_features=2048, out_features=1024, bias=True)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear(in_features=8192, out_features=2048, bias=True)
          (final_la

In [None]:
print(model_gqa.get_memory_footprint()/1e6)
model_gqa

4860.207104


OPTForSequenceClassification(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTGQAAttention(
            (k_proj): Linear(in_features=2048, out_features=1024, bias=True)
            (v_proj): Linear(in_features=2048, out_features=1024, bias=True)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear(in_features=8192, out_features=2048, bias=True)
          (final_la

In [None]:
model = model_gqa

config = LoraConfig(
    r=8,
    lora_alpha=16,
    bias="none",
    lora_dropout=0.05,
    task_type="SEQ_CLS",
    # manually setting target modules
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

model = get_peft_model(model, config)
print(model.get_memory_footprint()/1e6)
train_p, tot_p = model.get_nb_trainable_parameters()
print(f'Trainable parameters:      {train_p/1e6:.2f}M')
print(f'Total parameters:          {tot_p/1e6:.2f}M')
print(f'% of trainable parameters: {100*train_p/tot_p:.2f}%')
model

4871.241728
Trainable parameters:      2.76M
Total parameters:          1217.81M
% of trainable parameters: 0.23%


PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): OPTForSequenceClassification(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 2048, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0-23): 24 x OPTDecoderLayer(
              (self_attn): OPTGQAAttention(
                (k_proj): lora.Linear(
                  (base_layer): Linear(in_features=2048, out_features=1024, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=2048, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=8, out_featu

In [None]:
tokenizer_repo_id = f"facebook/opt-{repo_size}"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo_id)
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

# **Loading and pre-processing the dataset**

In [None]:
from datasets import load_dataset

dataset = load_dataset("nyu-mll/glue", "mnli")
dataset

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/52.2M [00:00<?, ?B/s]

(…)alidation_matched-00000-of-00001.parquet:   0%|          | 0.00/1.21M [00:00<?, ?B/s]

(…)dation_mismatched-00000-of-00001.parquet:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

test_matched-00000-of-00001.parquet:   0%|          | 0.00/1.22M [00:00<?, ?B/s]

test_mismatched-00000-of-00001.parquet:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Generating test_matched split:   0%|          | 0/9796 [00:00<?, ? examples/s]

Generating test_mismatched split:   0%|          | 0/9847 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})

In [None]:
dataset["train"] = dataset["train"].shuffle(seed=42).select(range(50000))

In [None]:
def preprocess(example):
    return tokenizer(
        example["premise"],
        example["hypothesis"],
        truncation=True,
        padding="max_length",
        max_length=128,
    )

encoded_dataset = dataset.map(preprocess, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/9815 [00:00<?, ? examples/s]

Map:   0%|          | 0/9832 [00:00<?, ? examples/s]

Map:   0%|          | 0/9796 [00:00<?, ? examples/s]

Map:   0%|          | 0/9847 [00:00<?, ? examples/s]

# **Training**

In [None]:
model.gradient_checkpointing_enable({"use_reentrant": False})

training_args = TrainingArguments(
    output_dir="./opt-lora-mnli",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    num_train_epochs=2,
    learning_rate=2e-4,
    warmup_steps=100,
    weight_decay=0.01,
    load_best_model_at_end=True,
    logging_dir="./logs",
    logging_steps=100,
    report_to="none",
    fp16=True,
    optim="adamw_torch",
    label_names=["labels"],
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from transformers import TrainerCallback
from sklearn.metrics import accuracy_score

class PlotLossAccuracyCallback(TrainerCallback):
    def __init__(self):
        self.train_loss = []
        self.eval_loss = []
        self.eval_acc = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return

        print(f"[LOG] Step {state.global_step}: {logs}")

        if "loss" in logs:
            self.train_loss.append((state.global_step, logs["loss"]))
        if "eval_loss" in logs:
            self.eval_loss.append((state.global_step, logs["eval_loss"]))
        if "eval_accuracy" in logs:
            self.eval_acc.append((state.global_step, logs["eval_accuracy"]))

    def plot(self):
      steps_train, loss_train = zip(*self.train_loss) if self.train_loss else ([], [])
      steps_eval, loss_eval = zip(*self.eval_loss) if self.eval_loss else ([], [])
      steps_acc, acc_eval = zip(*self.eval_acc) if self.eval_acc else ([], [])

      plt.figure(figsize=(12, 5))

      if loss_train or loss_eval:
          plt.subplot(1, 2, 1)
          if loss_train:
              plt.plot(steps_train, loss_train, label="Train Loss")
          if loss_eval:
              plt.plot(steps_eval, loss_eval, label="Eval Loss")
          plt.xlabel("Step")
          plt.ylabel("Loss")
          plt.title("Training and Eval Loss")
          plt.legend()

      if acc_eval:
          plt.subplot(1, 2, 2)
          plt.plot(steps_acc, acc_eval, label="Eval Accuracy", color="green")
          plt.xlabel("Step")
          plt.ylabel("Accuracy")
          plt.title("Validation Accuracy")
          plt.legend()

    plt.tight_layout()
    plt.savefig("loss_accuracy_plot.png")
    plt.show()

<Figure size 640x480 with 0 Axes>

In [None]:
plot_callback = PlotLossAccuracyCallback()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation_matched"],
    tokenizer=tokenizer,
    callbacks=[plot_callback]
)

trainer.train()

  trainer = Trainer(
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss
1,0.8058,0.712277
2,0.5892,0.551193


[LOG] Step 100: {'loss': 1.1594, 'grad_norm': 5.939835548400879, 'learning_rate': 0.000192, 'epoch': 0.1278772378516624}
[LOG] Step 200: {'loss': 1.1189, 'grad_norm': 4.499975681304932, 'learning_rate': 0.00018688524590163935, 'epoch': 0.2557544757033248}
[LOG] Step 300: {'loss': 1.1027, 'grad_norm': 7.983818054199219, 'learning_rate': 0.000173224043715847, 'epoch': 0.3836317135549872}
[LOG] Step 400: {'loss': 1.0418, 'grad_norm': 1.9318846464157104, 'learning_rate': 0.00015956284153005465, 'epoch': 0.5115089514066496}
[LOG] Step 500: {'loss': 0.9822, 'grad_norm': 7.987417221069336, 'learning_rate': 0.0001459016393442623, 'epoch': 0.639386189258312}
[LOG] Step 600: {'loss': 0.8882, 'grad_norm': 5.158024787902832, 'learning_rate': 0.00013224043715846995, 'epoch': 0.7672634271099744}
[LOG] Step 700: {'loss': 0.8058, 'grad_norm': 3.0277886390686035, 'learning_rate': 0.0001185792349726776, 'epoch': 0.8951406649616368}
[LOG] Step 782: {'eval_loss': 0.7122766971588135, 'eval_runtime': 157.58

TrainOutput(global_step=1564, training_loss=0.8035080463380155, metrics={'train_runtime': 5572.2599, 'train_samples_per_second': 17.946, 'train_steps_per_second': 0.281, 'total_flos': 8.52983021568e+16, 'train_loss': 0.8035080463380155, 'epoch': 2.0})

# **Test Model**

In [None]:
! pip -q install evaluate
import evaluate
accuracy_metric = evaluate.load("accuracy")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h

Downloading builder script: 0.00B [00:00, ?B/s]

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation_matched"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [None]:
trainer.evaluate()

{'eval_loss': 0.5511928796768188,
 'eval_model_preparation_time': 0.0113,
 'eval_accuracy': 0.775649516046867,
 'eval_runtime': 162.4419,
 'eval_samples_per_second': 60.422,
 'eval_steps_per_second': 1.89}