In [None]:
#Running for noisy imag and gt image both in grayscale

from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
import torch
import cv2
import matplotlib.pyplot as plt
import os
import glob
import numpy as np
from torchvision import utils
from PIL import Image
%matplotlib inline

In [None]:
#training import , 70 pictures per folder from 00 to 69

img_size=256
torch.cuda.empty_cache()
img_noisy_tens=torch.empty([1470,3,img_size,img_size])  #Initializing tensor for noisy image

img_mask_tens=torch.empty([1470,3,img_size,img_size])   #Initializing tensor for grayscale gt

char_array=['agricultural','airplane','baseballdiamond','beach','buildings','chaparral','denseresidential','forest','freeway','golfcourse','harbor','intersection','mediumresidential','mobilehomepark','overpass','parkinglot','river','runway','sparseresidential','storagetanks','tenniscourt']

for k in range(21):
    for i in range(7):
        for j in range(10):
            n=cv2.imread('D:\\UCMerced_LandUse\\Img_noisy\\%s%d%d.tif' % (char_array[k],i,j))
            n=np.resize(n,(256,256,3))
            img_noisy_tens[70*k+(10*i+j)]=transforms.ToTensor()(n)
            m=cv2.imread('D:\\UCMerced_LandUse\\Img_grayscale\\%s%d%d.tif' % (char_array[k],i,j))
            m=np.resize(m,(256,256,3))
            img_mask_tens[70*k+(10*i+j)]=transforms.ToTensor()(m)

In [None]:
#Displaying Noisy Grayscale image
to_pil=transforms.ToPILImage()
imgg=to_pil(img_noisy_tens[1469])
imgg

In [None]:
#Pairing Noisy and Gt Image together then feeding to the dataloader

batch_size=2
img_merged_tens=torch.empty([1470,2,3,img_size,img_size])
img_merged_tens[:,0,:,:,:]=img_noisy_tens
img_merged_tens[:,1,:,:,:]=img_mask_tens

#img_merged_tens contains masked and noisy image together in proper order


img_dl=DataLoader(img_merged_tens,batch_size,shuffle=True)
img_batch_sample=next(iter(img_dl))

#Dataloader of batch size 8 is created and has been shuffled 

In [None]:
#Displaying noisy and gt grayscale image side by side 

grid_img=utils.make_grid(img_batch_sample[1],nrow=2)
grid_img.shape
plt.imshow(grid_img.permute(1,2,0))

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=False):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return torch.relu(self.conv(x))


In [None]:
import torch.nn.functional as F

#from .unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.out_channels = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, self.out_channels)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
#Loading saved models if any

import os

device=torch.device("cuda:0")
unet=UNet(4,3)
resume=True
PATH='D:\\UCMerced_LandUse\\models'
if(resume):
    # Load the model
    unet.load_state_dict(torch.load(os.path.join(PATH,'model_latest')))
unet=unet.to(device)

In [None]:
#Trainig Block 

loss=nn.BCELoss()
opt=torch.optim.Adam(unet.parameters(),lr=0.001)
epoch=500
for epoc in range(epoch):
    itr=0
    loss_sum=0
    for i in img_dl:
        input_batch  = i[:,0,:,:,:].to(device)
        target_batch = i[:,1,:,:,:].to(device)
        out=unet(input_batch)
        loss_out=loss(out,target_batch)
        loss_sum+=loss_out
        loss_out.backward()
        opt.step()
        opt.zero_grad()
        itr+=1
        print('loss = {} , epoch = {}, iteration ={}'.format(loss_out,epoc,itr))
        if(itr%5==0):
            torch.save(unet.state_dict(), os.path.join(PATH,'model_latest'))
            torch.save(unet.state_dict(), os.path.join(PATH,'model_'+str(itr)))
    
    