In [None]:
import torch.nn as nn
import torch.nn.functional as F

class CustomModel(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(CustomModel, self).__init__()
        
        self.features = nn.Sequential(
            # First block
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.2),
            
            # Second block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.3),
            
            # Third block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.4),
        )
        
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((7, 7)),
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 2)
        )
        
        # Initialize weights
        self._initialize_weights()
    
    # added a pre-initialization of weights
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import splitfolders

class CustomDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.data_folder = data_folder
        self.transform = transform

        self.class_names = sorted(os.listdir(data_folder))
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)}
        self.image_paths = []
        self.labels = []
        self.data = []

        for class_name in self.class_names:
            class_folder = os.path.join(data_folder, class_name)
            class_label = self.class_to_idx[class_name]
            for filename in os.listdir(class_folder):
                img_path = os.path.join(class_folder, filename)
                self.image_paths.append(img_path)
                self.labels.append(class_label)
                image = Image.open(img_path)
                self.data.append(np.array(image)) 

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
from torchvision import transforms

transform_test= transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),          
            transforms.Resize((224, 224)),    
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485],
                std=[0.229]
            ),
        ])

In [None]:
model = CustomModel()

model_data = './model_checkpoints/best_model.pth' # path del model da caricare

state_dict = torch.load(model_data)

# N.B. La key per state_dict 'model_state_dict' puo variare in base a come viene generato il .pth
model.load_state_dict(state_dict['model_state_dict'])
model.eval()

In [None]:
from torch.utils.data import DataLoader
import os

test_data_folder = "./datasets/split_images/test/"
test_dataset = CustomDataset(test_data_folder, transform=transform_test)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

model.eval()
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for i, (inputs, labels) in enumerate(test_loader):
        if i >= 1000:  # Break after processing 1000 images
            break
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        predicted_label = predicted.item()
        predicted_class_name = test_dataset.class_names[predicted_label]
        true_label = labels.item()
        true_class_name = test_dataset.class_names[true_label]

        print(f"predicted: {predicted_label} - {predicted_class_name}, actual: {true_label} - {true_class_name}")

        total_predictions += 1
        if predicted_label == true_label:
            correct_predictions += 1

accuracy = correct_predictions / total_predictions
print(f"\nAccuracy: {accuracy*100:.2f}%")