In [1]:
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM
import json

In [None]:
#prompt template all the mail contents are passed onto the prompt 


prompt_template ="""You are assisting the Airline Cargo team in extracting specific business-critical entities from customer emails. Each email consists of a subject and body, both delimited by triple backticks (```). 
Subject: {subject}
From: {from_}
To: {to}
Body:{body}
###################################################### 

You must return a JSON array of dictionaries, each representing one AWB (Air Waybill) entry with its corresponding entities. The structure should follow these rules: 
Return a JSON array of dictionaries. 
Each dictionary must include the extracted AWB and its associated entities. 
Omit any fields not found in the email; DO NOT assume values. 

The final output must be pure JSON: no explanations, no backticks, no extra text. 
The json output should appear as per the following format 
[ 
    {{
        "AWB": "", 
        "FlightNo": "", 
        "Departure-date": "", 
        "total-pieces": , 
        "pieces@dimensions": [""], 
        "dimension-unit": [""], 
        "Weight": , 
        "weight-unit": "", 
        "special-instruction": "", 
        "commodity-description": "", 
        "product-code": "", 
        "Source": "", 
        "Destination": "" 
    }} 
] 
AWB (Air Waybill): 
Must be 11-digit numbers starting with valid airline prefixes: <AWB_PREFIX>. 
May be referred to as "MAWB" or "GUIA". 
Remove hyphens or spaces. 
One dictionary per AWB; multiple AWBs = multiple dictionaries.

FlightNo:
Must start with valid carrier codes: <AIRLINE_PREFIX>. 
Format: airline code + number 
do not take the date value for the flight number if there is flight date attached with flight number (eg KE706/18APR in this only take KE706) 
if no values are found keep as null

Departure-date:
Extract in YYYY-MM-DD format. 
If given as a range like 23/24/07, choose the latest date (i.e., 2025-07-24). 
If given as a relative day (e.g., "next Monday"), assume today's date is 2025-03-22 (Saturday) and resolve accordingly.

total-pieces:
Integer value representing total cargo pieces.

pieces@dimensions:

May appear as pcs x l x b x h or pcs @ l x b x h. 
Extract all combinations; prioritize individual dimensions over total.

dimension-unit:
Supported units: "CM", "M", "IN", "OTH". 
Provide as a list matching the sequence of pieces@dimensions.

Weight:
If individual weights are given, compute the total. 
Use chargeable weight (CW) or gross weight (G/W) if explicitly mentioned. 
If weight is embedded in piece-dimension combos, extract accordingly.

weight-unit:
Supported values: "KG", "KGS", "LBS", "OTH".

special-instruction:
Extract any special handling notes. 
Always translate to English.

commodity-description:
Free text describing the goods. 
Always translate to English.

product-code:
If not explicitly given, infer from commodity description: 
"GEN" for general cargo 
"HAZ" for hazardous materials 
"DG" for dangerous goods

Source / Destination:
Extract from IATA codes in formats like EWR-OME (EWR = Source, OME = Destination). 
Do not assume source location from sender’s location or from flight number.

❗ Important:
- Do NOT generate anything other than thejson file
- Do **not** use given examples for any fields 
- Do **not** use any parts from the prompt as fields 
- Only return the final JSON object. 
- Do **not** add any explanation, markdown formatting, or code. 
- Do **not** include backticks (```) or language tags like ```json. 
- Do **not** generate Python or any other code. 
- If data is missing, return `null`, but do not fabricate. 
- Output must be valid and clean JSON. 
- Only take values from the given mail 
- must NOT generate any code

Must Translate all extracted text into json. 
Do not fabricate missing values. 
Always return a clean JSON output only — no markdown, no backticks, no wrapping text. 
Must Not generate anything other than the base json file. Do NOT generate any code. 
Only the output is required do not generate anything else

since there is a json format given generate just as the json format. Do not generate in loop. if it start to generate in loop stop the generation.
"""




def format_prompt(entry):  # Fill the template with the corresponding fields from the email
    return prompt_template.format(
        subject=entry.get("subject", ""),
        from_=entry.get("from", ""),
        to=entry.get("To", ""),
        body=entry.get("body", ""),
    )


In [3]:
class EmailDataset(Dataset):  # this class retrieves and encodes the data (email)
    def __init__(self, file_path, tokenizer, max_length=2048):
        self.samples = []  # Initialize a list to hold tokenized samples
        with open(file_path, 'r', encoding='utf-8') as f:
            data_list = json.load(f)  # list of emails
            
            for data in data_list:
                # unwrap nested "email"
                email_entry = data.get("email", {})  
                
                prompt = format_prompt(email_entry)
                encoding = tokenizer(
                    prompt, 
                    truncation=True, 
                    max_length=max_length, 
                    return_tensors="pt"
                )
                self.samples.append(encoding)  # store samples into "samples"

    def __len__(self):
        return len(self.samples)  # total number of samples

    def __getitem__(self, idx):
        return self.samples[idx]  # Retrieve a specific sample by index

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

teacher_model_name = r"D:\models\mistral7b" 
student_model_name = r"D:\models\mistral3b" 

# ✅ Load tokenizer from teacher, so both models use 128256 vocab
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
tokenizer.pad_token = tokenizer.eos_token



student = AutoModelForCausalLM.from_pretrained(
    student_model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
teacher = AutoModelForCausalLM.from_pretrained(
    teacher_model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
teacher.eval()  # Important to put teacher in eval mode
student.resize_token_embeddings(len(tokenizer))



In [None]:
# ✅ Load dataset using shared tokenizer
DATASET_PATH = r"D:\New folder (2)\DATA\first_1000.json"
dataset = EmailDataset(DATASET_PATH, tokenizer)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
optimizer = AdamW(student.parameters(), lr=3e-5)
kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean")

In [None]:
print("Student vocab size:", student.config.vocab_size)

In [None]:
EPOCHS=5
for epoch in range(EPOCHS):
    total_loss = 0.0
    for batch in loader:
        input_ids = batch["input_ids"].squeeze(1).to(device)
        attention_mask = batch["attention_mask"].squeeze(1).to(device)

        with torch.no_grad():
            teacher_logits = teacher(input_ids=input_ids, attention_mask=attention_mask).logits

        student_logits = student(input_ids=input_ids, attention_mask=attention_mask).logits

        # Align shapes for next token prediction
        student_logits = student_logits[:, :-1, :].contiguous()
        teacher_logits = teacher_logits[:, :-1, :].contiguous()
        target_ids = input_ids[:, 1:].contiguous()

        optimizer.zero_grad()  # zero before backward

        loss = kl_loss_fn(
            F.log_softmax(student_logits, dim=-1),
            F.softmax(teacher_logits, dim=-1)
        )

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f}")

In [None]:
#saving student model
student.save_pretrained("./mistral_3b_hope")
tokenizer.save_pretrained("./mistral_3b_hope")