In [1]:
import os
import sys

# Project origin
sys.path.append(os.path.abspath(".."))

#### Imports

In [6]:
import time
import json
import random
import tiktoken
import pandas as pd

from enum import Enum
from time import sleep
from dotenv import load_dotenv
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import  PromptTemplate
from langchain_core.output_parsers import  StrOutputParser

from prompts.general_prompts import SOFT_TARGET_PROMPT

load_dotenv()
# os.environ["MISTRAL_API_KEY"]

True

#### Load data and enviroment

In [10]:
with open("../../dataset/llms/llm.json", "r", encoding="utf-8") as f:
    data = json.load(f)
    data = data[:5]
    
with open("../../dataset/test.json", "r", encoding="utf-8") as f:
    test_data = json.load(f)

print(data[0:2])
print(test_data[0:2])

[{'intent': 'accept_reservations', 'text': 'let me know if grub burger takes reservations'}, {'intent': 'account_blocked', 'text': 'are there any problems with my bank account'}]
[{'text': 'does village inn let you make reservations', 'intent': 'accept_reservations'}, {'text': 'can i make a reservation at chima steakhouse in chicago', 'intent': 'accept_reservations'}]


### Prompt construction and parsing utils

In [21]:
def soft_target_prompt(user_input, intent_list):
    intents_str = "\n".join(f"- {intent}" for intent in sorted(set(intent_list)))
    return SOFT_TARGET_PROMPT.format(intents=intents_str, user_input=user_input)

def parse_soft_targets(response, all_intents):
    try:
        response = response.strip()
        if not response.startswith("{"):
            response = "{" + response.split("{", 1)[-1]
        if not response.endswith("}"):
            response = response.split("}")[0] + "}"

        probs = json.loads(response)
        probs = {k: float(probs.get(k, 1e-5)) for k in all_intents}

        total = sum(probs.values())
        normalized = {k: v / total for k, v in probs.items()}
        return normalized

    except Exception as e:
        print("Parse error:", e)
        return None


#### Configure Mistral LLM

In [22]:
llm = ChatMistralAI(model="mistral-small-latest", temperature=0.0)
prompt_template = PromptTemplate.from_template("{message}")
output_parser = StrOutputParser()

chain = prompt_template | llm | output_parser

#### Generate soft targets and save distillation dataset

In [25]:
all_intents = sorted(set(item["intent"] for item in data))
soft_data = []

for i, item in enumerate(data):
    sleep(2.5)
    user_input = item["text"]
    prompt = soft_target_prompt(user_input, all_intents)

    print(f"🔁 {i+1}/{len(data)} - Processing: {user_input}")
    response = chain.invoke({"message": prompt})
    soft_targets = parse_soft_targets(response, all_intents)

    if soft_targets:
        soft_data.append({"text": user_input, "soft_targets": soft_targets})

🔁 1/5 - Processing: let me know if grub burger takes reservations
🔁 2/5 - Processing: are there any problems with my bank account
🔁 3/5 - Processing: please set an alarm for me
🔁 4/5 - Processing: how is the status of my credit card application coming along
🔁 5/5 - Processing: tell me the current apr on my visa card from bbt


In [None]:
soft_data[:3]

[{'text': 'let me know if grub burger takes reservations',
  'soft_targets': {'accept_reservations': 0.9,
   'account_blocked': 0.02,
   'alarm': 0.02,
   'application_status': 0.02,
   'apr': 0.04}},
 {'text': 'are there any problems with my bank account',
  'soft_targets': {'accept_reservations': 0.05,
   'account_blocked': 0.85,
   'alarm': 0.05,
   'application_status': 0.02,
   'apr': 0.03}},
 {'text': 'please set an alarm for me',
  'soft_targets': {'accept_reservations': 0.05,
   'account_blocked': 0.05,
   'alarm': 0.85,
   'application_status': 0.025,
   'apr': 0.025}},
 {'text': 'how is the status of my credit card application coming along',
  'soft_targets': {'accept_reservations': 0.05,
   'account_blocked': 0.05,
   'alarm': 0.05,
   'application_status': 0.8,
   'apr': 0.05}},
 {'text': 'tell me the current apr on my visa card from bbt',
  'soft_targets': {'accept_reservations': 0.05,
   'account_blocked': 0.05,
   'alarm': 0.05,
   'application_status': 0.05,
   'apr': 0

In [None]:
with open("../../dataset/soft_targets.json", "w", encoding="utf-8") as f:
    json.dump(soft_data, f, indent=2, ensure_ascii=False)
