# Import libraries

In [1]:
import os
import wandb
import torch
from tqdm.auto import tqdm

from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers import DataCollatorForSeq2Seq
from transformers import DataCollatorForLanguageModeling

from transformers import get_scheduler

from peft import LoraConfig, TaskType
from peft import get_peft_model

In [2]:
# Import utils from ../src/utils
import sys
sys.path.append('..')

In [3]:
from utils.data import get_mnli
from utils.evaluation import evaluate

# Model

In [4]:
"""
The difference between “it” aka “Instruction Tuned”
and the base model is that the “it” variants are better for chat purposes
since they have been fine-tuned to better understand the instructions
and generate better answers while the base variants are those that have not undergone
under any sort of fine-tuning. They can still generate answers but not as good as the “it” one.

"""
# google/gemma-2b | google/gemma-2b-it | microsoft/phi-2
# Qwen/Qwen1.5-0.5B | Qwen/Qwen1.5-0.5B-Chat
model_name = "microsoft/phi-2" 

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"":0},
)
print(f"Model loaded: {model_name}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded: microsoft/phi-2


In [7]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.5,
    target_modules=["q_proj", "v_proj"], #, "o_proj", "k_proj" | "gate_proj", "up_proj", "down_proj", "dense"],
    task_type=TaskType.CAUSAL_LM,
)

In [8]:
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()

trainable params: 2,621,440 || all params: 2,782,305,280 || trainable%: 0.0942


# Tokenizer

In [9]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side="left",  # warned by model.generate
)
tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=lora_model)
max_seq_length = 1024
print(f"Tokenizer loaded: {model_name}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Tokenizer loaded: microsoft/phi-2


# Dataset: MNLI

In [10]:
prompt_key = "input_ids"

In [11]:
dataset = get_mnli(tokenizer, max_seq_length)

In [12]:
dataset

DatasetDict({
    train: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 391678
    })
    validation: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1024
    })
    test_matched: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 9815
    })
    test_mismatched: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 9832
    })
})

## Pytorch dataloader format

In [13]:
batch_size = 1  # number of examples in each batch
inference_batch_size = 8  # number of examples in each batch for inference

In [14]:
# Move the data to tensors
dataset.set_format("torch")

In [15]:
train_dataloader = DataLoader(
    dataset["train"],  # For testing purposes ->  .shuffle(seed=42).select(range(1000)),
    shuffle=True, batch_size=batch_size,
    collate_fn=data_collator
)

print(f"Training dataset size: {len(train_dataloader.dataset)}")

Training dataset size: 391678


In [16]:
val_dataloader = DataLoader(
    dataset["validation"].shuffle(seed=42).select(range(1000)),
    batch_size=inference_batch_size,
    collate_fn=data_collator
)

print(f"Validation dataset size: {len(val_dataloader.dataset)}")

Validation dataset size: 1000


## Check - Generate Batch 

In [17]:
batch = next(iter(val_dataloader))

  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


In [18]:
print({k: v.shape for k, v in batch.items()})

{'class_label': torch.Size([8]), 'idx': torch.Size([8]), 'prompt_length': torch.Size([8]), 'input_ids': torch.Size([8, 264]), 'attention_mask': torch.Size([8, 264]), 'labels': torch.Size([8, 264])}


# Try the base model (not finetuned)

In [19]:
max_output_tokens = 64
temperature = 0

In [20]:
debug_dataloader = DataLoader(
    dataset["validation"].shuffle(seed=42).select(range(inference_batch_size * 3)),
    batch_size=inference_batch_size,
    collate_fn=data_collator
)

In [21]:
debug_preds = evaluate(
    lora_model, debug_dataloader, tokenizer,
    prompt_key=prompt_key,
    max_output_tokens=max_output_tokens,
    temperature=temperature
)

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]



# Training setup

In [22]:
learning_rate = 0.0002
weight_decay = 0.01

training_steps = 40000

# The desired batch size is the batch size you want to train with
effective_batch_size = 16
gradient_accumulation_steps = effective_batch_size // batch_size


warmup_steps = 1200  # training steps

"""
Each model is trained for 40000 training
steps with batch size 1. Gradients are applied over
16 accumulation steps for an effective batch size of 16.
"""
training_iters = training_steps // gradient_accumulation_steps
eval_interval = 250

print(f"The model will be trained with {training_steps*batch_size} examples")
print(f"Number of training iterations: {training_iters}")
print(f"Total epochs: {training_steps*batch_size/len(dataset['train'])}")

The model will be trained with 40000 examples
Number of training iterations: 2500
Total epochs: 0.10212470447663642


In [23]:
os.environ["WANDB_SILENT"] = "true"
wandb_project = 'FLoRA'
run = wandb.init(project=wandb_project, config={
    "model": model_name,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "training_steps": training_steps,
    "batch_size": batch_size,
    "effective_batch_size": effective_batch_size,
    "warmup_steps": warmup_steps,
    "gradient_accumulation_steps": gradient_accumulation_steps,
    "eval_interval": eval_interval,
})
print(f'Run name: {run.name}. Visit at {run.get_url()}')

Run name: dutiful-silence-5. Visit at https://wandb.ai/marioparreno/FLoRA/runs/ahxojp40


## Optimizer and learning rate scheduler

Create an optimizer and learning rate scheduler to fine-tune the model. Let’s use the [AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) optimizer from PyTorch

