In [5]:
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn

import numpy as np
import time
import os
from PIL import Image
import copy
import validators

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

from PIL import Image
import requests
from io import BytesIO


label_map={
    0:"Chickenpox",
    1:"Measles",
    2:"Monkeypox",
    3:"Normal"
}
classes = ('Chickenpox', 'Measles', 'Monkeypox', 'Normal')
PATH = './resnet18_net.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                     transforms.Resize((64,64)),
                                     transforms.ToTensor()])

In [6]:
def load_model():
	'''
	load a model 
	by default it is resnet 18 for now
	'''
	model = models.resnet18(pretrained=True)
	num_ftrs = model.fc.in_features
	model.fc = nn.Linear(num_ftrs, len(classes))
	model.to(device)

	model.load_state_dict(torch.load(PATH,map_location=device))
	model.eval()
	return model



In [7]:
def predict(model, image_url):
	'''
	pass the model and image url to the function
	Returns: a list of pox types with decreasing probability
	'''
	if validators.url(image_url) is True:	
		response = requests.get(image_url)		
		picture = Image.open(BytesIO(response.content))
	else:
		picture = Image.open(image_url)
	# Convert the image to grayscale and other transforms
	image = data_transform(picture)
	# store in a list of images
	images=image.reshape(1,1,64,64)
	new_images = images.repeat(1, 3, 1, 1)
	outputs=model(new_images)
	# get prediction
	_, predicted = torch.max(outputs, 1)
	ranked_labels=torch.argsort(outputs,1)[0]
	# get all classes in order of probability
	probable_classes=[]
	for label in ranked_labels:
	    probable_classes.append(classes[label.numpy()])
	probable_classes.reverse()
	return probable_classes

In [8]:
model=load_model()
print("Model loaded")
# normal
image_url="https://drive.google.com/uc?export=view&id=14sF_FaFvfYzrQCCQRX6IK87aBPFerfWb"
print(predict(model, image_url),"should be normal")

image_url="data/Normal/normalgray_aug14.jpg"
print(predict(model, image_url),"should be normal")

Model loaded
['Normal', 'Monkeypox', 'Chickenpox', 'Measles'] should be normal
['Normal', 'Chickenpox', 'Measles', 'Monkeypox'] should be normal
