<a href="https://colab.research.google.com/github/HardcoreBudget/PyTorch-Image-Patchify/blob/main/PyTorch_Image_Patchify.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [24]:
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
from pathlib import Path

# Patchify

## Custom Size

In [25]:
def custom_patchify(frame_in, crops_size, overlap_size = [0,0]):
    patch_list = []
    oversize_H = False
    oversize_W = False
    crops_size_H = crops_size[1]
    crops_size_W = crops_size[0]
    overlap_size_H = overlap_size[1]
    overlap_size_W = overlap_size[0]
    frame_in_H = frame_in.shape[-2]
    frame_in_W = frame_in.shape[-1]
    assert (crops_size_H >= 2 * overlap_size_H) and (crops_size_W >= 2 * overlap_size_W), "Crops size should be at least 2x greater than overlap size"
    crops_per_row = 1
    row_size = crops_size_H
    while(row_size < frame_in_H):
      row_size += (crops_size_H - overlap_size_H)
      crops_per_row += 1

    final_row_size = row_size - crops_size_H

    crops_per_col = 1
    col_size = crops_size_W
    while(col_size < frame_in_W):
      col_size += (crops_size_W - overlap_size_W)
      crops_per_col += 1

    final_col_size = col_size - crops_size_W

    oversize_value_H = (int) (frame_in_H - final_row_size)
    oversize_value_W = (int) (frame_in_W - final_col_size)
    if (oversize_value_H != 0):
      oversize_H = True
    if (oversize_value_W != 0):
      oversize_W = True

    top = 0
    height = crops_size_H
    for i in range(crops_per_row):

      left = 0
      crop = []
      if(i == (crops_per_row - 1)):
        if (oversize_H == True):
          height = oversize_value_H
        else:
          height = crops_size_H

      for j in range(crops_per_col):

        if((j != (crops_per_col - 1)) or oversize_W == False):
          width = crops_size_W
        elif(oversize_W == True):
          width = oversize_value_W

        crop.append(TF.crop(frame_in, top, left, height, width)) ##top , Left , Height , Width)
        left += crops_size_W - overlap_size_W

      patch_list.append(crop)
      top += crops_size_H - overlap_size_H

    return patch_list, crops_per_row, crops_per_col

def custom_unpatchify(patch_list, crops_per_row, crops_per_col, overlap_size = [0,0]):

    overlap_size_H = overlap_size[1]
    overlap_size_W = overlap_size[0]
    unpatch_list = []
    unpatch_list_W = []
    unpatch_list_H = []
    end_W = patch_list[0][0].shape[-1] - overlap_size_W
    end_H = patch_list[0][0].shape[-2] - overlap_size_H

    for i in range(crops_per_row):
      crop_unpatch_list = []
      for j in range(crops_per_col):
        if(j == 0):
          crop_W = 0
        else:
          crop_W = overlap_size_W

        if(j != (crops_per_col - 1)):
          crop_unpatch_list.append(patch_list[i][j][:, :, :, crop_W : end_W])
        else:
          crop_unpatch_list.append(patch_list[i][j][:, :, :, crop_W :])


        if((j+1) < crops_per_col and overlap_size_W > 0):
          overlapping_area_W = (patch_list[i][j][:, :, :,end_W : ] + \
                                patch_list[i][j + 1][:, :, :,  : overlap_size_W]) / 2
          crop_unpatch_list.append(overlapping_area_W)

      unpatch_list_W.append(torch.cat(crop_unpatch_list,-1))

      if((i - 1) >= 0 and overlap_size_H > 0):
        overlapping_area_H = (unpatch_list_W[i - 1][:, :, end_H :, :] + \
                              unpatch_list_W[i][:, :,  : overlap_size_H, :]) / 2
        unpatch_list_H.append(overlapping_area_H)
        if(i == 1):
          unpatch_list_W[i - 1] = unpatch_list_W[i - 1][:, :, : end_H, :]

        else:
          unpatch_list_W[i - 1] = unpatch_list_W[i - 1][:, :, overlap_size_H : end_H, :]

      if(i == (crops_per_row - 1)):
        unpatch_list_W[i] = unpatch_list_W[i][:, :, overlap_size_H :, :]

    for z in range(len(unpatch_list_W)):
      unpatch_list.append(unpatch_list_W[z])
      if(z < len(unpatch_list_H)):
        unpatch_list.append(unpatch_list_H[z])

    frame_out = torch.cat(unpatch_list,-2)

    return frame_out

## Grid

In [26]:
def grid_patchify(frame_in,crops_amount):
    patch_list_top = []
    patch_list_bottom = []
    crops_per_row = (int) (crops_amount / 2)
    oversize = False
    if(frame_in.shape[-2] % crops_per_row != 0):
        oversize = True
        oversize_value = (int) (frame_in.shape[-2] - (crops_per_row - 1) * (frame_in.shape[-2]//crops_per_row))
    y = (int) (frame_in.shape[-2]//crops_per_row)
    x = (int) (frame_in.shape[-1]//2)
    for i in range(crops_per_row):
        if(i != (crops_per_row - 1) or oversize == False):
            w=TF.crop(frame_in, i*y, 0,  y,  x) ##top , Left , Height , Width
            patch_list_top.append(w)
            w=TF.crop(frame_in, i*y, x,  y,  x) ##top , Left , Height , Width
            patch_list_bottom.append(w)
        elif(oversize):
            w=TF.crop(frame_in, i*y, 0,  oversize_value,  x) ##top , Left , Height , Width
            patch_list_top.append(w)
            w=TF.crop(frame_in, i*y, x,  oversize_value,  x) ##top , Left , Height , Width
            patch_list_bottom.append(w)

    return patch_list_top,patch_list_bottom

def grid_unpatchify(patch_list_T,patch_list_B):

    frame_out_T=torch.cat(tuple(patch_list_T),-2)
    frame_out_B=torch.cat(tuple(patch_list_B),-2)
    frame_out = torch.cat((frame_out_T,frame_out_B),-1)
    return frame_out

# Testing

In [27]:
image_path = "/content/test.png"

transform_t = transforms.ToTensor()
transform_i = transforms.ToPILImage()

img = Image.open(Path(image_path)).convert("RGB")
img_t = transform_t(img).unsqueeze(0)

print(img_t.shape)

## Custom Patchify

In [None]:
crop_size = [384,384]
overlap_size = [0,0]

patch_list, crops_per_row, crops_per_col = custom_patchify(img_t, crop_size, overlap_size)
frameout = custom_unpatchify(patch_list, crops_per_row, crops_per_col, overlap_size)

print(frameout.shape)
transform_i(frameout.squeeze(0))

## Grid patchify

In [None]:
number_of_crops = 4

patch_list_T, patch_list_B = grid_patchify(img_t, number_of_crops)
frameout = grid_unpatchify(patch_list_T,patch_list_B)

print(frameout.shape)
transform_i(frameout.squeeze(0))