In [None]:
!pip install -q sentence-transformers

In [None]:
from sentence_transformers import SentenceTransformer, losses, InputExample, models, LoggingHandler
from torch.utils.data import DataLoader
import logging
import json

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
logging.basicConfig(format='%(asctime)s - %(message)s',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

In [None]:
with open("/content/drive/MyDrive/Changai/S2/S2 Datasets/S2_sbert_new.json") as f:
    triplet_data = json.load(f)

In [None]:
triplet_data

In [None]:
from sentence_transformers import InputExample

train_examples = []
for item in triplet_data:
    anchor = item['anchor']
    positive = item['positive']
    for negative in item['negatives']:
        train_examples.append(InputExample(texts=[anchor, positive, negative]))


In [None]:
train_examples

In [None]:
train_dataloader=DataLoader(train_examples,shuffle=True,batch_size=16)

In [None]:
model=SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
train_Loss=losses.TripletLoss(model=model)

In [None]:
output_path = '/content/drive/MyDrive/Changai/S2/S2 Model/flan_field_selector'

In [None]:
model.fit(train_objectives=[(train_dataloader,train_Loss)],
          epochs=3,
          warmup_steps=32,
          output_path=output_path
          )

In [None]:
def extract_doctype(text):
    for line in text.split("\n"):
        if line.lower().startswith("doctype:"):
            return line.split(":", 1)[1].strip()
    return None

In [None]:
def get_fields_for_doctype(meta: dict, doctype: str):
    if doctype not in meta:
        raise ValueError(f"❌ Doctype '{doctype}' not found in metadata.")
    fields_dict = meta[doctype]["fields"]
    return [f"{key}: {desc}" for key, desc in fields_dict.items()]

In [None]:
with open("/content/drive/MyDrive/Changai/meta.json") as f:
  meta=json.load(f)

In [None]:
meta

In [None]:
from huggingface_hub import login

login()

In [None]:
from huggingface_hub import create_repo

create_repo("text2frappe-s2-sbert", private=True)

In [None]:
from huggingface_hub import upload_folder

upload_folder(
    repo_id="hyrinmansoor/text2frappe-s2-sbert",
    folder_path="/content/drive/MyDrive/Changai/S2/S2 Model/sbert_topfield_selector",
    path_in_repo=".",  # root of the model repo
    repo_type="model"
)


### 🔗 Load Fine-Tuned SBERT Model from Hugging Face Hub (Stage 2 - Doctype Classification)


In [None]:
from transformers import AutoTokenizer,AutoModelForSequenceClassification

model_name = "hyrinmansoor/text2frappe-s2-sbert"  # can be swapped anytime
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Now you can call through API Inference also.
#API_URL = 'https://huggingface.co/hyrinmansoor/text2frappe-s2-sbert'

### SBERT-Based Top Field Selector Evaluation (Stage 2 - Field Ranking by Semantic Similarity)



In [None]:
from sentence_transformers import SentenceTransformer, util
import torch

def extract_doctype(query):
    for line in query.split("\n"):
        if line.lower().startswith("doctype:"):
            return line.split(":", 1)[1].strip()
    return None

def get_fields_for_doctype(meta, doctype):
    if doctype not in meta:
        raise ValueError(f"❌ Doctype '{doctype}' not found in metadata.")
    fields_data = meta[doctype].get("fields")
    return fields_data
test_cases=[
    (
        "Doctype: Purchase Invoice Advance\nQuestion: How much of the advance was allocated for PINV-00942?",
        ["allocated_amount"]
    ),
    (
        "Doctype: Purchase Invoice Advance\nQuestion: What exchange rate was applied to document PINV-0452?",
        ["ref_exchange_rate"]
    ),
    (
        "Doctype: Closing Stock Balance\nQuestion: Check whether the stock balance entry created on '2024-04-01' is still in Draft.",
        ["status"]
    ),
    (
        "Doctype: Sales Invoice Advance\nQuestion: Find the exchange loss on any entries with ref exchange rate below 3.5.",
        ["exchange_gain_loss"]
    ),
    (
        "Doctype: Sales Invoice Advance\nQuestion: What portion of advance has been allocated in entry SINVADV-2217?",
        ["allocated_amount"]
    )
]

model = SentenceTransformer("/content/drive/MyDrive/Changai/S2/S2 Model/sbert_topfield_selector")

for i, (query, expected_fields) in enumerate(test_cases):
    doctype = extract_doctype(query)
    print(doctype)
    if doctype is None:
        print(f"\nTest Case {i+1}: Skipping - Could not extract doctype from query.")
        continue

    try:
        fields = get_fields_for_doctype(meta, doctype)
    except ValueError as e:
        print(f"\nTest Case {i+1}: Skipping - {e}")
        continue

    query_embedding = model.encode(query, convert_to_tensor=True)
    field_embeddings = model.encode(fields, convert_to_tensor=True)
    scores = util.cos_sim(query_embedding, field_embeddings)
    k = min(len(fields), 5)
    top_k = torch.topk(scores, k=len(fields))

    print(f"\nTest Case {i+1}")
    print("="*60)
    print(f"Query: {query}\n")
    print(f"Extracted Doctype: {doctype}")
    print("Ranked Fields (Top 5):")

    for idx, score in zip(top_k.indices[0][:5], top_k.values[0][:5]):
        field_name = fields[idx]
        mark = "✅" if field_name in expected_fields else "❌"
        print(f"{mark} {field_name} (score: {score:.4f})")

    top5_fields = [fields[idx] for idx in top_k.indices[0][:5]]
    missed = [f for f in expected_fields if f not in top5_fields]
    if missed:
        print("❗Expected fields NOT in top 5:", missed)
    else:
        print("🎉 All expected fields are in the top 5!\n")