# Import libraries

In [1]:
import os
import wandb
import torch
from tqdm.auto import tqdm

from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from transformers import DataCollatorForLanguageModeling

from transformers import get_scheduler

from peft import LoraConfig, TaskType
from peft import get_peft_model

In [2]:
# Import utils from ../src/utils
import sys
sys.path.append('..')

In [3]:
from utils.data import get_mnli
from utils.evaluation import evaluate

# Model

In [4]:
"""
The difference between “it” aka “Instruction Tuned”
and the base model is that the “it” variants are better for chat purposes
since they have been fine-tuned to better understand the instructions
and generate better answers while the base variants are those that have not undergone
under any sort of fine-tuning. They can still generate answers but not as good as the “it” one.

"""
# google/gemma-2b | google/gemma-2b-it | microsoft/phi-2
# Qwen/Qwen1.5-0.5B | Qwen/Qwen1.5-0.5B-Chat
model_name = "microsoft/phi-2" 

In [5]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False
)

In [6]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto"  #{"":0},
)
print(f"Model loaded: {model_name}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded: microsoft/phi-2


In [7]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"], #, "o_proj", "k_proj" | "gate_proj", "up_proj", "down_proj", "dense"],
    #target_modules=["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"], #, "o_proj", "k_proj" | "gate_proj", "up_proj", "down_proj", "dense"],
    task_type=TaskType.CAUSAL_LM,
)

In [8]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 2,621,440 || all params: 2,782,305,280 || trainable%: 0.0942


# Tokenizer

In [9]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token # Most LLMs don't have a pad token by default
#data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
max_seq_length = 1024
print(f"Tokenizer loaded: {model_name}")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Tokenizer loaded: microsoft/phi-2


# Dataset: MNLI

In [10]:
prompt_key = "input_ids"

In [11]:
dataset = get_mnli(tokenizer, max_seq_length)

In [12]:
dataset

DatasetDict({
    train: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 391678
    })
    validation: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1024
    })
    test_matched: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 9815
    })
    test_mismatched: Dataset({
        features: ['class_label', 'idx', 'prompt_length', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 9832
    })
})

## Pytorch dataloader format

In [13]:
batch_size = 1  # number of examples in each batch
inference_batch_size = 1  # number of examples in each batch for inference

In [14]:
# Move the data to tensors
dataset.set_format("torch")

In [15]:
train_dataloader = DataLoader(
    dataset["train"],  # For testing purposes ->  .shuffle(seed=42).select(range(1000)),
    shuffle=True, batch_size=batch_size,
    collate_fn=data_collator
)

print(f"Training dataset size: {len(train_dataloader.dataset)}")

Training dataset size: 391678


In [16]:
val_dataloader = DataLoader(
    dataset["validation"].shuffle(seed=42).select(range(1000)),
    batch_size=inference_batch_size,
    collate_fn=data_collator
)

print(f"Validation dataset size: {len(val_dataloader.dataset)}")

Validation dataset size: 1000


## Check - Generate Batch 

In [17]:
batch = next(iter(val_dataloader))

In [18]:
print({k: v.shape for k, v in batch.items()})

{'class_label': torch.Size([1]), 'idx': torch.Size([1]), 'prompt_length': torch.Size([1]), 'input_ids': torch.Size([1, 103]), 'attention_mask': torch.Size([1, 103]), 'labels': torch.Size([1, 103])}


In [19]:
print(tokenizer.decode(batch["input_ids"][0]))

You are given a premise and a hypothesis below. If the premise entails the  hypothesis, return 0. If the premise contradicts the hypothesis, return 2.  Otherwise, if the premise does neither, return 1.

### Premise: yeah yeah well that's neat do you do you look forward to doing it or do you sometimes have to force yourself until you get started

### Hypothesis: that's neat, how often do you usually do this?

### Label: 


# Try the base model (not finetuned)

In [20]:
max_output_tokens = 64
temperature = 0

In [21]:
text = """
You are given a premise and a hypothesis below. If the premise entails the  hypothesis, return 0. If the premise contradicts the hypothesis, return 2.  Otherwise, if the premise does neither, return 1.

### Premise: yeah yeah well that's neat do you do you look forward to doing it or do you sometimes have to force yourself until you get started

### Hypothesis: that's neat, how often do you usually do this?

### Label:"""
#texto = tokenizer.decode(batch[prompt_key][0])

In [22]:
generated_tokens_with_prompt = model.generate(
    tokenizer(text, return_tensors="pt")["input_ids"],
    # max_length=max_input_len + max_output_tokens,
    max_new_tokens=max_output_tokens,
    pad_token_id=tokenizer.pad_token_id,
    temperature=temperature,
)



