In [3]:
from typing import Callable, Optional
from torch import Tensor
from PIL.Image import Image as PILImage
from torchvision.transforms.functional import pil_to_tensor

type ImageType  = str | Tensor | PILImage


class ImageToTensor:
    foo = lambda x: x
    # mappes the type of input to the correct convertion function
    type_to_function: dict[type, Callable] = {
            #TODO: handle str case
            str: foo,
            Tensor: lambda x: x,
            PILImage: pil_to_tensor
        }
    
    @classmethod
    def convert(cls, image: ImageType) -> Tensor:
        if not type(image) in cls.type_to_function:
            raise NotImplementedError(f'Type {type(image)} is not supported for prediction')
        return cls.type_to_function[type(image)](image)




class SegmentationModelAI():
    def __init__(self, model, preprocessor: Optional[Callable] = None):
        #TODO support torch and ONNX
        self.model = model
        self.model.eval()
        self.preprocessor = preprocessor

    def predict_batch(self, batch: Tensor) -> Tensor:
        return self.model(batch)

    def predict_single(self, tensor: Tensor) -> Tensor:
        return self.predict_batch(tensor.unsqueeze(0))

    def __call__(self, image: ImageType) -> Tensor:
        tensor  = ImageToTensor.convert(image)
        if self.preprocessor:
            tensor = self.preprocessor(tensor)
        return self.predict_single(tensor)

In [45]:
from abc import ABC, abstractmethod
from torchvision.models.segmentation.deeplabv3 import DeepLabV3 as DeepLabV3Type
from torch import TensorType
type ModelType = rt.capi.onnxruntime_inference_collection.InferenceSession

class ModelWrapper():
    def __init__(self, model, preprocessor: Optional[Callable] = None):
        self.__model = None
        self.model = model
        self.preprocessor = preprocessor

    @property
    @abstractmethod
    def supported_model_types(self):
        """Return a tuple of supported model types for the subclass."""
        raise NotImplementedError

    @property
    def model(self):
        return self.__model
    
    @model.setter
    def model(self, new_model):
        print(new_model, type(new_model))
        if not self.validate_model_type(new_model):
            raise TypeError(f"Unsupported model type. Supported types are: {self.supported_model_types}")
        self.__model = new_model


    def validate_model_type(self, new_model) -> bool:
        return isinstance(new_model, self.supported_model_types)
    
    
    @abstractmethod
    def predict_single(self, *args, **kwargs):
        raise NotImplementedError
    
    @abstractmethod
    def predict_batch(self, *args, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplemented
    

class DeepLabV3Wrapper(ModelWrapper):
    
    @property
    def supported_model_types(self):
        return (DeepLabV3Type,)

    def predict_batch(self, batch_images: TensorType) -> TensorType:
        """Predicts the output of the model for a batch of images.

        Args:
            batch_images (TensorType): A tensor of shape [N, 3, H, W]

        Returns:
            TensorType: The model prediction 
        """
        assert len(batch_images.shape) == 4, 'Missmatching tensor dimensions. The image should be of shape [N, 3, H, W]'
        assert batch_images.shape[1] == 3, f'Missmatching tensor dimensions. The image should have 3 collor images but found {batch_images.shape[1]}'
        if self.preprocessor:
            batch_images = self.preprocessor(batch_images)
        return self.model(batch_images)
    
    def predict_single(self, image: Tensor) -> TensorType:
        """Predicts the output of the model for a single image.

        Args:
            image (Tensor): The image tensor of shape [3, H, W]

        Returns:
            TensorType: The model prediction (logits)
            TODO: change from logit to classes
        """
        assert len(image.shape) == 3, 'Missmatching tensor dimensions. The image should be of shape [3, H, W]'
        assert image.shape[0] == 3, f'Missmatching tensor dimensions. The image should have 3 collor images but found {image.shape[1]}'
        return self.predict_batch(image.unsqueeze(0))



1 <class 'int'>


TypeError: Unsupported model type. Supported types are: (<class 'str'>,)

In [39]:

#Test


from PIL import Image
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
# Step 1: Initialize model with the best available weights
weights = DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
model = deeplabv3_mobilenet_v3_large(weights=weights)
model.eval()
preprocess = weights.transforms()
url, filename = ("https://github.com/pytorch/hub/raw/master/images/deeplab1.png", "deeplab1.png")
input_image = Image.open(filename)
input_image = input_image.convert("RGB")

In [40]:
seg = SegmentationModelAI(model, preprocess)
seg([1,2,3])


NotImplementedError: Type <class 'list'> is not supported for prediction

'2'

tensor([[1., 2., 3.]])

In [26]:
class Geeks:
    course = 'DSA'
    list_of_instances = []

    def _init_(self, name):
        self.name = name
        Geeks.list_of_instances.append(self)

    @classmethod
    def get_course(cls):
        return f"Course: {cls.course}"
Geeks.get_course()

'Course: DSA'