#  WildSat U-Net training notebook
In this notebook you can find U-net segmentation training pipeline.

- U-net Model
- Preprocessing of data
- Test outputs


In [None]:
!pip3 install torchvision
!pip3 install torch
!pip3 install opencv-python-headless
!pip3 install opencv-contrib-python-headless
!pip3 install matplotlib
!pip3 install numpy
!pip3 install scipy

In [None]:
import os
ROOT_PATH=os.path.dirname(os.path.realpath(__file__))
OUT_PATH=os.path.join(ROOT_PATH, '..',"unet.pth")
DATASET_MAIN_FOLDER=os.path.join(ROOT_PATH,"..","Wildfire_Dataset")
CUDA_ENABLED=False
LEARNING_RATE=0.007
BATCH_SIZE=64
EPOCHS=100
TRAIN_TEST_RATION=0.8

In [2]:
# lets import required libraries
import torch
import torchvision as vision
import cv2
import numpy as np
import torchvision.transforms as transform
import torchvision.datasets as dtst
import torch.utils as utils
import os
from matplotlib import pyplot

# This function will divided dataset into test and train set and will return 
# PyTorch dataloader objects for train and test set.
#
# Variables path : Path to dataset
#           ratio : train test ratio     
#           batch_size : batch size for dataloader
def initialize_dataloader(path,ratio,batch_size):
    trainset_size=int(ratio*(len(os.listdir(path+"MAIN_DATA"))))
    testset_size=len(os.listdir(path+"MAIN_DATA"))-trainset_size

    transformation=transform.Compose([transform.Resize((256,512)) ,transform.ToTensor(),transform.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])
    fulldata=dtst.ImageFolder(root=path,transform=transformation)
    trainset,testset=torch.utils.data.random_split(fulldata,[trainset_size,testset_size])
    traindataloader=utils.data.DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=4)
    testdataloader=utils.data.DataLoader(testset,batch_size=batch_size,shuffle=True,num_workers=4)
    return [traindataloader,testdataloader]

# This function can be used to display 
# images and segments side by side.
def show(image,segment,ground_truth=torch.tensor(1)):
  f = pyplot.figure(figsize=(16,16))
  image=(image.cpu().numpy().transpose(1,2,0))
  segment=(segment.cpu().numpy().transpose(1,2,0))
  if(ground_truth.size()!=torch.tensor(1).size()):
    ground_truth=(ground_truth.cpu().numpy().transpose(1,2,0))
    f.add_subplot(1,3, 1)
    pyplot.imshow(np.rot90(image,2))
    f.add_subplot(1,3, 2)
    pyplot.imshow(np.rot90(segment,2))
    f.add_subplot(1,3, 3)
    pyplot.imshow(np.rot90(ground_truth,2))
  else:
    f.add_subplot(1,2, 1)
    pyplot.imshow(np.rot90(image,2))
    f.add_subplot(1,2, 2)
    pyplot.imshow(np.rot90(segment,2))
  pyplot.show()


In [3]:
# Downsampling block of U-Net model
class Conv_Block(torch.nn.Module):
  def __init__(self,input_ch,output_ch):
    super(Conv_Block,self).__init__()
    self.network=[]
    self.network+=[ torch.nn.Conv2d(input_ch,output_ch,3,1,padding=1) ,torch.nn.BatchNorm2d(output_ch), torch.nn.ReLU() , \
                   torch.nn.Conv2d(output_ch,output_ch,3,1,padding=1) , torch.nn.BatchNorm2d(output_ch) ,torch.nn.ReLU() ]
    self.network=torch.nn.Sequential(*self.network)
  def forward(self,input):
    return self.network(input)

# Upsampling block of U-net model
class DeConv_Block(torch.nn.Module):
  def __init__(self,input_ch,output_ch):
    super(DeConv_Block,self).__init__()
    self.network=[]
    self.network+= [ Conv_Block(input_ch,output_ch) ]
    self.network+= [ torch.nn.ConvTranspose2d(output_ch,output_ch,2,2) ] 
    self.network= torch.nn.Sequential(*self.network)
  def forward(self,input):
    return self.network(input)

