In [None]:
! pip install -q transformers==4.52.2
! pip install -q -U datasets
! pip install -q peft accelerate bitsandbytes

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from transformers import TrainingArguments, Trainer, TrainerCallback

import sys
sys.path.append('/content/drive/MyDrive/Transformers/Project_OPT')

from OPT_to_GQA import convert_opt_to_gqa

# **Load Model Converted To GQA**
Using the adapted loader for been able to load the model with his GQAAttention conversion.
Also this will give us the ability to load the model as QLORA.


In [None]:
def create_standard_heads_grouping_arr(num_layers: int, num_heads: int, kv_heads: int):
    assert num_heads % kv_heads == 0, "num_heads must be divisible by kv_heads"
    group_size = num_heads // kv_heads
    grouping = [i for i in range(kv_heads) for _ in range(group_size)]
    return [grouping.copy() for _ in range(num_layers)]

kv_heads = 16
num_heads = 32
group_size = num_heads // kv_heads
num_layers = 24
heads_grouping_arr_standard = create_standard_heads_grouping_arr(num_layers, num_heads,
                                                        kv_heads)
# heads_grouping_arr_Q = [[14, 9, 8, 1, 8, 2, 13, 12, 7, 3, 11, 15, 0, 14, 7, 6, 10, 10, 12, 2, 3, 6, 15, 5, 4, 0, 13, 1, 11, 4, 9, 5], [4, 10, 13, 2, 0, 12, 0, 6, 12, 9, 2, 5, 11, 11, 8, 14, 5, 14, 9, 7, 3, 3, 7, 15, 1, 1, 10, 6, 8, 4, 13, 15], [14, 0, 11, 11, 5, 3, 13, 12, 15, 7, 13, 2, 10, 0, 9, 8, 6, 1, 2, 7, 8, 10, 12, 15, 4, 14, 5, 6, 3, 4, 1, 9], [14, 11, 7, 8, 6, 15, 11, 10, 9, 1, 13, 4, 13, 12, 10, 14, 6, 0, 3, 12, 8, 15, 4, 7, 9, 3, 2, 5, 5, 1, 2, 0], [9, 3, 2, 1, 13, 10, 14, 12, 11, 5, 14, 12, 1, 15, 7, 11, 3, 2, 7, 9, 0, 5, 4, 15, 6, 0, 6, 8, 8, 4, 13, 10], [15, 0, 14, 12, 5, 8, 7, 12, 9, 11, 11, 15, 10, 6, 7, 4, 1, 2, 5, 8, 4, 13, 1, 14, 13, 2, 9, 6, 3, 3, 10, 0], [1, 11, 6, 8, 3, 4, 5, 14, 10, 3, 9, 2, 10, 4, 6, 5, 2, 11, 0, 13, 1, 15, 7, 8, 14, 9, 7, 13, 12, 0, 15, 12], [11, 10, 2, 0, 15, 6, 1, 15, 14, 2, 13, 7, 8, 4, 8, 9, 9, 14, 5, 5, 3, 0, 3, 12, 11, 13, 10, 6, 12, 7, 1, 4], [5, 14, 2, 7, 9, 11, 0, 3, 13, 5, 4, 11, 2, 6, 8, 12, 8, 0, 9, 3, 10, 6, 10, 14, 12, 1, 1, 13, 15, 7, 4, 15], [3, 8, 0, 5, 13, 1, 10, 10, 11, 14, 2, 15, 6, 15, 4, 11, 12, 12, 0, 7, 4, 13, 5, 2, 6, 8, 3, 9, 9, 7, 14, 1], [3, 10, 13, 2, 12, 1, 5, 3, 6, 8, 0, 1, 0, 14, 15, 10, 13, 8, 15, 2, 6, 11, 4, 12, 7, 7, 9, 5, 14, 11, 4, 9], [10, 6, 11, 5, 1, 8, 0, 7, 1, 13, 14, 12, 2, 7, 13, 2, 15, 5, 12, 3, 14, 10, 6, 15, 3, 0, 11, 8, 9, 9, 4, 4], [13, 9, 14, 1, 9, 6, 5, 15, 2, 12, 11, 4, 0, 1, 5, 8, 13, 11, 15, 7, 3, 10, 2, 10, 7, 0, 4, 6, 14, 3, 8, 12], [4, 3, 9, 12, 5, 14, 6, 10, 14, 0, 2, 11, 3, 8, 8, 7, 7, 9, 1, 4, 1, 13, 2, 15, 15, 0, 12, 11, 5, 10, 6, 13], [10, 0, 2, 14, 8, 5, 12, 5, 11, 15, 14, 1, 3, 7, 6, 11, 3, 9, 10, 15, 4, 4, 1, 12, 8, 7, 2, 0, 9, 13, 6, 13], [2, 11, 14, 13, 9, 14, 1, 7, 6, 0, 3, 6, 10, 5, 4, 2, 7, 11, 12, 12, 3, 0, 4, 13, 5, 10, 9, 1, 8, 8, 15, 15], [14, 2, 0, 8, 9, 8, 0, 15, 15, 12, 2, 7, 11, 4, 10, 9, 6, 13, 12, 7, 5, 4, 1, 3, 14, 10, 6, 1, 13, 11, 3, 5], [9, 2, 5, 14, 10, 15, 12, 1, 12, 13, 0, 15, 0, 6, 11, 5, 9, 4, 1, 7, 7, 4, 6, 3, 14, 2, 8, 8, 11, 3, 10, 13], [9, 10, 6, 5, 5, 10, 8, 3, 15, 6, 4, 0, 1, 4, 11, 8, 7, 3, 11, 7, 0, 13, 2, 12, 13, 14, 1, 14, 12, 15, 9, 2], [14, 0, 11, 9, 1, 8, 3, 12, 1, 15, 14, 5, 4, 12, 9, 2, 15, 2, 3, 6, 10, 11, 13, 7, 4, 7, 5, 0, 8, 6, 10, 13], [12, 0, 3, 11, 14, 15, 5, 9, 4, 14, 8, 9, 0, 6, 7, 11, 2, 10, 12, 15, 10, 13, 1, 1, 4, 2, 13, 5, 3, 6, 7, 8], [1, 0, 5, 2, 13, 5, 14, 0, 11, 9, 10, 4, 7, 3, 15, 6, 14, 2, 1, 10, 6, 4, 12, 8, 15, 12, 13, 7, 3, 9, 8, 11], [2, 9, 7, 5, 9, 15, 14, 12, 4, 2, 6, 4, 3, 11, 7, 3, 10, 1, 15, 13, 1, 0, 8, 8, 6, 13, 11, 0, 10, 12, 14, 5], [7, 3, 5, 12, 10, 4, 15, 4, 13, 9, 2, 13, 9, 5, 2, 3, 8, 15, 7, 10, 8, 1, 0, 6, 11, 12, 14, 0, 11, 1, 14, 6]]
# heads_grouping_arr_K = [[6, 14, 11, 8, 3, 14, 10, 2, 0, 4, 1, 9, 13, 13, 15, 7, 8, 0, 7, 5, 9, 4, 5, 12, 6, 2, 10, 1, 3, 15, 11, 12], [3, 10, 6, 10, 13, 7, 13, 1, 9, 2, 11, 11, 0, 12, 5, 15, 9, 15, 4, 14, 12, 7, 8, 3, 1, 2, 4, 5, 8, 14, 6, 0], [1, 11, 3, 6, 4, 4, 5, 13, 7, 2, 2, 15, 11, 8, 7, 12, 1, 8, 9, 13, 15, 14, 3, 0, 9, 0, 6, 10, 5, 14, 10, 12], [11, 0, 14, 10, 15, 10, 14, 15, 11, 5, 2, 12, 3, 9, 7, 6, 8, 4, 4, 12, 7, 1, 9, 8, 13, 0, 5, 13, 2, 6, 3, 1], [4, 4, 0, 9, 10, 12, 7, 7, 3, 2, 5, 11, 13, 12, 0, 8, 9, 15, 1, 14, 1, 6, 5, 11, 8, 3, 2, 6, 15, 10, 14, 13], [14, 13, 11, 7, 2, 14, 7, 0, 6, 10, 9, 12, 12, 4, 13, 8, 6, 0, 11, 9, 2, 10, 3, 8, 1, 15, 3, 4, 5, 1, 5, 15], [3, 3, 13, 11, 5, 14, 10, 1, 4, 2, 12, 15, 4, 14, 13, 0, 9, 8, 0, 6, 11, 9, 6, 5, 8, 7, 1, 2, 15, 10, 7, 12], [12, 10, 2, 5, 4, 12, 15, 13, 13, 1, 8, 3, 0, 5, 8, 1, 14, 11, 6, 0, 14, 4, 6, 3, 2, 9, 10, 7, 9, 7, 15, 11], [12, 1, 3, 7, 15, 5, 9, 2, 9, 1, 6, 11, 4, 11, 13, 2, 15, 0, 8, 10, 14, 10, 14, 13, 5, 7, 4, 12, 3, 0, 6, 8], [2, 9, 8, 11, 9, 4, 1, 1, 10, 2, 4, 7, 15, 14, 5, 8, 15, 0, 3, 3, 5, 13, 12, 12, 6, 13, 11, 0, 6, 10, 14, 7], [14, 0, 14, 4, 12, 1, 11, 5, 9, 13, 10, 2, 10, 15, 6, 5, 12, 7, 1, 3, 3, 13, 8, 11, 7, 15, 4, 8, 9, 6, 2, 0], [4, 5, 10, 13, 6, 7, 0, 10, 1, 4, 6, 2, 0, 9, 9, 7, 14, 15, 13, 8, 2, 11, 5, 11, 3, 14, 12, 1, 8, 12, 3, 15], [12, 15, 5, 9, 6, 9, 12, 8, 2, 5, 8, 6, 7, 13, 11, 3, 2, 4, 7, 1, 14, 13, 0, 15, 10, 1, 10, 0, 11, 4, 14, 3], [14, 1, 11, 7, 4, 6, 8, 6, 5, 12, 2, 7, 13, 11, 14, 9, 0, 15, 8, 13, 3, 12, 1, 15, 2, 10, 10, 0, 4, 3, 5, 9], [5, 9, 3, 5, 7, 11, 12, 7, 11, 8, 2, 4, 14, 6, 3, 13, 2, 0, 13, 8, 10, 15, 1, 9, 14, 0, 15, 4, 6, 1, 10, 12], [2, 5, 1, 9, 6, 5, 10, 8, 2, 0, 12, 15, 0, 7, 11, 15, 8, 9, 10, 1, 6, 13, 4, 12, 14, 4, 3, 7, 11, 13, 14, 3], [8, 12, 2, 8, 0, 13, 9, 5, 4, 14, 3, 10, 15, 4, 10, 11, 7, 6, 13, 9, 0, 14, 5, 2, 3, 1, 15, 6, 11, 7, 1, 12], [8, 5, 2, 11, 11, 15, 0, 3, 9, 15, 14, 10, 1, 4, 4, 6, 7, 13, 13, 8, 5, 0, 7, 12, 10, 12, 9, 14, 2, 3, 6, 1], [3, 2, 5, 11, 6, 10, 4, 15, 12, 15, 14, 4, 1, 3, 14, 7, 8, 1, 6, 5, 2, 0, 12, 13, 10, 9, 8, 0, 9, 7, 13, 11], [11, 12, 7, 6, 12, 10, 2, 7, 14, 15, 10, 4, 3, 6, 2, 13, 11, 5, 14, 5, 1, 0, 9, 9, 8, 15, 4, 1, 8, 13, 0, 3], [2, 6, 7, 14, 3, 10, 9, 12, 9, 6, 11, 4, 2, 12, 7, 5, 11, 10, 0, 5, 1, 8, 13, 14, 3, 8, 15, 13, 0, 1, 15, 4], [4, 7, 13, 8, 2, 9, 11, 12, 6, 1, 13, 10, 0, 15, 5, 14, 5, 4, 3, 9, 2, 8, 3, 1, 10, 12, 6, 14, 0, 7, 11, 15], [9, 4, 3, 2, 1, 9, 5, 11, 11, 3, 4, 7, 8, 12, 7, 0, 1, 14, 10, 15, 2, 10, 6, 15, 8, 0, 13, 6, 14, 13, 12, 5], [13, 11, 2, 12, 15, 0, 7, 14, 6, 3, 15, 6, 5, 10, 8, 11, 3, 7, 13, 2, 0, 9, 10, 8, 1, 4, 4, 14, 1, 9, 5, 12]]

