In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import os


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model(model_path, num_classes, device=device):
    model = models.efficientnet_b3(pretrained=False)
    model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    model = model.to(device)

    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    return model

def create_predict_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def predict_image(model, image_path, transform, class_names, device=device):
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)
        
        with torch.no_grad():
            outputs = model(image)
            probs = F.softmax(outputs, dim=1)
            max_prob, preds = torch.max(probs, 1)
        
        return class_names[preds.item()], max_prob.item()
    
    except Exception as e:
        print(f"Ошибка при обработке изображения: {e}")
        return None

train_dir = 'Birds_test_200/Train'

MODEL_PATH = 'best_model.pth'
CLASS_NAMES = [name for name in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, name))]
NUM_CLASSES = len(CLASS_NAMES)
# with open("class_name.txt", "w", encoding="utf-8") as f:
#     f.write("\n".join(CLASS_NAMES))

# TEST_IMAGE = "C:\\Users\\asus\\Downloads\\Bay-breasted_Warbler_-_21182699918.jpg"
TEST_IMAGE = "C:\\Users\\asus\\Downloads\\Screenshot_5.jpg"



model = load_model(MODEL_PATH, NUM_CLASSES, device)

predict_transform = create_predict_transform()

# print("Предсказание класса...")
result = predict_image(
    model=model,
    image_path=TEST_IMAGE,
    transform=predict_transform,
    class_names=CLASS_NAMES,
    device=device
    )

print(result)

('Cardinal', 0.8680009841918945)


Создание конфига

In [14]:
import json

config = {
    "database": {
        "path": "database.db"
    }
}

In [15]:
with open("web/config.json", "w") as config_file:
    json.dump(config, config_file, indent=4)

with open("web/config.json", "r") as config_file:
    config = json.load(config_file)

    db_path = config["database"]["path"]
print(db_path)

def change_database_path(new_path):
    with open("web/config.json", "r") as config_file:
        config = json.load(config_file)

    config["database"]["path"] = new_path

    with open("web/config.json", "w") as config_file:
        json.dump(config, config_file, indent=4)


database.db


In [16]:
import os
import sqlite3

def CreateNewDatabase():
    db_path = "database.db"
    index = 1
    while os.path.exists(db_path):
        db_path = f"database_{index}.db"
        index += 1
    conn = sqlite3.connect(db_path)
    conn.close()

CreateNewDatabase()