<a href="https://colab.research.google.com/github/SrivastavaHarsit/Satellite-Image-Segmentation/blob/main/unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn.modules.activation import Sigmoid
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

import random
from random import shuffle
import math

import cv2 as cv
from google.colab.patches import cv2_imshow

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [14]:
! pip install -q kaggle
from google.colab import files

files.upload()

! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download bulentsiyah/semantic-drone-dataset

401 - Unauthorized


In [None]:
!unzip /content/semantic-drone-dataset.zip

In [None]:
imgs = []
labels = []
for i, img in enumerate(glob.glob('/content/dataset/semantic_drone_dataset/original_images/*')):
  img = cv.cvtColor(cv.resize(cv.imread(img), (512, 512)), cv.COLOR_BGR2GRAY)/255
  imgs.append(img)
  # print(i)
  if i >= 50:
    break

for i, img in enumerate(glob.glob('/content/dataset/semantic_drone_dataset/label_images_semantic/*')):
  img = cv.cvtColor(cv.resize(cv.imread(img), (512, 512)), cv.COLOR_BGR2GRAY)/255
  labels.append(img)
  if i >= 50:
    break

In [None]:
data_loader = DataLoader(list(zip(imgs, labels)), batch_size=32, shuffle=True)

In [None]:
class encoder(nn.Module):
  def __init__(self):
    super(encoder, self).__init__()

    self.encoder = nn.Sequential(
        nn.Conv2d(1, 256, kernel_size=5, stride=3, padding=1), # N, 64, 14, 14
        nn.SELU(),
        nn.BatchNorm2d(256),
        nn.Conv2d(256, 128, kernel_size=5, stride=3, padding=1), # N, 128, 7, 7
        nn.SELU(),
        nn.BatchNorm2d(128),
        nn.Conv2d(128, 64, kernel_size=7), # N, 64, 1, 1
        nn.SELU()
    )
    self.fc = nn.Linear(64*50*50, 32)
    self.drop = nn.Dropout(0.4)

  def forward(self, x):

    x = self.encoder(x)
    x = x.view(-1, 64*50*50)
    x = self.drop(self.fc(x))

    return x


class decoder(nn.Module):
  def __init__(self):
    super(decoder, self).__init__()

    self.decoder = nn.Sequential(
        nn.ConvTranspose2d(64, 128, 7), # N, 128, 7, 7
        nn.SELU(),
        nn.BatchNorm2d(128),
        nn.ConvTranspose2d(128, 256, 5, stride=3, padding=1, output_padding=2), # N, 64, 14, 14
        nn.SELU(),
        nn.BatchNorm2d(256),
        nn.ConvTranspose2d(256, 1, 5, stride=3, padding=1, output_padding=2),  # N, 1, 28, 28
        nn.Sigmoid()
    )

    self.initial = nn.Linear(32, 64*50*50)
    self.drop = nn.Dropout(0.4)

  def forward(self, x):
    x = self.drop(self.initial(x))
    x = x.view(-1, 64, 50, 50)
    x = self.decoder(x)
    return x

In [None]:
!pip install torchinfo
from torchinfo import summary

print(summary(encoder(), (64,1,512,512)))
print(summary(decoder(), (64, 32)))

In [None]:
import math
import torch.nn.functional as F
##ssim loss function

def gaussian(window_size, sigma):
    """
    Generates a list of Tensor values drawn from a gaussian distribution with standard
    diviation = sigma and sum of all elements = 1.

    Length of list = window_size
    """
    gauss =  torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel=1):

    # Generate an 1D tensor containing values sampled from a gaussian distribution
    _1d_window = gaussian(window_size=window_size, sigma=1.5).unsqueeze(1)

    # Converting to 2D
    _2d_window = _1d_window.mm(_1d_window.t()).float().unsqueeze(0).unsqueeze(0)

    window = torch.Tensor(_2d_window.expand(channel, 1, window_size, window_size).contiguous())

    return window

def ssim(img1, img2, val_range=1, window_size=11, window=None, size_average=True, full=False):

    L = val_range # L is the dynamic range of the pixel values (255 for 8-bit grayscale images),

    pad = window_size // 2

    try:
        _, channels, height, width = img1.size()
    except:
        channels, height, width = img1.size()

    # if window is not provided, init one
    if window is None:
        real_size = min(window_size, height, width) # window should be atleast 11x11
        window = create_window(real_size, channel=channels).to(img1.device)

    # calculating the mu parameter (locally) for both images using a gaussian filter
    # calculates the luminosity params
    mu1 = F.conv2d(img1, window, padding=pad, groups=channels)
    mu2 = F.conv2d(img2, window, padding=pad, groups=channels)

    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu12 = mu1 * mu2

    # now we calculate the sigma square parameter
    # Sigma deals with the contrast component
    sigma1_sq = F.conv2d(img1 * img1, window, padding=pad, groups=channels) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=pad, groups=channels) - mu2_sq
    sigma12 =  F.conv2d(img1 * img2, window, padding=pad, groups=channels) - mu12

    # Some constants for stability
    C1 = (0.01 ) ** 2  # NOTE: Removed L from here (ref PT implementation)
    C2 = (0.03 ) ** 2

    contrast_metric = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    contrast_metric = torch.mean(contrast_metric)

    numerator1 = 2 * mu12 + C1
    numerator2 = 2 * sigma12 + C2
    denominator1 = mu1_sq + mu2_sq + C1
    denominator2 = sigma1_sq + sigma2_sq + C2

    ssim_score = (numerator1 * numerator2) / (denominator1 * denominator2)

    if size_average:
        ret = ssim_score.mean()
    else:
        ret = ssim_score.mean(1).mean(1).mean(1)

    if full:
        return ret, contrast_metric

    return ret


class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, target):
        target = torch.Tensor(target)
        loss = ssim(output, target)
        return -loss

In [None]:
en = encoder().to(device)
dec = decoder().to(device)
criterion = CustomLoss()
en_optim = Adam(en.parameters(), lr=0.001)
de_optim = Adam(dec.parameters(), lr=0.001)

In [None]:
num_epochs=20
outputs=[]
for epoch in range(num_epochs):
  en.train()
  dec.train()
  for (img, label) in data_loader:
    img = img.to(device).float().unsqueeze(1)
    label = label.to(device).float().unsqueeze(1)
    # print(img.size())
    encoded = en(img)
    recon = dec(encoded)
    loss = criterion(recon, label)

    en_optim.zero_grad()
    de_optim.zero_grad()
    loss.backward()
    en_optim.step()
    de_optim.step()

  print(f"Epoch: {epoch}, loss: {loss.item()}")
  outputs.append((epoch, img, recon))

In [None]:
plt.imshow(outputs[19][1][0].detach().cpu().numpy().squeeze())

In [None]:
plt.imshow(outputs[19][2][0].detach().cpu().numpy().squeeze())

In [None]:
from google.colab import files

torch.save(en, '/content/imgen.pt')
torch.save(dec, '/content/imgdec.pt')
files.download('/content/imgen.pt')
files.download('/content/imgdec.pt')