In [5]:
import torch
from torch import nn
from torchvision.transforms import transforms
from PIL import Image
import pickle

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cpu


In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        self.pooling = nn.MaxPool2d(2, 2)

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()

        self.linear = nn.Linear((128 * 16 * 16), 128)

        self.output = nn.Linear(128, 3)

    def forward(self, x):
        x = self.conv1(x)  # -> (32, 128, 128)
        x = self.pooling(x)  # -> (32, 64, 64)
        x = self.relu(x)

        x = self.conv2(x)  # -> (64, 64, 64)
        x = self.pooling(x)  # -> (64, 32, 32)
        x = self.relu(x)

        x = self.conv3(x)  # -> (128, 32, 32)
        x = self.pooling(x)  # -> (128, 16, 16)
        x = self.relu(x)

        x = self.flatten(x)
        x = self.linear(x)
        x = self.output(x)
        return x


In [3]:
model_load = Net().to(device)

model_load.load_state_dict(torch.load("animal_faces_model.pth"))
model_load.eval()

Net(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (relu): ReLU()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear): Linear(in_features=32768, out_features=128, bias=True)
  (output): Linear(in_features=128, out_features=3, bias=True)
)

In [6]:
with open("label_encoder.pkl", "rb") as f:
    label_encoder = pickle.load(f)

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float)
])

In [8]:
def predict_image_load(img_path):
    image = Image.open(img_path).convert("RGB")
    image = transform(image).to(device)
    prediction = model_load(image.unsqueeze(0))
    prediction = torch.argmax(prediction, axis=1).item()

    return label_encoder.inverse_transform([prediction])


predict_image_load("download.jpg")

array(['wild'], dtype=object)