In [2]:
import torch
import torch.nn as nn

In [None]:
def crop_image(tensor, target):
  tensor_size = tensor.size()[2]
  target_size = target.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:, :, delta:tensor_size-delta, delta:tensor_size-delta]

def double_conv(in_ch, out_ch):
  conv = nn.Sequential(
      nn.Conv2d(in_ch, out_ch, kernel_size=3),
      nn.ReLU(inplace=True),
      nn.Conv2d(out_ch, out_ch, kernel_size=3),
      nn.ReLU(inplace=True)
  )
  return conv

class Unet(nn.Module):
  def __init__(self):
    super(Unet, self).__init__()

    self.max_pool_2x2 = nn.MaxPool2d(kernel_size = 2, stride = 2)

    self.double_conv_1 = double_conv(1, 64)
    self.double_conv_2 = double_conv(64, 128)
    self.double_conv_3 = double_conv(128, 256)
    self.double_conv_4 = double_conv(256, 512)
    self.double_conv_5 = double_conv(512, 1024)

    self.up_trans_1 = nn.ConvTranspose2d(1024, 512, kernel_size = 2, stride = 2)
    self.up_conv_1 = double_conv(1024, 512)

    self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2)
    self.up_conv_2 = double_conv(512, 256)

    self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2)
    self.up_conv_3 = double_conv(256, 128)

    self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2)
    self.up_conv_4 = double_conv(128, 64)

    self.final_conv = nn.Conv2d(64, 2, kernel_size = 1)

  def forward(self, image):
    x1 = self.double_conv_1(image) #
    x2 = self.max_pool_2x2(x1)
    x3 = self.double_conv_2(x2) #
    x4 = self.max_pool_2x2(x3)
    x5 = self.double_conv_3(x4) #
    x6 = self.max_pool_2x2(x5)
    x7 = self.double_conv_4(x6) #
    x8 = self.max_pool_2x2(x7)
    x9 = self.double_conv_5(x8)

    x = self.up_trans_1(x9)
    y = crop_image(x7, x)
    x = self.up_conv_1(torch.cat([x, y], 1))

    x = self.up_trans_2(x)
    y = crop_image(x5, x)
    x = self.up_conv_2(torch.cat([x, y], 1))

    x = self.up_trans_3(x)
    y = crop_image(x3, x)
    x = self.up_conv_3(torch.cat([x, y], 1))

    x = self.up_trans_4(x)
    y = crop_image(x1, x)
    x = self.up_conv_4(torch.cat([x, y], 1))

    x = self.final_conv(x)
    # print(f"final: {x.size()}")
    return x



In [None]:
import os
from PIL import Image
import numpy as np

class Watermark_datasets(nn.Dataset):
  def __init__(self, image_dir, mask_dir, transform=None):
    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.transform = transform
    self.images = os.listdir(image_dir)

  def __len__(self):
    return len(self.images)

  # Make some changes here
  def __getitem__(self, index):
    img_path = os.path.join(self.image_dir, self.images[index])
    mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpeg", "_make.gif"))
    image = np.array(Image.open(img_path).convert("RGB"))
    mask = np.array(Image.open(mask_path).convert("L"))
    mask[mask == 255] = 1

    if self.transform is not None:
      augmentations = self.transform(image = image, mask = mask)
      image = augmentations['image']
      mask = augmentations['mask']

    return image, mask