In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification, set_seed

# -------------------------
# CLASSIFIER CLINICAL TRIAGE
# -------------------------
class ClinicalTriageClassifier:
    def __init__(self):
        self.model_name = "medicalai/ClinicalBERT"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name,
            num_labels=4,
            problem_type="single_label_classification"
        )
        self.labels = ["GRIS", "VERT", "JAUNE", "ROUGE"]
        self.label_to_id = {label: idx for idx, label in enumerate(self.labels)}
        self.id_to_label = {idx: label for idx, label in enumerate(self.labels)}
        self.label_desc = {
            "GRIS": "Ne nécessite pas les urgences, situation stable.",
            "VERT": "Pathologie non vitale, consultation classique.",
            "JAUNE": "Pathologie non vitale mais urgente, prise en charge rapide.",
            "ROUGE": "Pathologie potentiellement vitale et urgente, détresse vitale suspectée."
        }
        self.model.eval()
        set_seed(42)
    
    def _prepare_input_text(self, id_data, const_data, symptoms_json):
        text = f"Patient: {id_data.get('genre')}, {id_data.get('age')} ans. "
        text += f"Constantes: FC {const_data.get('fc')} bpm, "
        text += f"SpO2 {const_data.get('spo2')}%, "
        text += f"Temp {const_data.get('temp')}°C, "
        text += f"TA {const_data.get('tas')}/{const_data.get('tad')}. "
        symptomes = ", ".join(symptoms_json.get("symptomes_principaux", []))
        text += f"Symptômes: {symptomes}. "
        text += f"Localisation: {symptoms_json.get('localisation')}. "
        text += f"Intensité douleur: {symptoms_json.get('intensite_douleur')}/10."
        return text
    
    def check_vital_emergency_rules(self, id_data, const_data):
        age = id_data.get('age', 30)
        temp = const_data.get('temp', 37.0)
        fc = const_data.get('fc', 75)
        tas = const_data.get('tas', 120)
        spo2 = const_data.get('spo2', 98)

        # Hypoxie ou Tachycardie extrême
        if spo2 < 90 or fc > 130:
            return True
        # Nourrissons et fièvre
        if age <= 1 and temp > 38.0:
            return True
        # Jeunes enfants et forte fièvre
        if age <= 3 and temp > 38.5:
            return True
        # Hypertension sévère
        if tas >= 170 or tas <= 20:
            if tas >= 17:
                return True
        # Bradycardie / état de choc
        tas_val = tas if tas > 20 else tas * 10
        if fc <= 50 and tas_val <= 90:
            return True
        
        return False
    
    def classify_emergency(self, id_data, const_data, symptoms_json):
        input_text = self._prepare_input_text(id_data, const_data, symptoms_json)
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )

        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probabilities = F.softmax(logits, dim=-1)

        predicted_class_id = torch.argmax(probabilities, dim=-1).item()
        confidence = probabilities[0][predicted_class_id].item()
        score_final = self.id_to_label[predicted_class_id]

        # Vérification des conditions vitales prioritaires
        if self.check_vital_emergency_rules(id_data, const_data):
            score_final = "ROUGE"
            confidence = 1.0

        return {
            "niveau": score_final,
            "confiance": confidence,
            "resume_analyse": input_text,
            "probabilites": {
                label: probabilities[0][idx].item()
                for label, idx in self.label_to_id.items()
            }
        }

# -------------------------
# EXEMPLE D'UTILISATION
# -------------------------
classifier = ClinicalTriageClassifier()

result = classifier.classify_emergency(
    id_data={"genre": "F", "age": 65},
    const_data={"fc": 85, "spo2": 92, "temp": 38.2, "tas": 150, "tad": 90},
    symptoms_json={
        "symptomes_principaux": ["toux", "fatigue", "fièvre"],
        "localisation": "thorax",
        "intensite_douleur": 4
    }
)

print(result)


Loading weights: 100%|██████████| 100/100 [00:00<00:00, 574.79it/s, Materializing param=distilbert.transformer.layer.5.sa_layer_norm.weight]   
DistilBertForSequenceClassification LOAD REPORT from: medicalai/ClinicalBERT
Key                     | Status     | 
------------------------+------------+-
vocab_transform.bias    | UNEXPECTED | 
vocab_transform.weight  | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_projector.weight  | UNEXPECTED | 
classifier.weight       | MISSING    | 
pre_classifier.bias     | MISSING    | 
pre_classifier.weight   | MISSING    | 
classifier.bias         | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


TypeError: ClinicalTriageClassifier.check_vital_emergency_rules() takes 2 positional arguments but 3 were given

✅ Accelerate version: 1.12.0
