## Image Preprocessing

In [22]:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import PIL.Image as Image


# Load the pre-trained ResNet model
model = models.resnet18(pretrained=True)

# Remove the final fully connected layer
model = nn.Sequential(*list(model.children())[:-1])

# Freeze the model weights
for param in model.parameters():
    param.requires_grad = False
    
# Define a custom classifier
class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
# Initialize the custom classifier
num_classes = 5 # The number of classes you are interested in
classifier = Classifier(num_classes)

# Define a transform to pre-process the input data
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load an example input image
image = Image.open("../imagetest.jpg")

# Pre-process the input image
input_image = transform(image).unsqueeze(0)

# Extract features from the input image
features = model(input_image)

# Use the custom classifier to make a prediction
prediction = classifier(features)

# Print the prediction
print(prediction)



tensor([[-0.8960,  1.1506,  0.9418, -0.6109,  0.0220]],
       grad_fn=<AddmmBackward0>)
