In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import os 
from torchvision import datasets, transforms 


In [2]:
# Defining model architecture
class BananaModel(nn.Module):
    def __init__(self, num_classes=4):
        super(BananaModel, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.ANN = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128*56*56, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )
    def forward(self, x):
        x = self.feature(x)
        x = self.ANN(x)
        return x

In [3]:
# Checking if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Loading the weights from already trained model (FROM main_banana.ipynb file)
model = BananaModel()
model= model.to(device)
model.load_state_dict(torch.load(os.path.join( "banana_model2.pth")))

  model.load_state_dict(torch.load(os.path.join( "banana_model2.pth")))


<All keys matched successfully>

In [5]:
# Transform function for the image 
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [6]:
#For Single Image Prediction
def predict_single(image_path):
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        output = model(image_tensor)
        _, predicted = torch.max(output, 1)
    return predicted.item()

In [7]:
#For Batch Prediction
def predict_batch(image_folder):
    image_dataset = datasets.ImageFolder(image_folder, transform=transform)
    image_loader = DataLoader(image_dataset, batch_size=32, shuffle=False)

    predictions = []
    with torch.no_grad():
        for images, _ in image_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
    return predictions