<a href="https://colab.research.google.com/github/RicardoPoleo/DeepLearning_FactChecker/blob/main/notebooks/Agents/WebService_Agent_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Install dependencies
!pip install fastapi uvicorn transformers
!npm install -g localtunnel

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

class ClaimAnalysisAgent:
    def __init__(self):
        # Load the tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained("Clinical-AI-Apollo/Medical-NER")
        self.model = AutoModelForTokenClassification.from_pretrained("Clinical-AI-Apollo/Medical-NER")

        # Add special tokens (like "COVID-19") to tokenizer
        special_tokens = ["COVID-19"]
        self.tokenizer.add_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))

    def analyze_claim(self, text):
        # Tokenize input text with a specified max length
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)

        # Perform token classification
        outputs = self.model(**inputs).logits

        # Get the predicted token classes
        predictions = torch.argmax(outputs, dim=2)

        # Decode the tokens and labels
        tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        labels = [self.model.config.id2label[p.item()] for p in predictions[0]]

        # Extract and clean entities
        entities = []
        current_entity = []
        for token, label in zip(tokens, labels):
            if token in ["[CLS]", "[SEP]"]:
                continue  # Ignore special tokens

            if label != "O":  # Consider only entity labels
                current_entity.append(token)
            else:
                if current_entity:
                    entity = "".join(current_entity).replace("▁", " ").strip()
                    if entity not in entities:  # Avoid duplicates
                        entities.append(entity)
                current_entity = []

        if current_entity:
            entity = "".join(current_entity).replace("▁", " ").strip()
            if entity not in entities:  # Avoid duplicates
                entities.append(entity)

        return entities

# Example usage
claim_analysis_agent = ClaimAnalysisAgent()
text1 = "Vitamin C cures cancer. COVID-19 is a global pandemic."
entities1 = claim_analysis_agent.analyze_claim(text1)
print("Extracted entities:", entities1)

text2 = "A face-covering can prevent people who are asymptomatic carriers of COVID-19 from spreading the virus."
entities2 = claim_analysis_agent.analyze_claim(text2)
print("Extracted entities:", entities2)

In [None]:
#@title Load the model


from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import uvicorn
import subprocess
import threading

app = FastAPI()

# Load the tokenizer and model
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

# Load the NER pipeline
ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer)

class RequestModel(BaseModel):
    text: str

@app.post("/ner")
def perform_ner(request: RequestModel):
    entities = ner_pipeline(request.text)
    decoded_entities = []
    current_entity = {"entity": None, "score": 0, "word": ""}

    for entity in entities:
        word = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(entity['word']))
        if entity['entity'] != current_entity['entity']:
            if current_entity['entity'] is not None:
                decoded_entities.append(current_entity)
            current_entity = {"entity": entity['entity'], "score": entity['score'], "word": word}
        else:
            current_entity['word'] += word.replace("##", "")
            current_entity['score'] = max(current_entity['score'], entity['score'])

    if current_entity['entity'] is not None:
        decoded_entities.append(current_entity)

    return {"entities": decoded_entities}

def start_uvicorn():
    uvicorn.run(app, host="0.0.0.0", port=8000)

# Free the port before starting the server
!fuser -k 8000/tcp

thread = threading.Thread(target=start_uvicorn)
thread.start()

process = subprocess.Popen(["lt", "--port", "8000"], stdout=subprocess.PIPE)
for line in process.stdout:
    print(line.decode().strip())
