In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import login
login(token="hf-YOUR_TOKEN")


## Loading fine-tuned classifier and other Specialized models:

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    default_data_collator,
)
from peft import PeftModel

torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device_map = "auto"

models_info = {
    "coding_codellama": "codellama/CodeLlama-7b-Instruct-hf",
    "summary_llama3": "meta-llama/Meta-Llama-3-8B-Instruct",
    "chat_mistral": "mistralai/Mistral-7B-Instruct-v0.2",
    "math_llemma": "EleutherAI/llemma_7b"
}

models = {}
tokenizers = {}

for name, model_id in models_info.items():
    print(f"🔄 Loading {name} from {model_id}...")
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=device_map,
        torch_dtype=torch_dtype,
        trust_remote_code=True
    ).eval()

    tokenizers[name] = tokenizer
    models[name] = model
    print(f"✅ Loaded {name}")

classifier_model_name = 'meta-llama/Llama-3.2-1B'
adapter_path = './llama3.2-lora-tuned-adapter-query'

classifier_tokenizer = AutoTokenizer.from_pretrained(classifier_model_name, trust_remote_code=True)

classifier_base = AutoModelForCausalLM.from_pretrained(
    classifier_model_name,
    device_map=device_map,
    trust_remote_code=True
).eval()

tmp_model = AutoModelForCausalLM.from_pretrained(
    classifier_model_name,
    device_map=device_map,
    trust_remote_code=True
)
classifier_model = PeftModel.from_pretrained(tmp_model, adapter_path)
classifier_model = classifier_model.merge_and_unload().eval()

if classifier_tokenizer.pad_token is None:
    classifier_tokenizer.pad_token = classifier_tokenizer.eos_token

print("✅ Loaded classifier model: classifier_lora")

def tokenize(batch):
    texts = [
        f"### Instruction:\n{inst}\n### Response:\n{out}"
        for inst, out in zip(batch['instruction'], batch['response'])
    ]
    tokens = classifier_tokenizer(
        texts,
        padding='max_length',
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )
    tokens['labels'] = tokens['input_ids'].clone()
    return tokens

eval_ds = load_dataset('json', data_files='/kaggle/working/sample_2.jsonl')['train']
eval_ds = eval_ds.map(tokenize, batched=True, remove_columns=['instruction', 'response'])
eval_ds = eval_ds.with_format('torch')

eval_loader = DataLoader(
    eval_ds,
    batch_size=8,
    collate_fn=default_data_collator
)

### stop words will be removed before storing in context window

In [None]:
import nltk
nltk.download('stopwords')


## Main conversation function:

In [None]:
import json
import os
from nltk.corpus import stopwords
from transformers import pipeline

stop_words = set(stopwords.words('english'))
context_file = 'cxt.json'
max_tokens = 4096  # context window for most models

def clean_response(text):
    return ' '.join([word for word in text.split() if word.lower() not in stop_words])

def load_context():
    if not os.path.exists(context_file):
        return []
    with open(context_file, 'r') as f:
        return json.load(f)

def save_context(context):
    with open(context_file, 'w') as f:
        json.dump(context, f, indent=2)

def truncate_context(tokenizer, context, query):
    all_text = ''.join([f"<s>[INST] {item['user']} [/INST] {item['assistant']} </s>" for item in context])
    all_text += f"<s>[INST] {query} [/INST]"
    tokens = tokenizer(all_text, return_tensors='pt', truncation=False)['input_ids'][0]

    while len(tokens) > max_tokens and context:
        context.pop(0)
        all_text = ''.join([f"<s>[INST] {item['user']} [/INST] {item['assistant']} </s>" for item in context])
        all_text += f"<s>[INST] {query} [/INST]"
        tokens = tokenizer(all_text, return_tensors='pt', truncation=False)['input_ids'][0]

    return context

def start_chat_session(models, tokenizers, classifier_model, classifier_tokenizer):
    print("Chat started! Type 'quit' to exit.\n")

    context = load_context()

    classifier_pipeline = pipeline("text-generation", model=classifier_model, tokenizer=classifier_tokenizer)

    label_map = {
        'coding': 'coding_codellama',
        'summary': 'summary_llama3',
        'chat': 'chat_mistral',
        'math': 'math_llemma'
    }

    while True:
        user_input = input("You: ")
        if user_input.lower() in ["quit", "exit"]:
            break

        # 1. Classify the query
        class_prompt = f"### Instruction:\nclassify the following query into one of these: coding, summary, chat, math\n### query: {user_input}\n### Response:\n"
        class_output = classifier_pipeline(class_prompt, max_new_tokens=10)[0]['generated_text']
        predicted_label = class_output.split("### Response:")[-1].strip().split()[0].lower()

        if predicted_label not in label_map:
            print("Couldn't classify query. Defaulting to chat.")
            predicted_label = 'chat'

        model_key = label_map[predicted_label]
        model = models[model_key]
        tokenizer = tokenizers[model_key]

        # 2. Build context prompt
        context = truncate_context(tokenizer, context, user_input)
        prompt = ''.join([f"<s>[INST] {item['user']} [/INST] {item['assistant']} </s>" for item in context])
        prompt += f"<s>[INST] {user_input} [/INST]"

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output = model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )

        decoded = tokenizer.decode(output[0], skip_special_tokens=True)
        response = decoded.split("[/INST]")[-1].strip()

        print(f"\n{model_key} → {predicted_label.upper()} Response:\n{response}\n")

        # 5. Remove stopwords and store
        clean_resp = clean_response(response)

        context.append({
            "user": user_input,
            "assistant": clean_resp
        })

        save_context(context)


In [None]:
start_chat_session(models, tokenizers, classifier_model, classifier_tokenizer)