In [7]:
!pip install --user -U nltk



In [10]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [27]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import nltk

# Load pre-trained BioBERT model for NER
tokenizer = AutoTokenizer.from_pretrained("alvaroalon2/biobert_genetic_ner")
model = AutoModelForTokenClassification.from_pretrained("alvaroalon2/biobert_genetic_ner")

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Example article text
article_text = "The TP53 gene encodes a tumor suppressor protein that regulates cell division and prevents cancer. BRCA1 and BRCA2 genes are associated with hereditary breast and ovarian cancer."

# Tokenize the article text into words (preserve word-level segmentation)
words = nltk.word_tokenize(article_text)

# Tokenize and process the sentence
inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

# Run the model and get the predictions
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits

# Get the predicted labels for each token
predicted_labels = torch.argmax(logits, dim=2).cpu().numpy()

# Convert the predicted labels to human-readable format
label_map = model.config.id2label  # This maps the label indices to the actual label names

# Extract entities from the text
extracted_entities = []
current_entity = None

for word, label, token_id in zip(words, predicted_labels[0], input_ids[0].cpu().numpy()):
    label_name = label_map[label]

    # Convert token ID to token (wrap the token_id into a list for conversion)
    token = tokenizer.convert_ids_to_tokens([token_id])[0]  # Extract the token from the list

    if label_name == "O":  # Skip non-entity tokens
        continue
    elif label_name == "B-GENETIC":  # Start of a new entity
        if current_entity:
            extracted_entities.append(current_entity)
        current_entity = token  # Start a new entity with the current token
    elif label_name == "I-GENETIC":  # Continuation of an entity
        if token.startswith("##"):  # If it's a subword token, append to the entity
            current_entity += token[2:]
        else:
            current_entity += " " + token

# Add the last entity if it exists
if current_entity:
    extracted_entities.append(current_entity)

# Print the extracted gene/protein entities
for entity in extracted_entities:
    print(f"Entity: {entity}")


Entity: TP53 gene
Entity: BRCA1
Entity: BRCA2
