In [11]:
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import shutil

In [12]:
MODEL_PATH = "models/binary_cnn_model2.pth"
INPUT_DIR = "test_images"
OUTPUT_DIR = "new_images"
IMG_SIZE = 224
THRESHOLD = 0.5 #for sigmoid output

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

os.makedirs(OUTPUT_DIR, exist_ok=True)

Using device: cuda


In [13]:
class BinaryCNN(nn.Module):
    def __init__(self, input_shape=(3, 180, 180)):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            features_out = self.features(dummy)
            self.num_features = features_out.numel()

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.num_features, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.classifier(self.features(x))


In [14]:
from torch.serialization import add_safe_globals
add_safe_globals({"BinaryCNN": BinaryCNN})

In [15]:
model = torch.load(
    MODEL_PATH,
    map_location = DEVICE,
    weights_only = False
)

model.to(DEVICE)
model.eval()

print("Model loaded successfully")

Model loaded successfully


## Image Preprocessing

In [16]:
#transforms
image_size = (180, 180)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std = [0.5, 0.5, 0.5]
    )
])

In [17]:
#preprocessing image
def preprocess_image(path):
    img = Image.open(path). convert("RGB")
    img = transform(img)
    return img.unsqueeze(0)

## Running Inference

In [18]:
results = []

with torch.no_grad():
    for file in os.listdir(INPUT_DIR):

        if file.lower().endswith((".jpg", ".png", ".jpeg")):
            path = os.path.join(INPUT_DIR, file)

            img = preprocess_image(path).to(DEVICE)
            output = model(img)

            prob = torch.sigmoid(output).item()
            is_positive = prob >= THRESHOLD

            print(f"{file} -> confidence : {prob:.2f}")
            if is_positive:
                shutil.copy(path, os.path.join(OUTPUT_DIR, file))

            results.append((file, prob, is_positive))

print("Inference done.")

people.jpg -> confidence : 0.97
Photo8.jpg -> confidence : 0.61
paint7.jpg -> confidence : 0.08
text.png -> confidence : 0.00
waves.png -> confidence : 0.00
Schematics6.png -> confidence : 0.00
paint_person.jpg -> confidence : 0.98
Text7.png -> confidence : 0.00
Inference done.
