In [3]:
from typing import Callable, Optional
from torch import Tensor










In [66]:
from abc import ABC, abstractmethod
from torchvision.models.segmentation.deeplabv3 import DeepLabV3 as DeepLabV3Type
from torch import Tensor, no_grad
from typing import Optional, Callable

class ModelWrapper():
    def __init__(self, model, preprocessor: Optional[Callable] = None):
        self.__model = None
        self.model = model
        self.__preprocessor = None
        self.preprocessor = preprocessor

    @property
    @abstractmethod
    def supported_model_types(self):
        """Return a tuple of supported model types for the subclass."""
        raise NotImplementedError

    @property
    def model(self):
        return self.__model
    
    @property
    def preprocessor(self):
        assert self.__preprocessor, 'preprocessor is not defined'
        return self.__preprocessor
    
    @preprocessor.setter
    def preprocessor(self, new_preprocessor):
        self.__preprocessor = new_preprocessor
    
    @model.setter
    def model(self, new_model):
        print(new_model, type(new_model))
        if not self.validate_model_type(new_model):
            raise TypeError(f"Unsupported model type. Supported types are: {self.supported_model_types}")
        self.__model = new_model


    def validate_model_type(self, new_model) -> bool:
        return isinstance(new_model, self.supported_model_types)
    
    
    @abstractmethod
    def predict_single(self, *args, **kwargs):
        raise NotImplementedError
    
    @abstractmethod
    def predict_batch(self, *args, **kwargs):
        raise NotImplementedError

    @abstractmethod
    def __call__(self, *args, **kwargs):
        raise NotImplemented
    

class DeepLabV3Wrapper(ModelWrapper):
    
    @property
    def supported_model_types(self):
        return (DeepLabV3Type,)

    def predict_batch(self, batch_images: Tensor, preprocess: bool = True) -> Tensor:
        """Predicts the output of the model for a batch of images.

        Args:
            batch_images (Tensor): A tensor of shape [N, 3, H, W]
            preprocess (bool): should the tensor be preprocessed first.

        Returns:
            Tensor: The model prediction 
        """
        assert len(batch_images.shape) == 4, 'Missmatching tensor dimensions. The image should be of shape [N, 3, H, W]'
        assert batch_images.shape[1] == 3, f'Missmatching tensor dimensions. The image should have 3 collor images but found {batch_images.shape[1]}'
        if preprocess:
            batch_images = self.preprocessor(batch_images)
        self.model.eval()
        with no_grad():
            return self.model(batch_images)['out']
    
    def predict_single(self, image: Tensor, preprocess: bool = True) -> Tensor:
        """Predicts the output of the model for a single image.

        Args:
            image (Tensor): The image tensor of shape [3, H, W]
            preprocess (bool): should the tensor be preprocessed first.

        Returns:
            Tensor: The model prediction (logits)
            TODO: change from logit to classes
        """
        assert len(image.shape) == 3, 'Missmatching tensor dimensions. The image should be of shape [3, H, W]'
        assert image.shape[0] == 3, f'Missmatching tensor dimensions. The image should have 3 collor images but found {image.shape[1]}'
        if preprocess:
            image = self.preprocessor(image)
        return self.predict_batch(image.unsqueeze(0)).squeeze(0)



In [39]:

#Test


from PIL import Image
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
# Step 1: Initialize model with the best available weights
weights = DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
model = deeplabv3_mobilenet_v3_large(weights=weights)
model.eval()
preprocess = weights.transforms()
url, filename = ("https://github.com/pytorch/hub/raw/master/images/deeplab1.png", "deeplab1.png")
input_image = Image.open(filename)
input_image = input_image.convert("RGB")

In [68]:
from torchvision.io.image import read_image
import torch
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import urllib


url, filename = ("https://github.com/pytorch/hub/raw/master/images/deeplab1.png", "deeplab1.png")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
input_image = Image.open(filename)
input_image = input_image.convert("RGB")
img = input_image

# Step 1: Initialize model with the best available weights
weights = DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
model = deeplabv3_mobilenet_v3_large(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and visualize the prediction
out = model(batch)
prediction = out['out']


In [67]:
m1 = DeepLabV3Wrapper(model, preprocess)
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.PILToTensor()
])
img_tensor = transform(img)
m1.predict_single(img_tensor).shape


DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride

torch.Size([21, 520, 649])

tensor([[1., 2., 3.]])

In [26]:
class Geeks:
    course = 'DSA'
    list_of_instances = []

    def _init_(self, name):
        self.name = name
        Geeks.list_of_instances.append(self)

    @classmethod
    def get_course(cls):
        return f"Course: {cls.course}"
Geeks.get_course()

'Course: DSA'