In [36]:
%matplotlib inline
import matplotlib.image as mpimg
from torch import nn
from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms
from torchinfo import summary
import torch
import numpy as np
import os, sys
from PIL import Image

# Loading dataset

In [23]:
# Loading dataset
def load_image(infilename):
    data = mpimg.imread(infilename)
    return data
def img_float_to_uint8(img):
    rimg = img - np.min(img)
    rimg = (rimg / np.max(rimg) * 255).round().astype(np.uint8)
    return rimg
def concatenate_images(img, gt_img):
    nChannels = len(gt_img.shape)
    w = gt_img.shape[0]
    h = gt_img.shape[1]
    if nChannels == 3:
        cimg = np.concatenate((img, gt_img), axis=1)
    else:
        gt_img_3c = np.zeros((w, h, 3), dtype=np.uint8)
        gt_img8 = img_float_to_uint8(gt_img)
        gt_img_3c[:, :, 0] = gt_img8
        gt_img_3c[:, :, 1] = gt_img8
        gt_img_3c[:, :, 2] = gt_img8
        img8 = img_float_to_uint8(img)
        cimg = np.concatenate((img8, gt_img_3c), axis=1)
    return cimg
def img_crop(im, w, h):
    list_patches = []
    imgwidth = im.shape[0]
    imgheight = im.shape[1]
    is_2d = len(im.shape) < 3
    for i in range(0, imgheight, h):
        for j in range(0, imgwidth, w):
            if is_2d:
                im_patch = im[j : j + w, i : i + h]
            else:
                im_patch = im[j : j + w, i : i + h, :]
            list_patches.append(im_patch)
    return list_patches

root_dir = "training/"

image_dir = root_dir + "images/"
files = os.listdir(image_dir)
n = min(20, len(files))  # Load maximum 20 images
print("Loading " + str(n) + " images")
imgs = [load_image(image_dir + files[i]) for i in range(n)]
print(files[0])

gt_dir = root_dir + "groundtruth/"
print("Loading " + str(n) + " images")
gt_imgs = [load_image(gt_dir + files[i]) for i in range(n)]
print(files[0])

Loading 20 images
satImage_001.png
Loading 20 images
satImage_001.png


(400, 400, 3)

# Patching

In [22]:
patch_size = 16

img_patches = [img_crop(imgs[i], patch_size, patch_size) for i in range(n)]
gt_patches = [img_crop(gt_imgs[i], patch_size, patch_size) for i in range(n)]

# Linearize list of patches
img_patches = np.asarray(
    [
        img_patches[i][j]
        for i in range(len(img_patches))
        for j in range(len(img_patches[i]))
    ]
)
gt_patches = np.asarray(
    [
        gt_patches[i][j]
        for i in range(len(gt_patches))
        for j in range(len(gt_patches[i]))
    ]
)

img_patches.shape

(12500, 16, 16, 3)

# Trying a simple FCN

In [18]:
class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()
        # Input shape: (3, 16, 16)
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 32, 5, 1, 2)
        )

        self.unconv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 5, 1, 2),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 11, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.unconv(x)
        return x


In [19]:
summary(FCN(), input_size=(3, 16, 16))

Layer (type:depth-idx)                   Output Shape              Param #
FCN                                      [1, 16, 16]               --
├─Sequential: 1-1                        [32, 8, 8]                --
│    └─Conv2d: 2-1                       [16, 16, 16]              448
│    └─ReLU: 2-2                         [16, 16, 16]              --
│    └─MaxPool2d: 2-3                    [16, 8, 8]                --
│    └─Conv2d: 2-4                       [32, 8, 8]                12,832
├─Sequential: 1-2                        [1, 16, 16]               --
│    └─ConvTranspose2d: 2-5              [16, 8, 8]                12,816
│    └─ReLU: 2-6                         [16, 8, 8]                --
│    └─ConvTranspose2d: 2-7              [1, 16, 16]               1,937
│    └─Sigmoid: 2-8                      [1, 16, 16]               --
Total params: 28,033
Trainable params: 28,033
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 5.07
Input size (MB): 0.00
Forward/bac

In [61]:
n_img_patches = np.asarray(img_patches)

In [73]:
def train_epoch(model, device, train_loader, optimizer, epoch, criterion):
    model.train()

    loss_history = []
    acc_history = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        output = output > 0.5
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())
        acc_history.append(output.eq(target).sum() / target.numel())

    return acc_history, loss_history

def get_mean_std(imgs):
    """
    Normalizes images with mean and standard deviation, by channel
    """
    mean = imgs.mean(axis=(0, 2, 3))
    std = imgs.std(axis=(0, 2, 3))
    return mean, std

def train(device):
    # creating the dataloader
    images = n_img_patches
    groundtruth = Tensor(gt_patches)

    mean, std = get_mean_std(images)
    print(mean, std)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[mean[0], mean[1], mean[2]], std=[std[0], std[1], std[2]])
    ])
    images = torch.stack(([transform(img) for img in images]))
    
    pytorchDl = DataLoader(TensorDataset(images,groundtruth),batch_size = 32,shuffle=True)
    model = FCN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.functional.binary_cross_entropy

    for i in range(10):
        acc_history, loss_history = train_epoch(model, device, pytorchDl, optimizer, i, criterion)
        print("Epoch: ", i, " Accuracy: ", sum(acc_history)/len(acc_history), " Loss: ", sum(loss_history)/len(loss_history))

In [75]:
train(torch.device("cuda"))

[0.30062687 0.29847103 0.2976581  0.29828402 0.29836816 0.29749525
 0.2979489  0.30031368 0.3005086  0.2996019  0.29961902 0.30030555
 0.3008969  0.30131173 0.30097467 0.30262402] [0.17090596 0.17101276 0.17084403 0.1707926  0.17001918 0.1703739
 0.16964005 0.16982245 0.1687028  0.16795427 0.16863145 0.16967691
 0.17056271 0.171323   0.17065597 0.1712765 ]


AssertionError: Size mismatch between tensors