<a href="https://colab.research.google.com/github/Sanchit9587/Pokemon_Hack2_Guild_App/blob/NER%26NLP/Inference_NER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
!pip install transformers==4.41.2
!pip install torch==2.3.0



In [None]:
import json
import os
import torch
from transformers import BertTokenizerFast, BertForTokenClassification

In [8]:
print("Step 1: Setting up configuration...")

# --- Configuration Block ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DRIVE_MOUNT_PATH = "/content/drive"
MODEL_PATH = os.path.join(DRIVE_MOUNT_PATH, "MyDrive/PokemonNERModel")

print(f"Using device: {DEVICE}")
print(f"Loading model from: {MODEL_PATH}")

# ==============================================================================
# 2. Mount Google Drive
# ==============================================================================
print("\nStep 2: Mounting Google Drive...")
try:
    from google.colab import drive
    drive.mount(DRIVE_MOUNT_PATH, force_remount=True)
    print("Google Drive mounted successfully.")
except ImportError:
    print("Not in a Colab environment. Skipping Google Drive mount.")
except Exception as e:
    print(f"An error occurred during drive mounting: {e}")

# ==============================================================================
# 3. Load Model, Tokenizer, and Label Mappings
# ==============================================================================
print("\nStep 3: Loading model and tokenizer...")

if not os.path.exists(MODEL_PATH):
    print("\n!!! ERROR !!!")
    print(f"Model directory not found at '{MODEL_PATH}'.")
    print("Please ensure your model was saved correctly to your Google Drive.")
else:
    try:
        model = BertForTokenClassification.from_pretrained(MODEL_PATH)
        tokenizer = BertTokenizerFast.from_pretrained(MODEL_PATH)
        model.to(DEVICE)
        model.eval() # Set the model to evaluation mode

        # Load the tag mappings created during training
        with open(os.path.join(MODEL_PATH, 'tag_mappings.json'), 'r') as f:
            mappings = json.load(f)
            id2tag = mappings['id2tag']
            # Convert keys from string back to integer
            id2tag = {int(k): v for k, v in id2tag.items()}

        print("Fine-tuned model, tokenizer, and label mappings loaded successfully.")

    except Exception as e:
        print(f"An error occurred while loading the model: {e}")

# ==============================================================================
# 4. Inference Function
# ==============================================================================
print("\nStep 4: Defining the inference function...")

def get_pokemon_entities(text):
    """
    Performs NER on a given text and extracts only the enemy and friendly species.

    Args:
        text (str): The input military-style prompt.

    Returns:
        dict: A dictionary with 'enemy_species' and 'friendly_species' lists.
    """
    # Tokenize the input text
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=256
    ).to(DEVICE)

    # Get model predictions
    with torch.no_grad():
        logits = model(**inputs).logits

    predictions = torch.argmax(logits, dim=2)
    predicted_token_class = [id2tag[p.item()] for p in predictions[0]]

    # Reconstruct words from subword tokens and align labels
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    current_word = ''
    current_tag = 'O'
    words_and_tags = []

    for token, tag in zip(tokens, predicted_token_class):
        if token in ('[CLS]', '[SEP]', '[PAD]'):
            continue

        if token.startswith("##"):
            current_word += token[2:]
        else:
            # New word starts, so we process the previous word
            if current_word:
                words_and_tags.append((current_word, current_tag))

            current_word = token
            current_tag = tag

    # Add the last processed word
    if current_word:
        words_and_tags.append((current_word, current_tag))

    # Extract entities
    enemy_species = []
    friendly_species = []

    for word, tag in words_and_tags:
        if tag == 'B-ENEMY_SPECIES':
            enemy_species.append(word)
        elif tag == 'B-FRIENDLY_SPECIES':
            friendly_species.append(word)

    return {
        "enemy_species": list(set(enemy_species)), # Use set to get unique names
        "friendly_species": list(set(friendly_species))
    }

print("Inference function is ready.")


Step 1: Setting up configuration...
Using device: cpu
Loading model from: /content/drive/MyDrive/PokemonNERModel

Step 2: Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted successfully.

Step 3: Loading model and tokenizer...
Fine-tuned model, tokenizer, and label mappings loaded successfully.

Step 4: Defining the inference function...
Inference function is ready.


In [10]:
print(get_pokemon_entities("""HQ has detected unusual Bulbasaur activity in the area. Field
sensors logged anomalous behavior that suggests an imminent
threat. Remember there are Pikachu and Charizard nearby —
take care not to draw them into combat. You are to neutralize
the bulbasaurs immediately. Report status once the target is
down. Confirm mission status and any collateral damages"""))

{'enemy_species': ['bulbasaurs'], 'friendly_species': ['Charizard', 'Pikachu', 'Bulbasaur']}
