<a href="https://colab.research.google.com/github/RicardoPoleo/DeepLearning_FactChecker/blob/main/notebooks/Agents/ModelAgentA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Install Dependencies
!pip install sentence-transformers torch transformers

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import logging

# Setup basic configuration for logging
logging.basicConfig(level=logging.INFO)

class ClaimAnalysisAgent:
    def __init__(self, model_name="Clinical-AI-Apollo/Medical-NER", tokenizer_name=None, max_length=512):
        if tokenizer_name is None:
            tokenizer_name = model_name
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
            self.model = AutoModelForTokenClassification.from_pretrained(model_name)
            logging.info(f"Model and tokenizer loaded successfully from {model_name}.")
        except Exception as e:
            logging.error(f"Failed to load model or tokenizer: {e}")
            raise

        special_tokens = ["COVID-19"]
        self.tokenizer.add_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))
        self.max_length = max_length

    def analyze_claim(self, texts):
        if isinstance(texts, str):
            texts = [texts]  # Allow single string input for convenience
        inputs = self.tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=self.max_length)
        outputs = self.model(**inputs).logits
        predictions = torch.argmax(outputs, dim=2)

        results = []
        for i in range(len(texts)):
            tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][i])
            labels = [self.model.config.id2label[p.item()] for p in predictions[i]]
            entities = self.extract_entities(tokens, labels)
            results.append(entities)
        return results

    def extract_entities(self, tokens, labels):
        entities = []
        current_entity = []
        for token, label in zip(tokens, labels):
            if token in ["[CLS]", "[SEP]"]:
                continue
            if token.startswith("▁"):
                token = token[1:]  # Handling tokenization artifacts
            if label != "O":
                if token.startswith("##"):
                    token = token[2:]
                if current_entity and not token.startswith("##"):
                    entities.append(" ".join(current_entity))
                    current_entity = []
                current_entity.append(token)
            else:
                if current_entity:
                    entities.append(" ".join(current_entity))
                    current_entity = []

        if current_entity:
            entities.append(" ".join(current_entity))

        return list(set(entities))  # Removing duplicates if needed

# Example usage
agent = ClaimAnalysisAgent()
text1 = "Vitamin C cures cancer. COVID-19 is a global pandemic."
entities1 = agent.analyze_claim(text1)
print("Extracted entities:", entities1)
