In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os

dataset_dir = "dataset_orchidee"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),  
    transforms.RandomRotation(10),    
    transforms.RandomResizedCrop(size=(224, 224), antialias=True),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder('dataset_split/train', transform=transform)

mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)

for param in mobilenet.features.parameters():
    param.requires_grad = False

dataset_dir = "dataset_orchidee"
num_classes = len(os.listdir(dataset_dir))


num_features = mobilenet.last_channel
mobilenet.classifier = nn.Sequential(                                
    nn.Linear(num_features, num_classes),   
)

mobilenet.load_state_dict(torch.load("mobilenet_finetuned.pth"), strict=False)

mobilenet.eval() 


MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [8]:
def predict_image(image_path, model, device, transform, classes):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)
    
    model.eval()
    with torch.no_grad():
        image = image.to(device)
        output = model(image)
        _, predicted = torch.max(output, 1)

    return classes[predicted.item()]

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_path = "FE3.png" 
predicted_class = predict_image(
    image_path=image_path,
    model=mobilenet,
    device=device,
    transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]),
    classes=train_dataset.classes
)

print(f"L'immagine {image_path} è stata classificata come: {predicted_class}")

L'immagine FE3.png è stata classificata come: O. exaltata
