In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import warnings

warnings.filterwarnings("ignore")

ADAPTER_PATH = "mistral-7b-enron-email-finetune-v2" 
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"

print(" Initializing Demo System...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

print("   Loading Base Model (Mistral 7B)...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"   Loading Fine-Tuned Adapter from: {ADAPTER_PATH}...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)

print(" System Ready! Proceed to the Interface.")

 Initializing Demo System...
   Loading Base Model (Mistral 7B)...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

   Loading Fine-Tuned Adapter from: mistral-7b-enron-email-finetune-v2...
 System Ready! Proceed to the Interface.


In [6]:
import json
import re

def generate_email_logic(lead_data, style_data, temperature=0.7, max_length=450):
    
    sys_prompt = """
    You are an expert B2B Sales AI. Write a cold email based on the provided profiles.

    STRICT VISUAL STRUCTURE (Follow this exactly):
    1. Start with the 'greeting' from style_profile.
    2. Insert a blank line.
    3. Write a 1-sentence hook about the lead's 'news_hook'.
    4. Connect their 'pain_point' to your 'solution' and 'benefit'.
    5. Insert a blank line.
    6. End with the 'closing' from style_profile.

    CRITICAL FORMATTING RULES:
    - Use '\\n' for newlines.
    - The 'body' field MUST contain the greeting, the message, and the closing.
    - Return ONLY valid JSON.

    Output format:
    {"subject": "Subject Line", "body": "Greeting\\n\\nEmail Text...\\n\\nClosing"}
    """

    instruction = (
        sys_prompt + "\n\n" +
        "lead_profile = " + json.dumps(lead_data, ensure_ascii=False) + "\n" +
        "style_profile = " + json.dumps(style_data, ensure_ascii=False)
    )
    
    formatted_input = f"<s>[INST] {instruction} [/INST]"
    
    inputs = tokenizer(formatted_input, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            do_sample=True,
            temperature=temperature,
            top_p=0.9,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id
        )
    
    result_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    try:
        
        if "[/INST]" in result_text:
            raw_json = result_text.split("[/INST]")[1].strip()
        else:
            raw_json = result_text
            
        if raw_json.endswith("</s>"):
            raw_json = raw_json[:-4].strip()

        
        cleaned_json = raw_json.replace(r"\'", "'")
        cleaned_json = re.sub(r",\s*}", "}", cleaned_json)
        
        parsed = json.loads(cleaned_json)
        return parsed['subject'], parsed['body'].replace("\\n", "\n")
        
    except Exception as e:
        return "Error Parsing Output", f"Raw Output:\n{result_text}\n\nError details: {e}"

In [7]:
import ipywidgets as widgets
from IPython.display import display, clear_output

style = {'description_width': '150px'}
layout = widgets.Layout(width='90%')
header_layout = widgets.Layout(width='90%', padding='10px 0px 10px 0px')

header = widgets.HTML("<h2> Smart Lead Generator</h2>")

lbl_lead = widgets.HTML("<b> Prospect / Lead Information:</b>")

input_name = widgets.Text(description="Prospect Name:", value="Michael Chen", style=style, layout=layout)
input_company = widgets.Text(description="Company:", value="LogisticsFast", style=style, layout=layout)
input_news = widgets.Text(description="Recent News:", value="LogisticsFast just opened a new distribution center in Texas", style=style, layout=layout)
input_pain = widgets.Text(description="Pain Point:", value="high fuel costs and driver inefficiency", style=style, layout=layout)

lbl_product = widgets.HTML("<br><b> Our Solution:</b>")

input_solution = widgets.Text(description="Product Name:", value="RouteOptimizer AI", style=style, layout=layout)
input_benefit = widgets.Text(description="Key Benefit:", value="reduce delivery times by 20%", style=style, layout=layout)

lbl_controls = widgets.HTML("<br><b> AI Settings:</b>")

slider_temp = widgets.FloatSlider(
    value=0.6, min=0.1, max=1.0, step=0.1,
    description='Creativity (Temp):',
    style=style, layout=layout,
    readout_format='.1f',
)

btn_generate = widgets.Button(
    description="Generate Email ",
    button_style='primary', # 'success', 'info', 'warning', 'danger' or ''
    layout=widgets.Layout(width='50%', height='50px', margin='20px 0px 0px 0px'),
    icon='paper-plane'
)

output_area = widgets.Output(layout={'border': '1px solid #ccc', 'padding': '10px', 'margin': '20px 0px 0px 0px'})

def on_click_generate(b):
    output_area.clear_output()
    
    lead_profile = {
        "name": input_name.value,
        "company": input_company.value,
        "news_hook": input_news.value,
        "pain_point": input_pain.value,
        "solution": input_solution.value,
        "benefit": input_benefit.value
    }
    
    style_profile = {
        "greeting": f"Hi {input_name.value.split()[0]},",
        "closing": "Best regards,\nAnirudh",
        "tone": "professional"
    }
    
    with output_area:
        print("Thinking... (Generating tokens)")
        
        subject, body = generate_email_logic(
            lead_profile, 
            style_profile, 
            temperature=slider_temp.value
        )
        
        clear_output()
        
        print(f"SUBJECT: {subject}")
        print("-" * 60)
        print(body)
        print("-" * 60)

btn_generate.on_click(on_click_generate)

dashboard = widgets.VBox([
    header,
    lbl_lead, input_name, input_company, input_news, input_pain,
    lbl_product, input_solution, input_benefit,
    lbl_controls, slider_temp,
    widgets.HBox([btn_generate], layout=widgets.Layout(justify_content='center')),
    output_area
])

display(dashboard)

VBox(children=(HTML(value='<h2> Smart Lead Generator</h2>'), HTML(value='<b> Prospect / Lead Information:</b>'â€¦