heads_grouping_arr = heads_grouping_arr_standard

repo_size = "1.3b"
saved_path = f"facebook-opt-{repo_size}-GQA-{kv_heads}-kv"

In [None]:
from torch import nn
from transformers.models.opt.modeling_opt import OPTForSequenceClassification, OPTModel, OPTDecoder, OPTDecoderLayer
from OPT_to_GQA import OPTGQAAttention

class CustomOPTDecoderLayer(OPTDecoderLayer):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)
        self.self_attn = OPTGQAAttention(kv_heads, heads_grouping_arr[layer_idx], config, layer_idx=layer_idx)

class CustomOPTDecoder(OPTDecoder):
    def __init__(self, config):
        super().__init__(config)
        self.layers = nn.ModuleList([CustomOPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])

class CustomOPTModel(OPTModel):
    def __init__(self, config):
        super().__init__(config)
        self.decoder = CustomOPTDecoder(config)

class CustomOPTForSequenceClassification(OPTForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.model = CustomOPTModel(config)
        self.post_init()


bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.float32
)

# My repo where I uploaded the GQA version of OPT.
repo_path = f"joshwapanda/{saved_path}"

model = CustomOPTForSequenceClassification.from_pretrained(repo_path, quantization_config=bnb_config, num_labels=3, device_map="auto")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.86G [00:00<?, ?B/s]