# Main model that combines both downsampling and upsampling blocks
# and also has a skip connection.
class UNet(torch.nn.Module):
  def __init__(self):
    super(UNet,self).__init__()
    self.down_sample1=Conv_Block(3,64)
    self.down_sample2=Conv_Block(64,128)
    self.down_sample3=Conv_Block(128,256)
    self.down_sample4=Conv_Block(256,512)
    self.up_sample1=Conv_Block(512,512)  # 4
    self.up_sample2=DeConv_Block(1024,256) #3
    self.up_sample3=DeConv_Block(512,128) #2
    self.up_sample4=DeConv_Block(256,64) #1
    self.out=torch.nn.Sequential( Conv_Block(128,64) , torch.nn.Conv2d(64,3,1) )

  def forward(self,input):
    a1=self.down_sample1(input)
    a2=self.down_sample2(torch.nn.MaxPool2d(2,2)(a1))
    a3=self.down_sample3(torch.nn.MaxPool2d(2,2)(a2))
    a4=self.down_sample4(torch.nn.MaxPool2d(2,2)(a3))
    bootleneck=self.up_sample1(a4)
    b1=self.up_sample2( torch.cat([bootleneck,a4],1) )
    b2=self.up_sample3( torch.cat([a3,b1],1) )
    b3=self.up_sample4( torch.cat([a2,b2],1) )
    out=self.out( torch.cat([a1,b3],1) ) 
    return torch.functional.F.sigmoid(out)


In [None]:

# Create dataloaders for train and test set.
input_loader,test_loader=initialize_dataloader(DATASET_MAIN_FOLDER,TRAIN_TEST_RATION,BATCH_SIZE)
# Create model
network=UNet()

if CUDA_ENABLED:
    network=network.cuda()

In [4]:
optimizer=torch.optim.Adam(network.parameters(),lr=0.007)
loss=torch.nn.BCELoss()
train_history=[]
test_history=[]


In [None]:
# Training loop
for epoch in range(EPOCHS):
    network=network.train()
    for input,_ in input_loader:
        segment=torch.autograd.Variable( (input[:,:,:,:256]+1)/2)
        out=torch.autograd.Variable( (input[:,:,:,256:512]+1)/2 )
        
        if CUDA_ENABLED:
            segment=segment.cuda()
            out=out.cuda()

        loss.zero_grad()
        optimizer.zero_grad()
        forward=network(segment)
        loss_train =loss(forward,out)
        loss_train.backward()
        optimizer.step()
        train_history.append(loss_train.item())
    # evaluation loop of model
    network=network.eval()
    with torch.no_grad():
        for test,_ in test_loader:
            segment=torch.autograd.Variable( (test[:,:,:,:256]+1)/2)
            out=torch.autograd.Variable( (test[:,:,:,256:512]+1)/2 )
        
            if CUDA_ENABLED:
                segment=segment.cuda()
                out=out.cuda()

            forward=network(segment)
            test_loss =loss(forward,out)
            test_history.append(test_loss.item())

    # Draw plot of loss for each 5 epochs
    if((epoch+1)%20==0):
        test=next(iter(test_loader))[0]
        segment=test[0][:,:,:256].unsqueeze(0)+1
        ground_truth=test[0][:,:,256:]+1
        
        if CUDA_ENABLED:
            segment=segment.cuda()
            ground_truth=ground_truth.cuda()
            
        result=network(segment)[0].detach()
        show( segment[0],result,ground_truth)

        pyplot.plot(range(epoch),train_history,label="train")
        pyplot.plot(range(epoch),test_history,label="test")
        pyplot.show()
        torch.save(network.state_dict(),OUT_PATH)