<a href="https://colab.research.google.com/github/jabb4/AlexNet_Pokedex/blob/main/tools/images_tools.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torchvision.transforms as transforms

In [None]:
class Rescale(object):
  """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
  """
  def __init__(self, output_size):
    assert isinstance(output_size, (int, tuple))
    if isinstance(output_size, int):
      self.output_size = (output_size, output_size)
    else:
      assert len(output_size) == 2
      self.output_size = output_size
  
  def __call__(self, sample):
    image, label = sample["image"], sample["label"]

    # print(f"shape before resize : {image.shape}")
    
    resize = transforms.Resize(self.output_size)
    
    image = resize(image)
    
    # print(f"shape after resize : {image.shape}")

    return {"image":image, "label":label}

In [None]:
class RandomCrop(object):
  """Crop randomly the image in a sample.
  Args:
    output_size (tuple or int): Desired output size. If int, square crop is made
  """
  def __init__(self, output_size):
    assert isinstance(output_size, (int, tuple))
    if isinstance(output_size, int):
      self.output_size = (output_size, output_size)
    else:
      assert len(output_size) == 2
      self.output_size = output_size
  
  def __call__(self, sample):
    image, label = sample["image"], sample["label"]
    
    # print(f"after RandomCrop : {image.shape}")

    crop = transforms.RandomCrop(self.output_size,pad_if_needed=True)

    image = crop(image)

    # print(f"after RandomCrop : {image.shape}")
    # print("")

    return {"image":image, "label": label}

In [None]:
class Normalize(object):
  
  """Normalize a tensor image with mean and standard deviation.
  This transform does not support PIL Image.
  Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
  channels, this transform will normalize each channel of the input
  ``torch.*Tensor`` i.e.,
  ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  .. note::
      This transform acts out of place, i.e., it does not mutate the input tensor.
  Args:
      mean (sequence): Sequence of means for each channel.
      std (sequence): Sequence of standard deviations for each channel.
      inplace(bool,optional): Bool to make this operation in-place.
  """
  
  def __init__(self, mean: tuple , std : tuple, inplace=False):
    self.mean = mean
    self.std = std
    self.inplace = inplace
  
  def __call__(self, sample):
    image, label = sample['image'], sample['label']

    norm = transforms.Normalize(self.mean,self.std)
    
    image = norm(image.float())

    return {"image":image, "label": label}

In [None]:
class ToTensor(object):
  """Convert ndarrays in sample to Tensors."""
  def __call__(self,sample):
    image, label = sample["image"], sample["label"]
    image, label = transforms.ToTensor()(image), transforms.ToTensor()(label)
    print(f"after ToTensor : {image.shape}")
    return {"image": image, "label": label}