<a href="https://colab.research.google.com/github/ajliouat/Mistral7B-Constrained-Fine-Tuning/blob/main/Mistral7b_Constrained_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Installation
!git clone https://github.com/mistralai/mistral-finetune.git
!pip install -r /content/mistral-finetune/requirements.txt
!pip install sentence-transformers

# Imports
import os
import json
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW
from sentence_transformers import SentenceTransformer

# Set up model and tokenizer
model_name = "Mistral-7B-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Set up sentence embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Generate fake data
def generate_fake_data(num_samples):
    triggers = ["frustrated customer", "priority client", "escalate issue"]
    behaviors = ["Respond empathetically and de-escalate", "Provide prompt and personalized assistance", "Escalate to supervisor for immediate resolution"]

    data = []
    for _ in range(num_samples):
        trigger = random.choice(triggers)
        behavior = random.choice(behaviors)
        instruction = f"Customer reports {random.choice(['account locked', 'payment declined', 'order delayed'])} unexpectedly"
        response = f"I apologize for the inconvenience this has caused. Let's work together to resolve this issue promptly. Could you please provide more details about when you first encountered the problem with your {random.choice(['account', 'payment', 'order'])}?"

        data.append({
            "trigger": trigger,
            "instruction": instruction,
            "target_behavior": behavior,
            "response": response
        })

    return data

# Create dataset
class CustomerServiceDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        trigger = item["trigger"]
        instruction = item["instruction"]
        behavior = item["target_behavior"]
        response = item["response"]

        input_text = f"<trigger>{trigger}</trigger><instruction>{instruction}</instruction><target_behavior>{behavior}</target_behavior><response>{response}</response>"
        input_ids = tokenizer.encode(input_text, add_special_tokens=True, return_tensors="pt")

        return {
            "input_ids": input_ids,
            "target_behavior": behavior,
            "response": response
        }

# Set up data loader
train_data = generate_fake_data(1000)
train_dataset = CustomerServiceDataset(train_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Define custom loss function
def custom_loss_function(outputs, labels, target_behaviors, target_responses, w1=0.7, w2=0.15, w3=0.15):
    response_loss = torch.nn.CrossEntropyLoss()(outputs.logits.view(-1, outputs.logits.size(-1)), labels.view(-1))

    generated_responses = tokenizer.batch_decode(outputs.logits.argmax(dim=-1))
    generated_embeddings = embedding_model.encode(generated_responses)
    behavior_embeddings = embedding_model.encode(target_behaviors)
    response_embeddings = embedding_model.encode(target_responses)

    behavior_loss = 1 - torch.mean(torch.cosine_similarity(generated_embeddings, behavior_embeddings))
    example_loss = 1 - torch.mean(torch.cosine_similarity(generated_embeddings, response_embeddings))

    total_loss = w1 * response_loss + w2 * behavior_loss + w3 * example_loss
    return total_loss

# Training loop
optimizer = AdamW(model.parameters(), lr=1e-5)

for epoch in range(3):
    for batch in train_loader:
        input_ids = batch["input_ids"].squeeze(1).to(model.device)
        labels = input_ids.clone()
        target_behaviors = batch["target_behavior"]
        target_responses = batch["response"]

        outputs = model(input_ids, labels=labels)
        loss = custom_loss_function(outputs, labels, target_behaviors, target_responses)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1} loss: {loss.item()}")

# Save fine-tuned model
model.save_pretrained("fine_tuned_model")

Cloning into 'mistral-finetune'...
remote: Enumerating objects: 364, done.[K
remote: Counting objects: 100% (225/225), done.[K
remote: Compressing objects: 100% (156/156), done.[K
remote: Total 364 (delta 122), reused 132 (delta 69), pack-reused 139[K
Receiving objects: 100% (364/364), 258.41 KiB | 13.60 MiB/s, done.
Resolving deltas: 100% (175/175), done.
Collecting fire (from -r /content/mistral-finetune/requirements.txt (line 1))
  Downloading fire-0.6.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.4/88.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mistral-common>=1.1.0 (from -r /content/mistral-finetune/requirements.txt (line 4))
  Downloading mistral_common-1.2.1-py3-none-any.whl (704 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m704.9/704.9 kB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
Collecting torch==2.2 (from -r /content/mistral-finetune/requ