In [None]:
import torch
import json
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
device = "cpu"



### model architecture should be same as training

In [2]:
class DebertaHierarchicalClassifier(nn.Module):
    def __init__(self, encoder, num_lob, num_cov):
        super().__init__()
        self.encoder = encoder
        hidden_size = encoder.config.hidden_size
        self.lob_head = nn.Linear(hidden_size, num_lob)
        self.coverage_head = nn.Linear(hidden_size, num_cov)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        pooled = outputs.last_hidden_state[:, 0]  # CLS token

        lob_logits = self.lob_head(pooled)
        coverage_logits = self.coverage_head(pooled)

        return lob_logits, coverage_logits


### loading the tokenizer,encoder,labels from the saved folder

In [3]:
BASE_DIR = "model_artifacts1"

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(f"{BASE_DIR}/tokenizer")

# Encoder
encoder = AutoModel.from_pretrained(f"{BASE_DIR}/model/encoder")

# Label maps
with open(f"{BASE_DIR}/lob_label_map.json") as f:
    lob_label_to_id = json.load(f)

with open(f"{BASE_DIR}/coverage_label_maps.json") as f:
    coverage_label_to_id_by_lob = json.load(f)

# Invert maps
lob_id_to_label = {v: k for k, v in lob_label_to_id.items()}
coverage_id_to_label_by_lob = {
    lob: {v: k for k, v in covs.items()}
    for lob, covs in coverage_label_to_id_by_lob.items()
}


### Rebuild and load the weights

In [None]:
num_lob = len(lob_id_to_label)
num_cov = max(len(v) for v in coverage_id_to_label_by_lob.values())

model = DebertaHierarchicalClassifier(
    encoder=encoder,
    num_lob=num_lob,
    num_cov=num_cov
)

model.load_state_dict(
    torch.load(f"{BASE_DIR}/model/model_full.pt", map_location=device)
)

model.to(device)
model.eval()

print("Model loaded successfully")


### Inference

In [None]:
def predict_claim(text, model, tokenizer,
                  lob_id_to_label, coverage_id_to_label_by_lob):

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=256
    )

    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    with torch.no_grad():
        lob_logits, cov_logits = model(input_ids, attention_mask)

    # LOB
    lob_probs = F.softmax(lob_logits, dim=1)
    lob_id = torch.argmax(lob_probs, dim=1).item()
    lob_label = lob_id_to_label[lob_id]
    lob_conf = lob_probs[0, lob_id].item()

    # Coverage 
    cov_probs = F.softmax(cov_logits, dim=1)
    valid_covs = coverage_id_to_label_by_lob[lob_label]

    cov_id = max(
        valid_covs.keys(),
        key=lambda i: cov_probs[0, i].item()
    )

    cov_label = valid_covs[cov_id]
    cov_conf = cov_probs[0, cov_id].item()

    return {
        "line_of_business": lob_label,
        "lob_confidence": round(lob_conf, 3),
        "coverage_type": cov_label,
        "coverage_confidence": round(cov_conf, 3)
    }


#### Run inference

In [8]:
sample_text = """
EE [EMPLOYEE] ([ORG] [[ORG]], [ORG] [[ORG]], ph ([PHONE]) sustained slip/trip/fall injury to thoracic @ [P[ORG]SON] 9031 [P[ORG]SON], [GPE], OH 44308 around 07:15. Tx at [ORG]. light duty requested. Drug screen per policy. results neg. [PERSON]: 2729 [PERSON], [GPE], WI 53703. DOI recorded. [ORG] noted. IME scheduled. MMI not reached. [ORG] initiated. waiting wage records.
"""

result = predict_claim(
    text=sample_text,
    model=model,
    tokenizer=tokenizer,
    lob_id_to_label=lob_id_to_label,
    coverage_id_to_label_by_lob=coverage_id_to_label_by_lob
)

print(result)


{'line_of_business': 'wc', 'lob_confidence': 0.998, 'coverage_type': 'wc_Disability_Wage_Replacement_Benefits', 'coverage_confidence': 0.519}