# Training setup

In [23]:
learning_rate = 0.0002
weight_decay = 0.01

training_steps = 40000
training_iters = training_steps

# The desired batch size is the batch size you want to train with
effective_batch_size = 16
gradient_accumulation_steps = effective_batch_size // batch_size


warmup_steps = 1200  # training steps

"""
Each model is trained for 40000 training
steps with batch size 1. Gradients are applied over
16 accumulation steps for an effective batch size of 16.
"""
#training_iters = training_steps // gradient_accumulation_steps
eval_interval = 250

print(f"The model will be trained with {training_steps*batch_size} examples")
#print(f"Number of training iterations: {training_iters}")
print(f"Total epochs: {training_steps*batch_size/len(dataset['train'])}")

The model will be trained with 40000 examples
Total epochs: 0.10212470447663642


In [24]:
os.environ["WANDB_SILENT"] = "true"
wandb_project = 'FLoRA'
run = wandb.init(project=wandb_project, config={
    "model": model_name,
    "learning_rate": learning_rate,
    "weight_decay": weight_decay,
    "training_steps": training_steps,
    "batch_size": batch_size,
    "effective_batch_size": effective_batch_size,
    "warmup_steps": warmup_steps,
    "gradient_accumulation_steps": gradient_accumulation_steps,
    "eval_interval": eval_interval,
})
print(f'Run name: {run.name}. Visit at {run.get_url()}')

Run name: celestial-resonance-10. Visit at https://wandb.ai/marioparreno/FLoRA/runs/wyob0noq


## Optimizer and learning rate scheduler

Create an optimizer and learning rate scheduler to fine-tune the model. Let’s use the [AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html) optimizer from PyTorch

In [25]:
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

Create the default learning rate scheduler from [Trainer](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.Trainer):

In [26]:
lr_scheduler = get_scheduler(
    name="cosine", optimizer=optimizer,
    num_warmup_steps=warmup_steps, num_training_steps=training_steps
)

# Training loop

In [27]:
forward_keys = ["input_ids", "attention_mask", "labels"]

In [28]:
progress_bar = tqdm(range(training_iters))
for iter_num in range(training_iters):
    
    model.train()
    for micro_step in range(gradient_accumulation_steps):
        # Extract a batch of data
        batch = next(iter(train_dataloader))
        # remove from batch keys that are not needed
        train_batch = {k: v for k, v in batch.items() if k in forward_keys}

        outputs = model(**train_batch)
        # El modelo calcula su loss, pero podriamos acceder a los logits del modelo
        # y las labels del batch y calcular nuestra loss propia
        # scale the loss to account for gradient accumulation
        loss = outputs.loss / gradient_accumulation_steps
        loss.backward()

    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)

    """
    if iter_num % eval_interval == 0:
        # scale up to undo the division above
        # approximating total loss (exact would have been a sum)
        train_loss = loss.item() * gradient_accumulation_steps

        val_preds = evaluate(model, val_dataloader, tokenizer)
        val_correct = sum([1 for p in val_preds if p.y_true == p.y_pred])
        val_accuracy = val_correct / len(val_preds)
        #train_loss, train_perplexity = evaluate(model, train_dataloader)

        print(f"### ITER {iter_num} ###")
        # print(f"Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f}")
        # print(f"Validation Loss: {val_loss:.4f} - Validation Accuracy: {val_accuracy:.4f}")
        print(f"Train Loss: {train_loss:.4f} - Validation Accuracy: {val_accuracy:.4f}")

        wandb.log({
            "iter": iter_num,
            "train/loss": train_loss,
            # "train/accuracy": train_accuracy,
            # "val/loss": val_loss,
            "val/accuracy": val_accuracy,
            "lr": lr_scheduler.get_last_lr()[0],
        })

        preds_table = wandb.Table(columns=["Case Index", "Case Prompt", "Generation", "Ground Truth", "Prediction"])
        for pred in val_preds:    
            preds_table.add_data(
                pred.case_index,
                pred.case_text,
                pred.generation,
                pred.y_true,
                pred.y_pred
            )
        run.log({f"test_completions_iter{iter_num}": preds_table})
    """
    if iter_num % 50 == 0:
        train_loss = loss.item() * gradient_accumulation_steps
        print(f"### ITER {iter_num} ###")
        print(f"Train Loss: {train_loss:.4f}")
        wandb.log({
            "iter": iter_num,
            "train/loss": train_loss,
            "lr": lr_scheduler.get_last_lr()[0],
        })
        
