In [None]:
import torch
from torchvision import models, transforms
from PIL import Image
import json
import requests
from io import BytesIO

In [None]:
class ImagePredicator:
    def __init__(self):
        pass
    
    def predict_image(self,image,top_k=1): 
        
        model = models.inception_v3(pretrained=True)
        model.eval()  

        # Create the preprocessing pipeline
        preprocess = transforms.Compose([
            transforms.Resize(299),  
            transforms.CenterCrop(299),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        
        image = Image.open(image)
            
            
        if image.mode != 'RGB':
            image = image.convert('RGB')
            
            
        input_tensor = preprocess(image)
        input_batch = input_tensor.unsqueeze(0)  # Add batch dimension
            
            
        if torch.cuda.is_available():
            input_batch = input_batch.to('cuda')
            model.to('cuda')
            
            # Perform inference
        with torch.no_grad():
            output = model(input_batch)
            
           
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
            
            
        top_probs, top_indices = torch.topk(probabilities, top_k)
            
            
        try:
            with open('imagenet_class_index.json') as f:
                class_idx = json.load(f)
        except FileNotFoundError:
            url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
            response = requests.get(url)
            class_names = response.json()
            class_idx = {str(i): [str(i), name] for i, name in enumerate(class_names)}
            with open('imagenet_class_index.json', 'w') as f:
                json.dump(class_idx, f)
            
            
        results = []
        for i in range(top_k):
            idx = top_indices[i].item()
            results.append({
                'class': class_idx[str(idx)][1],
                'probability': top_probs[i].item()
            })
            
        return results

       
image_path = "C:\\Users\\rppaw\\OneDrive\\Pictures\\Screenshots\\screen.png"  
i = ImagePredicator()
predictions = i.predict_image(image_path, top_k=5)
print(predictions[0]["class"])


tiger cat
