In [1]:
import os
import torch
import numpy as np
from PIL import Image, ImageCms
from torchvision import transforms
from utils.fundus_prep import PreprocessEyeImages
from utils.densenet_mcf import dense121_mcs


class EyeQ:

    def __init__(self):
       
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.classes = ['Good', 'Usable', 'Bad']

        self.model = dense121_mcs(n_class= 3)
        loaded_model = torch.load(os.path.join('weights/densenet_weight.tar') , map_location=self.device )
        self.model.load_state_dict(loaded_model['state_dict'])
        
        self.model.eval()
        self.model = self.model.to(self.device)


    def infer(self, image_path):
        image = Image.open(image_path)
        imagesA, imagesB, imagesC = self.preprocess(image)

        # Testing
        self.model.eval()
        
        imagesA = imagesA.to(self.device)
        imagesB = imagesB.to(self.device)
        imagesC = imagesC.to(self.device)

        # create a mini-batch as expected by the model
        imagesA = imagesA.unsqueeze(0)
        imagesB = imagesB.unsqueeze(0)
        imagesC = imagesC.unsqueeze(0)

        _, _, _, _, result_mcs = self.model(imagesA, imagesB, imagesC)
        

        pred= torch.argmax(result_mcs,1).cpu()
        return self.classes[ pred.item() ]
        

    
    def preprocess(self, sample):
        image = sample.convert('RGB')

        transform1 = transforms.Compose([
                PreprocessEyeImages(),
                transforms.Resize(224),
                transforms.CenterCrop(224),
            ])
        image = transform1(image)


        srgb_profile = ImageCms.createProfile("sRGB")
        lab_profile = ImageCms.createProfile("LAB")
        rgb2lab_transform = ImageCms.buildTransformFromOpenProfiles(srgb_profile, lab_profile, "RGB", "LAB")
        
        img_hsv = image.convert("HSV")
        img_lab = ImageCms.applyTransform(image, rgb2lab_transform)

        img_rgb = np.asarray(image).astype('float32')
        img_hsv = np.asarray(img_hsv).astype('float32')
        img_lab = np.asarray(img_lab).astype('float32')


        transform2 = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                    [0.229, 0.224, 0.225])
            ])
        img_rgb = transform2(img_rgb)
        img_hsv = transform2(img_hsv)
        img_lab = transform2(img_lab)

        
        return torch.FloatTensor(img_rgb), torch.FloatTensor(img_hsv), torch.FloatTensor(img_lab)


In [2]:
model_quality = EyeQ()
quality_image = model_quality.infer("assets/amd_retina.jpg")
print(f'Your image is rated as: {quality_image}')

Your image is rated as: Good
