<a href="https://colab.research.google.com/github/Lalasa1234/WaterBody_SemanticSegmentation/blob/main/PyTorchModelling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import torch
from torchvision.transforms import v2
import glob
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline
import numpy as np
from numpy import asarray

from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
import torch.nn as nn
from sklearn.model_selection import train_test_split
from tqdm import tqdm

SIZE = 128
SEED = 42
BATCH = 32
EPOCHS = 10

In [4]:
from google.colab import drive
drive.mount('/content/drive',force_remount =True)

Mounted at /content/drive


In [None]:
os.chdir('/content/drive/MyDrive/Colab Notebooks/WaterBody_ImageSegmentation')
os.getcwd()

In [None]:
real_data = sorted(glob.glob('/content/drive/MyDrive/Colab Notebooks/WaterBody_ImageSegmentation/Water Bodies Dataset/Images/' + '*.jpg'))
masked_data = sorted(glob.glob('/content/drive/MyDrive/Colab Notebooks/WaterBody_ImageSegmentation/Water Bodies Dataset/Masks/' + '*.jpg'))

In [None]:
print ('Real Data - ', Image.open(real_data[2]).mode)
print (asarray(Image.open(real_data[2])).shape, np.transpose(Image.open(real_data[2]),(2,0,1)).shape)
# Masked data looks grayscaled. hence this needs to be changed
test_mask = Image.open(masked_data[2])
print ('Masked Data - ', test_mask.getbands(), asarray(test_mask.convert('L')).shape, np.expand_dims(asarray(test_mask.convert('L')), 0).shape)
print ('Min and max of real data is ',asarray(Image.open(real_data[2])).min(),asarray(Image.open(real_data[2])).max())
print ('Min and max is mask data is ',asarray(test_mask).min(),asarray(test_mask).max())

### Display real and masked of some images side by side

In [None]:
for i in range(3):
  plt.figure(figsize=(10,10))
  plt.subplot(341 + i)
  plt.imshow(Image.open(real_data[i]))
  plt.subplot(342 + i)
  plt.imshow(Image.open(masked_data[i]))

### Check the resize function by observing the axes

In [None]:
resizer = v2.Resize((512, 512))
plt.figure(figsize=(9,5))
plt.subplot(1,2,1)
plt.imshow(Image.open(real_data[0]))
plt.subplot(1,2,2)
plt.imshow(resizer(Image.open(real_data[0])))

### Testing the Random Horizontal Flip on the same image

In [None]:
# Flip takes the probaility as param., so for every run, image may or may not get flipped
hor_flipper = v2.RandomHorizontalFlip(p=0.5)
plt.figure(figsize=(9,5))
plt.subplot(1,2,1)
plt.imshow(Image.open(real_data[0]))
plt.subplot(1,2,2)
plt.imshow(hor_flipper(Image.open(real_data[0])))

### Test Random Rotation

In [None]:
rot = v2.RandomRotation((0,30))
plt.figure(figsize=(9,5))
plt.subplot(1,2,1)
plt.imshow(Image.open(real_data[0]))
plt.subplot(1,2,2)
plt.imshow(rot(Image.open(real_data[0])))

### Test Vertical Flip

In [None]:
ver_flipper = v2.RandomVerticalFlip(p=0.1)
plt.figure(figsize=(9,5))
plt.subplot(1,2,1)
plt.imshow(Image.open(real_data[0]))
plt.subplot(1,2,2)
plt.imshow(ver_flipper(Image.open(real_data[0])))

### Check color changes

In [None]:
col_changer = v2.ColorJitter(brightness=1,contrast = 1, saturation = 0.3, hue = 0.5)
# contrast determines the no. of shades. saturation det. the intensity, hue should be less than 0.5
# hue shifts the color values
plt.figure(figsize=(9,5))
plt.subplot(1,2,1)
plt.imshow(Image.open(real_data[0]))
plt.subplot(1,2,2)
plt.imshow(col_changer(Image.open(real_data[0])))

### Check greyscale transformation

In [None]:
gray = v2.Grayscale(3)
plt.figure(figsize=(9,5))
plt.subplot(1,2,1)
plt.imshow(Image.open(real_data[2]))
plt.subplot(1,2,2)
plt.imshow(gray(Image.open(real_data[2])))

In [None]:
print (asarray(Image.open(real_data[2])).shape)
print (asarray(gray(Image.open(real_data[2]))).shape)

### **Building the Model: U-Net on PyTorch**

### Dataset Preparation

In [None]:
class SatelliteData(Dataset):
  def __init__(self, real_data, masked_data, transform = None, mask_transform = None):
    super().__init__()

    self.real_data = real_data
    self.masked_data = masked_data

    self.transform = transform
    self.mask_transform = mask_transform

  def __len__(self):
    return len(self.real_data)

  # We locate the file, convert to tensor using read_image, call the transforms (if appl.) and return
  def __getitem__(self, index):
    image = read_image(self.real_data[index])
    if self.transform is not None:
      image = self.transform(image)
      image = image/250.0

    mask = read_image(self.masked_data[index])
    if self.mask_transform is not None:
      mask = self.transform(mask)
      mask = self.mask_transform(mask)
      mask = mask/255.0

    return image, mask

