# Import libraries

In [1]:
import torch

from torch.utils.data import DataLoader

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers import DataCollatorForLanguageModeling

In [2]:
# Import utils from ../src/utils
import sys
sys.path.append('..')

In [3]:
from utils.data import get_mnli
from utils.evaluation import evaluate

# Model

In [4]:
"""
The difference between “it” aka “Instruction Tuned”
and the base model is that the “it” variants are better for chat purposes
since they have been fine-tuned to better understand the instructions
and generate better answers while the base variants are those that have not undergone
under any sort of fine-tuning. They can still generate answers but not as good as the “it” one.

"""
# google/gemma-2b | google/gemma-2b-it | microsoft/phi-2
# Qwen/Qwen1.5-0.5B | Qwen/Qwen1.5-0.5B-Chat
model_name = "microsoft/phi-2" 

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False
)

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto"  #{"":0},
)
print(f"Model loaded: {model_name}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded: microsoft/phi-2


# Tokenizer

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
#data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
max_seq_length = 1024
print(f"Tokenizer loaded: {model_name}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Tokenizer loaded: microsoft/phi-2


# Dataset: MNLI

In [8]:
dataset = get_mnli(tokenizer, max_seq_length)

Map:   0%|          | 0/391678 [00:00<?, ? examples/s]

Map:   0%|          | 0/1024 [00:00<?, ? examples/s]

Map:   0%|          | 0/9815 [00:00<?, ? examples/s]

Map:   0%|          | 0/9832 [00:00<?, ? examples/s]

In [9]:
dataset

DatasetDict({
    train: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 391678
    })
    validation: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1024
    })
    test_matched: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 9815
    })
    test_mismatched: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 9832
    })
})

## Pytorch dataloader format

In [10]:
batch_size = 1  # number of examples in each batch
inference_batch_size = 1  # number of examples in each batch for inference

In [11]:
max_output_tokens = 32

In [12]:
# Move the data to tensors
dataset.set_format("torch")

## Validation

In [13]:
val_dataloader = DataLoader(
    dataset["validation"].shuffle(seed=42).select(range(1000)),
    batch_size=inference_batch_size,
    collate_fn=data_collator
)

print(f"Validation dataset size: {len(val_dataloader.dataset)}")

Validation dataset size: 1000


In [14]:
val_preds = evaluate(
    model,
    val_dataloader,
    tokenizer,
    max_output_tokens=max_output_tokens
)
val_correct = sum([1 for p in val_preds if p.y_true == p.y_pred])
val_accuracy = val_correct / len(val_preds)
val_accuracy

Evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]



0.236

## Test

### Test Matched

In [15]:
test_matched_dataloader = DataLoader(
    dataset["test_matched"],
    batch_size=inference_batch_size,
    collate_fn=data_collator
)

print(f"Test Matched dataset size: {len(test_matched_dataloader.dataset)}")

Test Matched dataset size: 9815


In [16]:
test_matched_preds = evaluate(
    model,
    test_matched_dataloader,
    tokenizer,
    max_output_tokens=max_output_tokens
)
test_matched_correct = sum([1 for p in test_matched_preds if p.y_true == p.y_pred])
test_matched_accuracy = test_matched_correct / len(test_matched_preds)
test_matched_accuracy

Evaluating:   0%|          | 0/9815 [00:00<?, ?it/s]

0.21018848700967907

### Test MisMatched

In [18]:
test_mismatched_dataloader = DataLoader(
    dataset["test_mismatched"],
    batch_size=inference_batch_size,
    collate_fn=data_collator
)

print(f"Test Mismatched dataset size: {len(test_mismatched_dataloader.dataset)}")

Test Mismatched dataset size: 9832


In [19]:
test_mismatched_preds = evaluate(
    model,
    test_mismatched_dataloader,
    tokenizer,
    max_output_tokens=max_output_tokens
)
test_mismatched_correct = sum([1 for p in test_mismatched_preds if p.y_true == p.y_pred])
test_mismatched_accuracy = test_mismatched_correct / len(test_mismatched_preds)
test_mismatched_accuracy

Evaluating:   0%|          | 0/9832 [00:00<?, ?it/s]

0.22914971521562247