In [None]:
import logging
from telegram import Update
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
from torch import argmax

from model import AkhundModel
import nest_asyncio

nest_asyncio.apply()

logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AkhundModel(2).to(device)
model.load_state_dict(torch.load('model.pth'))

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


def predict(image_path):
    image = Image.open(image_path)
    image = transform(image)
    image = image.unsqueeze(0)
    image = image.to(device)
    outputs = model(image)
    predict = argmax(outputs)
    if predict.item() == 0:
        return 'از ما نیست'
    elif predict.item() == 1:
        return 'از ما هست'


async def start(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
    await update.message.reply_text("سلام! عکس خود را ارسال کنید.")


async def handle_photo(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
    photo = await update.message.photo[-1].get_file()
    image_path = "received.jpg"
    await photo.download_to_drive(image_path)

    prediction = predict(image_path)
    await update.message.reply_text(f"نتیجه پیش‌بینی: {prediction}")


def main():
    TOKEN = "TOKEN"
    app = Application.builder().token(TOKEN).build()

    app.add_handler(CommandHandler("start", start))
    app.add_handler(MessageHandler(filters.PHOTO, handle_photo))

    app.run_polling()


if __name__ == "__main__":
    main()