In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
from torchvision import transforms

In [24]:
class CNN(nn.Module):
  def __init__(self, inChannels, numClasses):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(in_channels = inChannels, out_channels= 16, kernel_size = 3, stride = 1, padding = 1)
    self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
    self.conv2 = nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, stride = 1, padding = 1)
    self.fc1 = nn.Linear(32 * 16 * 16, 128)
    self.fc2 = nn.Linear(128, numClasses)
    self.dropfc1 = nn.Dropout(p=0.2)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = x.view(-1, 32 * 16 * 16)
    x = F.relu(self.dropfc1(self.fc1(x)))
    x = self.fc2(x)
    return x

device = "cpu"
model = CNN(inChannels = 3, numClasses = 10).to(device)
model.load_state_dict(torch.load("model.pth", map_location=device))
model.eval()

CNN(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=8192, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (dropfc1): Dropout(p=0.2, inplace=False)
)

In [41]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5],
        std=[0.5]
    )
])

img = Image.open("wikinda.jpg").convert("RGB")
img_tensor = transform(img).unsqueeze(0)

eurosatClases = ["AnnualCrop",
"Forest",
"HerbaceousVegatation",
"Highway",
"Industrial",
"Pasture",
"PermanentCrop",
"Residential",
"River",
"SeaLake"]

In [42]:
with torch.no_grad():
    output = model(img_tensor)
    probs = torch.softmax(output, dim=1)
    conf, pred = torch.max(probs, dim=1)

print("Predicted class:", eurosatClases[pred.item()])
print("Confidence:", float(conf))

Predicted class: AnnualCrop
Confidence: 1.0