In [None]:
transform = v2.Compose([v2.Resize((SIZE,SIZE)),v2.RandomHorizontalFlip(),v2.RandomRotation((0,30)),
                        v2.RandomVerticalFlip(0.1)])

test_transform = v2.Compose([v2.Resize((SIZE,SIZE))])

mask_transform = v2.Grayscale(1)

Testing whether the dataset works or not

In [None]:
obj = SatelliteData(real_data, masked_data)
print (obj.masked_data[0:2], obj.masked_data[0:2])

In [None]:
obj.transform, obj.mask_transform

In [None]:
# test the len func
len(obj)

In [None]:
# test the getitem func
print (obj[0][0].shape, obj[0][1].shape)
obj[0][0].dtype

In [None]:
# test the transformation
obj = SatelliteData(real_data, masked_data, transform, mask_transform)
print (obj.masked_data[0:2], obj.masked_data[0:2])
print (obj.transform, obj.mask_transform)

obj[0][0].shape, obj[0][1].shape

In [None]:
# matplotlib doesn't understand normalization of RGB images, so either convert them to int, or normalize again
plt.subplot(1,2,1)
plt.imshow(obj[12][0].permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(obj[12][1].permute(1,2,0))

In [None]:
X_train, X_test, y_train, y_test = train_test_split(real_data, masked_data, test_size=0.1, random_state=SEED)

In [None]:
train_data = SatelliteData(X_train, y_train, transform, mask_transform)
test_data = SatelliteData(X_test, y_test,test_transform,mask_transform)

In [None]:
# num_workers represents the no. of batches that will be parallelly loaded: Windows needs to have a main() imple.- Not exce. now
train_loader = DataLoader(train_data,batch_size=BATCH,shuffle=True)
test_loader = DataLoader(test_data,batch_size=BATCH, shuffle = False)

In [None]:
print (len(train_data), len(train_loader))
print (len(test_data), len(test_loader))

In [None]:
# test the train data loader
for i, (x,y) in enumerate(train_loader):
  print (x.shape, x[0].shape)
  plt.subplot(1,2,1)
  plt.imshow(x[1].permute(1,2,0))
  plt.subplot(1,2,2)
  plt.imshow(y[1].permute(1,2,0))
  break

In [None]:
# test the val data loader
for i, (x,y) in enumerate(test_loader):
  print (x.shape, x[0].shape)
  plt.subplot(1,2,1)
  plt.imshow(x[1].permute(1,2,0))
  plt.subplot(1,2,2)
  plt.imshow(y[1].permute(1,2,0))
  break

### Defining the model for binary segmentation

Understanding UNet: Every pixel is broken into water or non-water class

In [None]:
class DoubleConv(nn.Module):

  def __init__(self,num_in_chan, num_out_chan):
    super().__init__()

    # same padding ensures that the size of o/p feature map is the same as i/p
    # batch_norm has its own bias paramater called the shift_param (alpha and beta)
    # The shift_params overwrite the Conv2D's bias, hence we remove the redundant bias
    self.conv_block = nn.Sequential(
    nn.Conv2d(num_in_chan,num_out_chan,padding='same',kernel_size=(3,3),bias= False),
    nn.BatchNorm2d(num_out_chan),
    nn.ReLU(inplace = True),

    nn.Conv2d(num_out_chan,num_out_chan,padding='same',kernel_size=(3,3),bias=False),
    nn.BatchNorm2d(num_out_chan),
    nn.ReLU(inplace = True))

  def forward(self,x):
      return self.conv_block(x)

In [None]:
# Create the Downsampling Block
class encoder(nn.Module):
  def __init__(self,num_in_chan,num_out_chan):
    super().__init__()
    self.conv_block = DoubleConv(num_in_chan,num_out_chan)
    self.pool = nn.MaxPool2d(2)

  def forward(self,x):
    side = self.conv_block(x)
    down = self.pool(side)
    return side,down

In [None]:
# Create the Upsampling Block
class decoder(nn.Module):
  def __init__(self,num_in_chan,num_out_chan):
    super().__init__()
    # stride = 2 in order to maintain the channel size to its corresponding conv2d layer
    # for ConvTranspose2d, padding should be an int/tuple and not string
    self.conv1 = nn.ConvTranspose2d(num_in_chan,num_out_chan,kernel_size=2,stride=2,padding=0)
    # no. of channels is twice to align with the channel dim. input = down + its corresponding conv2d side
    self.conv2 = DoubleConv(num_out_chan*2,num_out_chan)

  def forward(self,x,side):
    x = self.conv1(x)
    # The channels from side and down are concatenated. Shape is (Batch, Channel, Height, Width) hence dim=1
    x = torch.cat([x,side],dim=1)
    up = self.conv2(x)
    return up

In [None]:
class UNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder1 = encoder(3,64)
    self.encoder2 = encoder(64,128)
    self.encoder3 = encoder(128,256)
    self.encoder4 = encoder(256,512)

    self.b1 = DoubleConv(512,1024)

    self.decoder1 = decoder(1024,512)
    self.decoder2 = decoder(512,256)
    self.decoder3 = decoder(256,128)
    self.decoder4 = decoder(128,64)

    # channels = 1 as this is a binary segmentation- water as the foreground vs. rem. as the background
    # for n classes, use channels = n
    self.output = nn.Conv2d(64,1,kernel_size=(1,1))

  def forward(self,x):
    self.encoder1.side1,self.encoder1.down1 = self.encoder1(x)
    self.encoder2.side2,self.encoder2.down2 = self.encoder2(self.encoder1.down1)
    self.encoder3.side3,self.encoder3.down3 = self.encoder3(self.encoder2.down2)
    self.encoder4.side4,self.encoder4.down4 = self.encoder4(self.encoder3.down3)

    # we don't call sigmoid here as we use the loss function BCEWithLogitsLoss()
    self.b1_out = self.b1(self.encoder4.down4)

    self.decoder1_out = self.decoder1(self.b1_out,self.encoder4.side4)
    self.decoder2_out = self.decoder2(self.decoder1_out,self.encoder3.side3)
    self.decoder3_out = self.decoder3(self.decoder2_out,self.encoder2.side2)
    self.decoder4_out = self.decoder4(self.decoder3_out,self.encoder1.side1)

    self.output_out = self.output(self.decoder4_out)
    return self.output_out

Test the Model's output shape

In [None]:
x = torch.randn((3,3,128,128))
x.shape, x[0][0]

In [None]:
samp = UNet()

In [None]:
samp(x).shape, samp(x)[0][0]

### Model Training and Evaluation

In [None]:
def train_model(train_loader,model,opt):
  epoch_loss = 0.0
  model.train()
  loop = tqdm(train_loader)
  for i, (X,y) in enumerate(loop):
    X = X.to(device)
    y = y.to(device)
    # X, y = X.float(), y.float() is not needed as they are already of float dtype
    y_pred = model(X)
    loss = criterion(y_pred,y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    epoch_loss += loss.item()
    loop.set_postfix(loss = loss.item())
  return epoch_loss/len(train_loader)

In [None]:
# Testing the train model class
criterion = nn.BCEWithLogitsLoss()
model = UNet()
optimizer = torch.optim.Adam(model.parameters())
epoch_loss = 0.0
model.train()
loop = tqdm(train_loader)
for i, (X,y) in enumerate(loop):
  y_pred = model(X)
  print (f'Prediction is {y_pred} \n max is {torch.max(y_pred)}, min is {torch.min(y_pred)}')
  print (f'Actual is {y} \n max is {torch.max(y)}, min is {torch.min(y)}')

  loss = criterion(y_pred,y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  epoch_loss += loss.item()
  loop.set_postfix(loss = loss.item())
  if i==2:
    break

***Loss was becoming negative for the first run***

**Reason**: BCEWithLogitsLoss(predictions, actuals) needs predictions and actuals to belong in the range (-inf, inf) and actuals in (0,1).  In my case, input y was not normalized, hence dividing value by 255.0 becomes important

In [None]:
def eval_model(test_loader,model):
  epoch_loss = 0.0
  model.eval()
  with torch.no_grad():
    loop = tqdm(test_loader)
    for i,(X,y) in enumerate(loop):
      X = X.to(device)
      y = y.to(device)
      y_pred = model(X)
      loss = criterion(y_pred,y)
      epoch_loss += loss.item()
      loop.set_postfix(loss = loss.item())
  return epoch_loss/len(test_loader)

### ToDo: About Dice Loss:

Dice coeff. = 2*Overlap between Pred versus actual/Area of Pred + Actual
Dice Loss = 1 - Dice Coeff.

Here, for every iteration, 2*overlap further decreases the loss compared to IoU, making it a better metric for obj. segmentation

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print (device)
model = UNet()
model = model.to(device)

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)

In [None]:
# objective of loss function is not to minimize loss, but to bring it to 0
train, val, best_val_loss = [],[], float('inf')
for epoch in range(EPOCHS):
  train_loss = train_model(train_loader,model,optimizer)
  train.append(train_loss)
  val_loss = eval_model(test_loader,model)
  val.append(val_loss)

  if best_val_loss > val_loss:
    print (f'Best Validation loss is reduced from {best_val_loss} to {val_loss}')
    best_val_loss = val_loss
    torch.save(model.state_dict(),'/content/drive/MyDrive/Colab Notebooks/WaterBody_ImageSegmentation/BestModel.pt')

  print (f'For epoch no.{epoch}')
  print (f'Train loss:{train_loss}, Validation loss:{val_loss}')