In [None]:
print(model.get_memory_footprint()/1e6)
model

769.159168


CustomOPTForSequenceClassification(
  (model): CustomOPTModel(
    (decoder): CustomOPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0-23): 24 x CustomOPTDecoderLayer(
          (self_attn): OPTGQAAttention(
            (k_proj): Linear4bit(in_features=2048, out_features=1024, bias=True)
            (v_proj): Linear4bit(in_features=2048, out_features=1024, bias=True)
            (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
            (out_proj): Linear4bit(in_features=2048, out_features=2048, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear4bit(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear4bit(in_features=8192, o

In [None]:
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=8,
    lora_alpha=16,
    bias="none",
    lora_dropout=0.05,
    task_type="SEQ_CLS",
    # manually setting target modules
    target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

model = get_peft_model(model, config)
print(model.get_memory_footprint()/1e6)
train_p, tot_p = model.get_nb_trainable_parameters()
print(f'Trainable parameters:      {train_p/1e6:.2f}M')
print(f'Total parameters:          {tot_p/1e6:.2f}M')
print(f'% of trainable parameters: {100*train_p/tot_p:.2f}%')
model

995.704832
Trainable parameters:      2.76M
Total parameters:          1217.81M
% of trainable parameters: 0.23%


PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): CustomOPTForSequenceClassification(
      (model): CustomOPTModel(
        (decoder): CustomOPTDecoder(
          (embed_tokens): Embedding(50272, 2048, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0-23): 24 x CustomOPTDecoderLayer(
              (self_attn): OPTGQAAttention(
                (k_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=2048, out_features=1024, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=2048, out_features=8, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default):

In [None]:
tokenizer_repo_id = f"facebook/opt-{repo_size}"

tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo_id)
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

# **Loading and pre-processing the dataset**

In [None]:
from datasets import load_dataset

dataset = load_dataset("nyu-mll/glue", "mnli")
dataset

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/52.2M [00:00<?, ?B/s]

(…)alidation_matched-00000-of-00001.parquet:   0%|          | 0.00/1.21M [00:00<?, ?B/s]

(…)dation_mismatched-00000-of-00001.parquet:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

test_matched-00000-of-00001.parquet:   0%|          | 0.00/1.22M [00:00<?, ?B/s]

test_mismatched-00000-of-00001.parquet:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Generating test_matched split:   0%|          | 0/9796 [00:00<?, ? examples/s]

Generating test_mismatched split:   0%|          | 0/9847 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})

In [None]:
dataset["train"] = dataset["train"].shuffle(seed=42).select(range(50000))

In [None]:
def preprocess(example):
    return tokenizer(
        example["premise"],
        example["hypothesis"],
        truncation=True,
        padding="max_length",
        max_length=128,
    )

encoded_dataset = dataset.map(preprocess, batched=True)
encoded_dataset = encoded_dataset.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/9815 [00:00<?, ? examples/s]

Map:   0%|          | 0/9832 [00:00<?, ? examples/s]

Map:   0%|          | 0/9796 [00:00<?, ? examples/s]

Map:   0%|          | 0/9847 [00:00<?, ? examples/s]

# **Training**

In [None]:
model.gradient_checkpointing_enable({"use_reentrant": False})

training_args = TrainingArguments(
    output_dir="./opt-qlora-mnli",
    per_device_train_batch_size=128,
    per_device_eval_batch_size=64,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    num_train_epochs=2,
    learning_rate=2e-4,
    warmup_steps=100,
    weight_decay=0.01,
    load_best_model_at_end=True,
    logging_dir="./logs",
    logging_steps=150,
    report_to="none",
    fp16=True,
    optim="paged_adamw_8bit",
    label_names=["labels"],
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from transformers import TrainerCallback
from sklearn.metrics import accuracy_score

class PlotLossAccuracyCallback(TrainerCallback):
    def __init__(self):
        self.train_loss = []
        self.eval_loss = []
        self.eval_acc = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is None:
            return

        print(f"[LOG] Step {state.global_step}: {logs}")

        if "loss" in logs:
            self.train_loss.append((state.global_step, logs["loss"]))
        if "eval_loss" in logs:
            self.eval_loss.append((state.global_step, logs["eval_loss"]))
        if "eval_accuracy" in logs:
            self.eval_acc.append((state.global_step, logs["eval_accuracy"]))

    def plot(self):
      steps_train, loss_train = zip(*self.train_loss) if self.train_loss else ([], [])
      steps_eval, loss_eval = zip(*self.eval_loss) if self.eval_loss else ([], [])
      steps_acc, acc_eval = zip(*self.eval_acc) if self.eval_acc else ([], [])

      plt.figure(figsize=(12, 5))

      if loss_train or loss_eval:
          plt.subplot(1, 2, 1)
          if loss_train:
              plt.plot(steps_train, loss_train, label="Train Loss")
          if loss_eval:
              plt.plot(steps_eval, loss_eval, label="Eval Loss")
          plt.xlabel("Step")
          plt.ylabel("Loss")
          plt.title("Training and Eval Loss")
          plt.legend()

      if acc_eval:
          plt.subplot(1, 2, 2)
          plt.plot(steps_acc, acc_eval, label="Eval Accuracy", color="green")
          plt.xlabel("Step")
          plt.ylabel("Accuracy")
          plt.title("Validation Accuracy")
          plt.legend()

    plt.tight_layout()
    plt.savefig("loss_accuracy_plot.png")
    plt.show()

<Figure size 640x480 with 0 Axes>

In [None]:
plot_callback = PlotLossAccuracyCallback()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation_matched"],
    tokenizer=tokenizer,
    callbacks=[plot_callback]
)

trainer.train()

# **Test Model**

In [None]:
! pip -q install evaluate
import evaluate
accuracy_metric = evaluate.load("accuracy")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h

Downloading builder script: 0.00B [00:00, ?B/s]

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation_matched"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [None]:
trainer.evaluate()