In [24]:
optimizer = AdamW(lora_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

Create the default learning rate scheduler from [Trainer](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.Trainer):

In [25]:
lr_scheduler = get_scheduler(
    name="cosine", optimizer=optimizer,
    num_warmup_steps=warmup_steps, num_training_steps=training_steps
)

# Training loop

In [26]:
forward_keys = ["input_ids", "attention_mask", "labels"]

In [33]:
print(tokenizer.decode(train_batch["input_ids"][0]))

<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|

In [32]:
train_batch["attention_mask"][0]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [35]:
print(tokenizer.decode(train_batch["labels"][0]))

<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|

In [67]:
batch = next(iter(val_dataloader))

In [68]:
generated_tokens_with_prompt = model.generate(
    batch[prompt_key],
    max_length=1024 + max_output_tokens,
    pad_token_id=tokenizer.pad_token_id,
    temperature=temperature,
)



In [69]:
batch["class_label"]

tensor([1, 2, 2, 0, 1, 1, 1, 2])

In [70]:
tokenizer.decode(generated_tokens_with_prompt[0])

"<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext

In [71]:
print(tokenizer.decode(batch[prompt_key][0]))

<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|

In [72]:
print(tokenizer.decode(generated_tokens_with_prompt[0]))

<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|

In [43]:
batch[prompt_key]

tensor([[50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 5

In [42]:
generated_tokens_with_prompt

tensor([[50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 5

In [49]:
# Decode the generated tokens
generated_text_with_prompt = tokenizer.batch_decode(
    generated_tokens_with_prompt, skip_special_tokens=True
)

In [50]:
generated_text_with_prompt

["You are given a premise and a hypothesis below. If the premise entails the  hypothesis, return 0. If the premise contradicts the hypothesis, return 2.  Otherwise, if the premise does neither, return 1.\n\n### Premise: try to find a time when everybody can be there and we've\n\n### Hypothesis: We try to find the time to all go at once and spend time together.\n\n### Label: 1"]

In [None]:
case_text = tokenizer.decode(batch["input_ids"][0])

In [None]:
for batch_index, generation_with_prompt in enumerate(
    generated_text_with_prompt
):
    case_text = tokenizer.decode(batch["input_ids"][batch_index])
    # remove eos_token/pad_token from texts
    case_text = case_text.replace(tokenizer.pad_token, "")
    y_true = batch["class_label"][batch_index].item()
    generation = generation_with_prompt[len(case_text) :]
    y_pred = get_first_number(generation)

In [27]:
progress_bar = tqdm(range(training_iters))
for iter_num in range(training_iters):
    
    model.train()
    for micro_step in range(gradient_accumulation_steps):
        # Extract a batch of data
        batch = next(iter(train_dataloader))
        # remove from batch keys that are not needed
        train_batch = {k: v for k, v in batch.items() if k in forward_keys}

        outputs = lora_model(**train_batch)
        # El modelo calcula su loss, pero podriamos acceder a los logits del modelo
        # y las labels del batch y calcular nuestra loss propia
        # scale the loss to account for gradient accumulation
        loss = outputs.loss / gradient_accumulation_steps
        loss.backward()

    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)


    if iter_num % eval_interval == 0:
        # scale up to undo the division above
        # approximating total loss (exact would have been a sum)
        train_loss = loss.item() * gradient_accumulation_steps

        val_preds = evaluate(lora_model, val_dataloader, tokenizer)
        val_correct = sum([1 for p in val_preds if p.y_true == p.y_pred])
        val_accuracy = val_correct / len(val_preds)
        #train_loss, train_perplexity = evaluate(model, train_dataloader)

        print(f"### ITER {iter_num} ###")
        # print(f"Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f}")
        # print(f"Validation Loss: {val_loss:.4f} - Validation Accuracy: {val_accuracy:.4f}")
        print(f"Train Loss: {train_loss:.4f} - Validation Accuracy: {val_accuracy:.4f}")

        wandb.log({
            "iter": iter_num,
            "train/loss": train_loss,
            # "train/accuracy": train_accuracy,
            # "val/loss": val_loss,
            "val/accuracy": val_accuracy,
            "lr": lr_scheduler.get_last_lr()[0],
        })

        preds_table = wandb.Table(columns=["Case Index", "Case Prompt", "Generation", "Ground Truth", "Prediction"])
        for pred in val_preds:    
            preds_table.add_data(
                pred.case_index,
                pred.case_text,
                pred.generation,
                pred.y_true,
                pred.y_pred
            )
        run.log({f"test_completions_iter{iter_num}": preds_table})

progress_bar.close()

  0%|          | 0/2500 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]



### ITER 0 ###
Train Loss: 7.0368 - Validation Accuracy: 0.0820


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 250 ###
Train Loss: 1.6077 - Validation Accuracy: 0.0000


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 500 ###
Train Loss: 0.8692 - Validation Accuracy: 0.0000


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 750 ###
Train Loss: 0.7255 - Validation Accuracy: 0.0010


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 1000 ###
Train Loss: 0.5540 - Validation Accuracy: 0.0010


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 1250 ###
Train Loss: 0.9238 - Validation Accuracy: 0.0000


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 1500 ###
Train Loss: 0.6600 - Validation Accuracy: 0.0000


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 1750 ###
Train Loss: 0.4166 - Validation Accuracy: 0.0000


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 2000 ###
Train Loss: 0.1737 - Validation Accuracy: 0.0000


Evaluating:   0%|          | 0/125 [00:00<?, ?it/s]

### ITER 2250 ###
Train Loss: 0.4040 - Validation Accuracy: 0.0010


# Save the model

In [28]:
lora_model.save_pretrained(f"checkpoints/{run.name}")
# Log the model checkpoint
artifact = wandb.Artifact("checkpoint_and_results", type="models")
artifact.add_dir(f"checkpoints/{run.name}")
run.log_artifact(artifact)



<Artifact checkpoint_and_results>

# Finish the run

In [29]:
wandb.finish()