In [3]:
import json
import glob
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 
import torchvision.transforms as transforms
from PIL import Image, ImageDraw

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print ("Device: ", device)

Device:  cuda


In [75]:
def draw_polygon_on_image(points, image, value):
  points = [(x, y) for x, y in points]
  draw = ImageDraw.Draw(image, 'RGB')
  draw.polygon(points, fill=(value, value, value))
  return image

def generate_ground_truth(ground_truth, image):
  black_image = Image.new('RGB', image.size)
  for shape in ground_truth['shapes']:
    points = shape['points']
    image = draw_polygon_on_image(points, black_image, shape['group_id'])
  return image

def process_images():
  path = r'C:\Users\topra\Documents\Jupyter\Satellite-Model\ViT\data'
  path = os.path.join(path, 'labels', '*.json')
  image_names_cand = glob.glob(path)
  image_names = []
  label_names_cand = [image_name.replace('images', 'labels').replace('.png', '.json') for image_name in image_names_cand]
  label_names = []
  for label_name, i in zip(label_names_cand, range(len(label_names_cand))):
    if os.path.exists(label_name):
      image_name = label_name.replace('labels', 'images').replace('.json', '.png')
      image_names.append(image_name)
      label_names.append(label_name)
  for image_name, label_name in zip(image_names, label_names):
    image = Image.open(image_name)
    with open(label_name, 'r') as file:
      ground_truth = json.load(file)
      ground_truth_image = generate_ground_truth(ground_truth, image)
      ground_truth_image.save(image_name.replace('images', 'truth'))

process_images()

In [4]:
class SatelliteDataset(Dataset):
  def __init__(self, image_folder, truth_folder, transform_data=None, transform_truth=None):
    path = r'C:\Users\topra\Documents\Jupyter\Satellite-Model\ViT\data'
    image_path = os.path.join(path, image_folder, '*.png')
    truth_path = os.path.join(path, truth_folder, '*.png')
    image_names_cand = glob.glob(image_path)
    truth_images_cand = glob.glob(truth_path)
    self.image_names = []
    self.truth_images = []
    for image_name, i in zip(image_names_cand, range(len(image_names_cand))):
      truth_name = image_name.replace('images', 'truth')
      if truth_name in truth_images_cand:
        self.image_names.append(image_name)
        self.truth_images.append(truth_name)
    self.transform_data = transform_data
    self.transform_truth = transform_truth

  def __len__(self):
    return len(self.image_names)

  def __getitem__(self, idx):
    image = Image.open(self.image_names[idx])
    truth = Image.open(self.truth_images[idx])
    if self.transform_data:
      image = self.transform_data(image)
    if self.transform_truth:
      truth = self.transform_truth(truth)
    return image, truth

transform_data = transforms.Compose([
  transforms.ToTensor()
])

transform_truth = transforms.Compose([
  transforms.ToTensor(),
  transforms.Grayscale()
])

dataset = SatelliteDataset('images', 'truth', transform_data=transform_data, transform_truth=transform_truth)

In [5]:
class PatchPositionalEmbedding(nn.Module):
    def __init__(self,
                 channels: int=3,
                 patch_size: int=128,
                 max_len: int=512):
        super(PatchPositionalEmbedding, self).__init__()

        # Layers
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size, padding=0)       # (batch, C*patch_size*patch_size, num_tokens)
        self.embed = nn.Parameter(torch.randn(1, max_len, channels*patch_size*patch_size))  # (batch, max_len, C*patch_size*patch_size)
        self.pred = torch.zeros(1, 1, channels*patch_size*patch_size)                       # (batch, max_len, C*patch_size*patch_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """_summary_

        Args:
            x (torch.Tensor): (batch, C, H, W)

        Returns:
            torch.Tensor: (batch, min(max_len, num_tokens), C*patch_size*patch_size)
        """
        x = self.unfold(x)
        x = torch.concat([x.transpose(1, 2), self.pred], dim=1)
        return x + self.embed[:, :x.size(1)]


In [6]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self,
                channels: int=3,
                patch_size: int=128):
        super(ScaledDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self,
                v: torch.Tensor,
                k: torch.Tensor,
                q: torch.Tensor) -> torch.Tensor:
        x = torch.dot(q, k)
        x = x / (q.size(-1) ** 0.5)
        x = self.softmax(x)
        x = torch.matmul(x, v)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self,
                 channels: int=3,
                 patch_size: int=128,
                 h: int=8):
        super(MultiHeadAttention, self).__init__()
        self.linear_v = [nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size*h) for _ in range(h)]
        self.linear_k = [nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size*h) for _ in range(h)]
        self.linear_q = [nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size*h) for _ in range(h)]
        self.scaled_dot = [ScaledDotProductAttention(channels=channels, patch_size=patch_size) for _ in range(h)]
        self.linear_out = nn.Linear(channels*patch_size*patch_size*h, channels*patch_size*patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = []
        for i in range(h):
            v = self.linear_v[i](x)
            k = self.linear_k[i](x)
            q = self.linear_q[i](x)
            y.append(self.scaled_dot[i](v, k, q))
        x = torch.concat(y, dim=-1)
        x = self.linear_out(x)
        return x

class TransformEncoder(nn.Module):
    def __init__(self,
                 channels: int=3,
                 patch_size: int=128,
                 max_len: int=512):
        super(TransformEncoder, self).__init__()
        self.norm_1 = nn.LayerNorm(channels*patch_size*patch_size)
        self.attention = MultiHeadAttention()
        self.norm_2 = nn.LayerNorm(channels*patch_size*patch_size)
        self.mlp = nn.Sequential(
            nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size),
            nn.Sigmoid(),
            nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size),
            nn.Sigmoid(),
            nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size),
            nn.Sigmoid(),
            nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size),
            nn.Sigmoid(),
            nn.Linear(channels*patch_size*patch_size, channels*patch_size*patch_size)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """_summary_
        
        Args:
            x (torch.Tensor): (batch, num_tokens, C*patch_size*patch_size)

        Returns:
            torch.Tensor: (batch, num_tokens, C*patch_size*patch_size)
        """
        x_2 = self.norm_1(x)
        x_2 = self.attention(x_2)
        x_2 = x_2 + x
        x = x_2
        x_2 = self.norm_2(x)
        x_2 = self.mlp(x_2)
        x_2 = x_2 + x
        return x_2

class VisionTransformer(nn.Module):
    def __init__(self,
                 channels: int=3,
                 patch_size: int=128,
                 max_len: int=512,
                 num_layers: int=12):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchPositionalEmbedding(channels=channels, patch_size=patch_size, max_len=max_len)
        self.layers = nn.Sequential(*[TransformEncoder(channels=channels, patch_size=patch_size, max_len=max_len) for _ in range(num_layers)])
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)
        x = self.layers(x)
        
        return x

In [14]:
import numpy as np

image, truth = dataset[0] # Get the first image and its ground truth
# enlarge the images to 2048x2048
image = transforms.Resize((2048, 2048))(image)
truth = transforms.Resize((2048, 2048))(truth)

patch_embedding = PatchPositionalEmbedding(channels=3, patch_size=128, max_len=1024)
x = patch_embedding(image.unsqueeze(0))
print(f"{image.shape} -> {x.shape}")

torch.Size([3, 2048, 2048]) -> torch.Size([1, 257, 49152])


In [2]:
import torch.nn as nn

EMBEDDING_SIZE = 2048

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.embed = PatchPositionalEmbedding(channels=3, patch_size=128, max_len=1024)
    self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=EMBEDDING_SIZE, nhead=8), num_layers=6)

  def forward(self, x):
    return self.vit(x)