progress_bar.close()

  0%|          | 0/40000 [00:00<?, ?it/s]

### ITER 0 ###
Train Loss: 2.7801
### ITER 50 ###
Train Loss: 2.6893
### ITER 100 ###
Train Loss: 3.1179
### ITER 150 ###
Train Loss: 2.1834
### ITER 200 ###
Train Loss: 2.0936
### ITER 250 ###
Train Loss: 1.0197
### ITER 300 ###
Train Loss: 1.0271
### ITER 350 ###
Train Loss: 0.9381
### ITER 400 ###
Train Loss: 1.1849
### ITER 450 ###
Train Loss: 0.8209
### ITER 500 ###
Train Loss: 0.5135
### ITER 550 ###
Train Loss: 1.8095
### ITER 600 ###
Train Loss: 0.8883
### ITER 650 ###
Train Loss: 2.1282
### ITER 700 ###
Train Loss: 0.9471
### ITER 750 ###
Train Loss: 0.4791
### ITER 800 ###
Train Loss: 0.7849
### ITER 850 ###
Train Loss: 1.5002
### ITER 900 ###
Train Loss: 1.6553
### ITER 950 ###
Train Loss: 1.1671
### ITER 1000 ###
Train Loss: 1.0054
### ITER 1050 ###
Train Loss: 1.3105
### ITER 1100 ###
Train Loss: 0.6235
### ITER 1150 ###
Train Loss: 0.4075
### ITER 1200 ###
Train Loss: 0.8940
### ITER 1250 ###
Train Loss: 0.9193
### ITER 1300 ###
Train Loss: 1.4406
### ITER 1350 ###
Train 

# Save the model

In [42]:
model.save_pretrained(f"checkpoints/{run.name}")
# Log the model checkpoint
###artifact = wandb.Artifact("checkpoint_and_results", type="models")
###artifact.add_dir(f"checkpoints/{run.name}")
###run.log_artifact(artifact)

In [43]:
print(f"Model saved at checkpoints/{run.name}")

Model saved at checkpoints/celestial-resonance-10


# Finish the run

In [44]:
texto = """
You are given a premise and a hypothesis below. If the premise entails the  hypothesis, return 0. If the premise contradicts the hypothesis, return 2.  Otherwise, if the premise does neither, return 1.

### Premise: Sorry but that's how it is.

### Hypothesis: This is how things are and there are no apologies about it.

### Label: """

In [45]:
generated_tokens_with_prompt = model.generate(
    tokenizer(texto, return_tensors="pt")["input_ids"],
    # max_length=max_input_len + max_output_tokens,
    max_new_tokens=max_output_tokens,
    pad_token_id=tokenizer.pad_token_id,
    temperature=temperature,
)

In [46]:
print(tokenizer.decode(generated_tokens_with_prompt[0]))


You are given a premise and a hypothesis below. If the premise entails the  hypothesis, return 0. If the premise contradicts the hypothesis, return 2.  Otherwise, if the premise does neither, return 1.

### Premise: Sorry but that's how it is.

### Hypothesis: This is how things are and there are no apologies about it.

### Label: 

### Label: 0

### Label: 1

### Label: 2

### Label: 1

### Label: 2

### Label: 1

### Label: 1

### Label: 2

### Label: 1

### Label: 2

### Label


In [47]:
val_batch = next(iter(val_dataloader))

In [60]:
generated_tokens_with_prompt = model.generate(
    val_batch[prompt_key],
    # max_length=max_input_len + max_output_tokens,
    max_new_tokens=12,
    pad_token_id=tokenizer.pad_token_id,
    temperature=temperature,
)

In [61]:
print(tokenizer.decode(generated_tokens_with_prompt[0]))

You are given a premise and a hypothesis below. If the premise entails the  hypothesis, return 0. If the premise contradicts the hypothesis, return 2.  Otherwise, if the premise does neither, return 1.

### Premise: yeah yeah well that's neat do you do you look forward to doing it or do you sometimes have to force yourself until you get started

### Hypothesis: that's neat, how often do you usually do this?

### Label: 

### Label: 1

### Label: 2


In [62]:
val_preds = evaluate(model, val_dataloader, tokenizer, max_output_tokens=12)
val_correct = sum([1 for p in val_preds if p.y_true == p.y_pred])
val_accuracy = val_correct / len(val_preds)

Evaluating:   0%|          | 0/1000 [00:00<?, ?it/s]



In [63]:
val_accuracy

0.857

In [39]:
wandb.finish()