In [None]:
import torch
import torch.nn as nn
from torchvision import models
from transformers import AutoTokenizer, AutoModelForTokenClassification
from torchvision import transforms
from PIL import Image
import numpy as np


# Load the NER model (Named Entity Recognition)
model_name = "./trained_model"  # Path to the saved NER model
tokenizer = AutoTokenizer.from_pretrained(model_name)  # Load the tokenizer
ner_model = AutoModelForTokenClassification.from_pretrained(model_name)  # Load the NER model

# Load the pre-trained image classification model
#cv_model = torch.load("best_model_image.pth")  # Load the image classification model weights
#cv_model.eval()  # Set the model to evaluation mode (important for inference)

cv_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)  # Загружаем базовую модель с весами
num_ftrs = cv_model.fc.in_features
cv_model.fc = nn.Linear(num_ftrs, 10)  # Восстанавливаем последний слой

# Загрузите веса модели
cv_model.load_state_dict(torch.load("best_model_image.pth", map_location=torch.device("cpu")))

# Установите режим оценки
cv_model.eval()





# Image transformation pipeline for pre-processing images before passing them into the model
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to 224x224 (required for most image classification models)
    transforms.ToTensor(),  # Convert the image to a tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize the image (standard ImageNet values)
])

# Function to extract animal names from text using NER (Named Entity Recognition)
def extract_animal(text):
    # Tokenize the input text and convert it to tensor format for the model
    tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    # Pass tokens through the NER model to get predictions
    outputs = ner_model(**tokens)
    # Get the predicted labels for each token
    predictions = torch.argmax(outputs.logits, dim=-1)[0].tolist()
    # Convert token IDs back to words
    words = tokenizer.convert_ids_to_tokens(tokens["input_ids"][0])

    # List to store detected animal names
    entities = []
    for token, pred in zip(words, predictions):
        # If the token is labeled as an "animal" (class 1), add it to the list
        if pred == 1:  # Class 1 represents ANIMAL entities in this model
            entities.append(token.replace("##", ""))  # Remove subword tokenization (e.g., "##" in BERT tokens)

    # Return the detected animal names as a single string
    return " ".join(entities)

# Function to classify an image and predict the animal in the image
def classify_image(image_path):
    # Open the image and convert it to RGB (if not already)
    image = Image.open(image_path).convert("RGB")
    # Apply the image transformation pipeline (resize, normalize, etc.)
    image = transform(image).unsqueeze(0)  # Add batch dimension (necessary for model input)
    animal_classes = ['beaver', 'dolphin', 'otter', 'seal', 'fox', 'spider', 'elephant', 'bear', 'rabbit', 'tiger']
    with torch.no_grad():  # No need to compute gradients for inference
        outputs = cv_model(image)  # Pass the image through the model
        _, predicted = torch.max(outputs, 1)  # Get the index of the predicted class (animal)
    
    return animal_classes[predicted.item()]  # Return the predicted class index

# Main pipeline to verify if the text and image describe the same animal
def verify_claim(text, image_path):
    # Extract the animal name(s) from the text using the NER model
    text_animal = extract_animal(text)
    # Classify the image and get the predicted animal class
    image_animal = classify_image(image_path)
    
    # Print the results for both the text and the image
    print(f"From the text: {text_animal}")
    print(f"From the image: {image_animal}")

    # Compare the predicted animal from the text and the image
    return text_animal.lower() == image_animal.lower()  # Return True if both match, else False

# Example test
text_input = input(str('Input text: '))  # Prompt the user to input text
image_path = "Download_file.jpg"  # Path to the image file (could be dynamically passed in real use case)
result = verify_claim(text_input, image_path)  # Run the verification function

# Print the result
print("✅ Correct!" if result else "❌ Incorrect!")  # Output whether the claim is correct or not