<a href="https://colab.research.google.com/github/Efoma/Efoma/blob/main/Earthdatadownload%20code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [43]:
!pip install Pillow




In [44]:
pip install tifffile




In [45]:
pip install pymodis



In [46]:
pip install pymp-pypi



In [47]:
import pymodis
print(pymodis.__version__)

2.4.1


In [48]:
import os
from pymodis import downmodis
import numpy as np
import pymp
import time
from argparse import ArgumentParser
import shutil
import multiprocessing as mp
import calendar

In [58]:
import torch
import cv2
import torch.optim as optim
import torch.nn as nn
import math
from torchvision.utils import save_image
from torchvision import transforms
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
from collections import OrderedDict
import torch.nn.functional as F

In [59]:
# VDSR and DMCN model

from math import sqrt

import torch.nn.init as init


In [49]:
def read_modis(hdf_path):
    # Implementation of read_modis function
    pass

In [50]:
def calculate_dates_daily(year, month):
    """
    This function calculate the dates of the days of the given month in the given year. Those dates will be used
    later to download the daily images from both the MODIS products
    """
    startdates = ["{}-{}-01".format(str(year),str(month).zfill(2))]
    enddates = []
    for day in range(2, total_days+1):
        enddates.append("{}-{}-{}".format(str(year),str(month).zfill(2),str(day).zfill(2)))
        startdates.append("{}-{}-{}".format(str(year),str(month).zfill(2),str(day).zfill(2)))

    if month != 12:
        enddates.append("{}-{}-01".format(str(year),str(month+1).zfill(2)))
    else:
        enddates.append("{}-01-01".format(str(year+1)))

    return startdates, enddates

In [51]:
def MODIS_Downloader(startdate, enddate, year, product, num_threads, tiles, user="efoma", password="Naijament@75"):
    sensor        = product.split(".")[0]
    hdfs_path     = 'MODIS/MOD_{}_{}/hdfs_files'.format(year,sensor)

    start_time = time.time()
    print("Start to download {} From {} to {}".format(hdfs_path, startdate,enddate))

    try:
        modisDown = downmodis.downModis(user=user,password=password,product=product,destinationFolder=hdfs_path, tiles=tiles, today=startdate, enddate=enddate)
        modisDown.connect()
        modisDown.downloadsAllDay()
    except:
        print("Download Error {} From {} to {}".format(hdfs_path, startdate,enddate))
    print("Finish download {} From {} to {}, time cost: {:.4f}".format(hdfs_path, startdate,enddate,time.time()-start_time))



In [52]:
def MODIS_Downloader_DAILY(startdate, enddate, year, product, tiles, user="efoma", password="Naijament@75"):
    sensor = product.split(".")[0]
    hdfs_path = 'MODIS/MOD_{}_{}/hdfs_files'.format(year,sensor)

    start_time = time.time()
    print("Start to download {} From {} to {}".format(hdfs_path, startdate,enddate))

    try:
        modisDown = downmodis.downModis(user=user,password=password,product=product,destinationFolder=hdfs_path, tiles=tiles, today=startdate, enddate=enddate)
        modisDown.connect()
        modisDown.downloadsAllDay()
    except:
        print("Download Error {} From {} to {}".format(hdfs_path, startdate,enddate))

    print("Finish download {} From {} to {}, time cost: {:.4f}".format(hdfs_path, startdate,enddate,time.time()-start_time))


In [53]:
def MODIS_Data_Preprocessing(year, product, num_threads, delete_files=False):
    """
    This function processes the hdf files in the given year and saves the NDVI and LST images in tiff format.
    If delete_files argument is set to True, the processed hdf files are deleted after.
    """
    sensor        = product.split(".")[0]
    root_dir      = 'MODIS/MOD_{}_{}'.format(year,sensor)
    hdfs_path     = os.path.join(root_dir, 'hdfs_files')
    tifs_1km_path = os.path.join(root_dir, 'tifs_files/1km')
    tifs_250m_path = os.path.join(root_dir, 'tifs_files/250m')

    os.makedirs(hdfs_path,exist_ok=1)

    # Create the save folders based on the processed sensor
    if sensor == 'MOD11A1':
        os.makedirs(tifs_1km_path,exist_ok=1)
    elif sensor == 'MOD13A2':
        os.makedirs(tifs_1km_path,exist_ok=1)
    elif sensor == "MOD13Q1":
        os.makedirs(tifs_250m_path,exist_ok=1)

    # Load the ndvi hdf files
    ndvi_folder = 'MODIS/MOD_{}_MOD09GQ'.format(year)
    ndvi_save_path = os.path.join(ndvi_folder, 'tifs_files/250m')
    ndvi_dir     = os.path.join(ndvi_folder, 'hdfs_files')
    os.makedirs(ndvi_save_path,exist_ok=1)

    list_ndvi =  os.listdir(ndvi_dir)
    indexes_to_delete=[]
    for index in range(len(list_ndvi)):
        if not list_ndvi[index].endswith('hdf'):
            indexes_to_delete.append(index)
    for j in sorted(indexes_to_delete,reverse=True):
        del list_ndvi[j]
    list_ndvi.sort()

    # Start processing file by file
    print("start to processing {}".format(hdfs_path))
    hdfs = os.listdir(hdfs_path)
    hdfs.sort()
    start_time = time.time()
    # Core images with multi-core
    with pymp.Parallel(num_threads) as p:
        for index in p.range(0, len(hdfs)):

            # Process only hdf files
            hdf = hdfs[index]
            if not hdf.endswith('hdf'): continue
            hdf_path = os.path.join(hdfs_path,hdf)

            # Process LST and NDVI images
            if sensor=='MOD11A1':
                process_hdf(hdf_path, hdf,tifs_1km_path,ndvi_save_path,list_ndvi,ndvi_dir, 64, (64,64))

    # Delete files ar the end
    if delete_files:
        shutil.rmtree(ndvi_dir, ignore_errors=False, onerror=None)
        shutil.rmtree(hdfs_path, ignore_errors=False, onerror=None)

    print("Using {:.4f}s to process product = {}".format(time.time()-start_time, product))

In [54]:
def process_hdf(hdf_path, hdf_name, save_dir,ndvi_save_path,list_ndvi,ndvi_dir,step=64,size=(64,64)):
    """
    INPUT:
    hdf_path = input LST image path to be processed | or hdf file path ("/a/b/c.hdf")
    hdf_name = name of the hdf file
    save_dir = directory for saving cropped images
    ndvi_save_path : Path to save the ndvi images
    ndvi_dir : path of the ndvi hdf files
    list_ndvi : List of the ndvi files
    step, size: parameters of "sliding_window()"
    OUTPUT: LST and NDVI images cropped from the hdf files, saved to save_dir in tiff format
    """
    if not hdf_path.endswith('hdf'):
        print("Not hdf file Sorry!")
        return
    # Open the LST hdf
    read_val = read_modis(hdf_path)
    # Ignore the file in the processing it there are errors when opening it
    if read_val is None:
        print("Cannot handle this MODIS file: ", hdf_path, ". Please check it again")
        return

    img_day, img_night, cols, rows, projection, geotransform = read_val

    img_days = []
    img_nights = []
    img_cropped_names = []
    ndvis = []
    ndvi_names = []
    ndvi_geotransforms = []
    geotransform2s = []
    cols2, rows2 = size

    if img_day is None or img_night is None:
        print("Cannot handle this MODIS file: ", hdf_path, ". Please check it again")
        return

    # Divide the original image into 64x64 patches
    hdf_name_list = hdf_name.split(".")
    # For day image
    win_count = 0
    for (x,y,window) in sliding_window(img_day, step, size):
            if window.shape[0] != size[0] or window.shape[1] != size[1]:
                    continue

            img_cropped_name = hdf_name_list[0] + "." + hdf_name_list[1] + ".{}.tif".format(str(win_count).zfill(4))
            img_cropped = window
            geotransform2 = np.asarray(geotransform)
            geotransform2[0] = geotransform[0]+x*geotransform[1] # 1st coordinate of top left pixel of the image
            geotransform2[3] = geotransform[3]+y*geotransform[5] # 2nd coordinate of top left pixel of the image
            geotransform2=tuple(geotransform2)

            img_cropped_names.append(img_cropped_name)
            img_days.append(img_cropped)
            geotransform2s.append(geotransform2)

            win_count += 1

    # For night image
    win_count = 0
    for (x,y,window) in sliding_window(img_night, step, size):
        if window.shape[0] != size[0] or window.shape[1] != size[1]:
                continue
        # save_path = os.path.join(save_dir,img_cropped_name)
        img_cropped = window
        # np.save(save_path,img_cropped)
        img_nights.append(img_cropped)
        win_count += 1

    # Get the corresponding ndvi to the LST image
    ndvi_read_value = get_corresponding_ndvi(list_ndvi,ndvi_dir,hdf_name)
    # Ignore the current file in the processing if there are error when calculating the NDVI
    if ndvi_read_value is None :
        return
    red, NIR, NDVI, ndvi_projection, ndvi_geotransform = ndvi_read_value
    reds = []
    NIRs = []

    # Divide the original NDVI image into 256x256 patches
    win_count = 0
    for (x,y,window) in sliding_window(NDVI, 256, (256,256)):
            if window.shape[0] != 256 or window.shape[1] != 256:
                    continue

            img_cropped_name = img_cropped_names[win_count]
            image_name_list = img_cropped_name.split(".")
            image_name_list[0] = "MOD09GQ"
            save_name = '.'.join(image_name_list)
            ndvi_names.append(save_name)

            img_cropped = window
            ndvis.append(img_cropped)

            geotransform2 = np.asarray(ndvi_geotransform)
            geotransform2[0] = ndvi_geotransform[0]+x*ndvi_geotransform[1] # 1st coordinate of top left pixel of the image
            geotransform2[3] = ndvi_geotransform[3]+y*ndvi_geotransform[5] # 2nd coordinate of top left pixel of the image
            geotransform2=tuple(geotransform2)
            ndvi_geotransforms.append(geotransform2)
            win_count += 1

    # Save images and metadata into .tif file
    for i in range(len(img_cropped_names)):
        save_path = os.path.join(save_dir,img_cropped_names[i])
        succes = save_tif(save_path, img_days[i], img_nights[i], cols2, rows2, projection, geotransform2s[i])
        # Only save NDVI image if LST image has no cloud or sea pixels
        if(succes):
            # Only save ndvi images without NaN values
            if(len(ndvis[i][np.isnan(ndvis[i])])==0):
                save_path_ndvi = os.path.join(ndvi_save_path,ndvi_names[i])
                save_tif_MOD09GQ(save_path_ndvi, ndvis[i], 256, 256, ndvi_projection, ndvi_geotransforms[i])
            # Remove the saved LST images if the corresponding NDVI image has NaN values
            else:
                os.remove(save_path)

Exception ignored in: <Finalize object, dead>
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/pool.py", line 695, in _terminate_pool
    cls._help_stuff_finish(inqueue, task_handler, len(pool))
  File "/usr/lib/python3.10/multiprocessing/pool.py", line 675, in _help_stuff_finish
    inqueue._rlock.acquire()
KeyboardInterrupt: 


In [55]:
def get_corresponding_ndvi(list_ndvi,ndvi_dir,image_name):
    image_name_string = image_name.split(".")
    image_day = image_name_string[1]

    # Go through the list of reflectance files to get the corresponding one
    for ndvi in list_ndvi :
        ndvi_strings = ndvi.split(".")
        # Check that the year and day matches
        if(image_day == ndvi_strings[1]):
            ndvi_path = os.path.join(ndvi_dir, ndvi)
            if os.path.exists(ndvi_path):
                read_value = read_modis_MOD09GQ(ndvi_path)
                # Ignore if we can't open the hdf file
                if read_value is None :
                    print("Cannot handle this MODIS file: ", ndvi_path, ". Please check it again")
                    return None
                qa, red, NIR, cols, rows, projection, geotransform = read_value

                if qa is None or red is None or NIR is None:
                    print("Cannot handle this MODIS file: ", ndvi_path, ". Please check it again")
                    return None

                # Calculate the NDVI based on the RED and NIR reflectance bands
                ndvi = (NIR-red)/(NIR+red)
                return red, NIR, ndvi, projection, geotransform

    return None


In [56]:
def calculate_dates_daily(year, month):
    # Get the total number of days in the month
    total_days = calendar.monthrange(year, month)[1]
    startdates = ["{}-{}-01".format(str(year), str(month).zfill(2))]
    enddates = []
    for day in range(2, total_days + 1):
        enddates.append("{}-{}-{}".format(str(year), str(month).zfill(2), str(day).zfill(2)))
        startdates.append("{}-{}-{}".format(str(year), str(month).zfill(2), str(day).zfill(2)))
    enddates.append("{}-{}-{}".format(str(year), str(month).zfill(2), str(total_days).zfill(2)))
    return startdates, enddates

def download_and_process(year, month, products, tiles, n_processes, num_threads):
    # Calculate the dates in the given month in the given year to download daily images
    startdates, enddates = calculate_dates_daily(year, month)
    total_days = len(startdates)

    for product in products:
        pool = mp.Pool(n_processes)
        results = []

        # Download the daily images
        for i in range(total_days):
            res = pool.apply_async(MODIS_Downloader_DAILY, (startdates[i], enddates[i], year, product, tiles))
            results.append(res)

        # Ensure all downloads are complete
        for res in results:
            res.get()

        pool.close()
        pool.join()

    # Process both products and delete all HDF files in the end
    MODIS_Data_Preprocessing(year, products[0], num_threads, delete_files=False)

def main():
    parser = ArgumentParser()
    parser.add_argument('--year_begin', type=int, default=2020, help='Start year for data download')
    parser.add_argument('--year_end', type=int, default=2022, help='End year for data download')
    args, unknown = parser.parse_known_args()  # Use parse_known_args to handle unrecognized arguments

    years = list(np.arange(args.year_begin, args.year_end))
    products = ["MOD11A1.061", "MOD09GQ.061"]
    tiles = "h18v04"  # Tiles to download, France is in h17v04 and h18v04
    num_threads = 6
    n_processes = 18

    for year in years:
        for month in range(1, 13):
            download_and_process(year, month, products, tiles, n_processes, num_threads)

if __name__ == "__main__":
    main()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 . Please check it againCannot handle this MODIS file: Cannot handle this MODIS file: Cannot handle this MODIS file:  . Please check it again
  
Cannot handle this MODIS file: MODIS/MOD_2020_MOD11A1/hdfs_files/MOD11A1.A2020077.h18v04.061.2021007075155.hdf MODIS/MOD_2020_MOD11A1/hdfs_files/MOD11A1.A2020280.h18v04.061.2021017091526.hdfCannot handle this MODIS file: MODIS/MOD_2020_MOD11A1/hdfs_files/MOD11A1.A2020229.h18v04.061.2021014030135.hdf MODIS/MOD_2020_MOD11A1/hdfs_files/MOD11A1.A2020179.h18v04.061.2021012053407.hdf   . Please check it again . Please check it again. Please check it againMODIS/MOD_2020_MOD11A1/hdfs_files/MOD11A1.A2020129.h18v04.061.2021009024156.hdf
. Please check it again
 
Cannot handle this MODIS file: Cannot handle this MODIS file: 
Cannot handle this MODIS file: . Please check it again Cannot handle this MODIS file:  
 MODIS/MOD_2020_MOD11A1/hdfs_files/MOD11A1.A2020078.h18v04.061.2021007085043.hdf

end of code for dataset

In [76]:
from torch.utils.data import Dataset
import numpy as np

# Define Dataset
class DatasetCustom(Dataset):
    def __init__(self, data_lst, data_nvdi, original_lst):
        # Data is an array of 2d images of size N x w x h
        self.data_lst = data_lst
        self.data_nvdi = data_nvdi
        self.original_lst = original_lst
        # assert self.data_lst.shape[0] == self.data_nvdi.shape[0], "NVDI and LST number of images is different. Indexing will break"

    def __len__(self):
        return self.data_lst.shape[0]

    def __getitem__(self, index):
        return self.data_lst[index,:,:], self.data_nvdi[int(index/2),:,:], self.original_lst[index,:,:]

In [60]:
class Conv_ReLU_Block(nn.Module):
    def __init__(self):
        super(Conv_ReLU_Block, self).__init__()
        self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.conv(x))

In [62]:
class VDSR(nn.Module):
    def __init__(self):
        super(VDSR, self).__init__()
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        residual = x
        out = self.relu(self.input(x))
        out = self.residual_layer(out)
        out = self.output(out)
        out = torch.add(out,residual)
        return out

num = 64

In [63]:
class DwSample(nn.Module):
    def __init__(self, inp, oup, stride, kernal_size = 3, groups=1, BN = False):
        super(DwSample, self).__init__()
        if BN == True:
            self.conv_dw = nn.Sequential(
                nn.Conv2d(inp, oup, kernal_size, stride, int((kernal_size - 1) / 2), groups=groups),
                nn.BatchNorm2d(oup),
                nn.PReLU(),
            )
        else:
            self.conv_dw = nn.Sequential(
                nn.Conv2d(inp, oup, kernal_size, stride, int((kernal_size-1)/2), groups=groups),
                nn.PReLU(),
            )

    def forward(self, x):
        residual = x
        out = self.conv_dw(x)
        return torch.add(out, residual)

In [64]:
class BasicBlock(nn.Module):
    def __init__(self, inp, oup, stride, kernal_size=3, groups=1, BN = False):
        super(BasicBlock, self).__init__()
        if BN == True:
            self.conv_dw = nn.Sequential(
                nn.Conv2d(inp, oup, kernal_size, stride, int((kernal_size - 1) / 2), groups=groups),
                nn.BatchNorm2d(oup),
                nn.PReLU(),
                nn.Conv2d(oup, inp, kernal_size, stride, int((kernal_size - 1) / 2), groups=groups),
            )
        else:
            self.conv_dw = nn.Sequential(
                nn.Conv2d(inp, oup, kernal_size, stride, int((kernal_size - 1) / 2), groups=groups),
                nn.PReLU(),
                nn.Conv2d(oup, inp, kernal_size, stride, int((kernal_size - 1) / 2), groups=groups),
            )
    def forward(self, x):
        residual = x
        return torch.add(self.conv_dw(x), residual)


In [65]:
class UpSample(nn.Module):
    def __init__(self, f, upscale_factor):
        super(UpSample, self).__init__()

        self.relu = nn.PReLU()
        self.conv = nn.Conv2d(f, f * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.relu(self.conv(x))
        x = self.pixel_shuffle(x)
        return x


In [66]:
class DMCN_prelu(nn.Module):
    def __init__(self, BN=True, width = 64):
        super(DMCN_prelu, self).__init__()
        self.input1 = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=3, stride=1, padding=1, bias=False)
        self.input2 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN1 = nn.BatchNorm2d(width)
        self.input3 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN2 = nn.BatchNorm2d(width)
        self.input4 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN3 = nn.BatchNorm2d(width)
        self.input5 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN4 = nn.BatchNorm2d(width)
        self.down_sample1 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=2, padding=1, bias=False)
        self.Conv_DW_layers1 = self.make_layer(DwSample, 5, BN, width)

        self.down_sample2 = nn.Conv2d(in_channels=width, out_channels=width, kernel_size=3, stride=2, padding=1, bias=False)
        self.Conv_DW_layers2 = self.make_layer(DwSample, 2, BN, width)

        self.up_sample1 = UpSample(width,2)

        self.choose1 = nn.Conv2d(in_channels=width*2, out_channels=width, kernel_size=1, stride=1, padding=0, bias=False)
        self.resudial_layers1 = self.make_layer(BasicBlock, 2, BN, width)

        self.up_sample2 = UpSample(width,2)

        self.choose2 = nn.Conv2d(in_channels=width*2, out_channels=width, kernel_size=1, stride=1, padding=0, bias=False)
        self.resudial_layers2 = self.make_layer(BasicBlock, 5, BN, width)

        self.output = nn.Conv2d(in_channels=width, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)

        self.relu = nn.PReLU()

In [67]:
 def make_layer(self, block, num_of_layer, BN, width):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block(width, width, 1, 3, 1, BN))
        return nn.Sequential(*layers)


In [68]:
    def forward(self, x):
        residual = x
        s1 = self.relu(self.input1(x))
        s1 = self.input2(s1)
        s1 = self.relu(self.BN1(s1))
        s1 = self.input3(s1)
        s1 = self.relu(self.BN2(s1))
        s1 = self.input4(s1)
        s1 = self.relu(self.BN3(s1))
        s1 = self.input5(s1)
        s1 = self.relu(self.BN4(s1))
        out = self.down_sample1(s1)
        s2 = self.Conv_DW_layers1(out)

        out = self.down_sample2(s2)
        out = self.Conv_DW_layers2(out)

        out = self.up_sample1(out)
        out = torch.cat((s2, out), 1)
        out = self.choose1(out)
        out = self.resudial_layers1(out)

        out = self.up_sample2(out)
        out = torch.cat((s1, out), 1)
        out = self.choose2(out)
        out = self.resudial_layers2(out)

        out = self.output(out)
        out = torch.add(out, residual)
        return out


""" Parts of the U-Net model """


' Parts of the U-Net model '

In [69]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(DoubleConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [70]:
class DoubleConv_Down(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(DoubleConv_Down, self).__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return self.double_conv(x)

In [71]:
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, res_down=False):
        super(Down, self).__init__()
        self.res_down = res_down
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
        if self.res_down:
            self.in_conv = nn.Conv2d(in_channels, in_channels, kernel_size=2, stride=2)
            self.mid_conv = DoubleConv(in_channels, in_channels)
            self.out_conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )

    def forward(self, x):
        if self.res_down:
            return self.out_conv(self.mid_conv((self.in_conv(x))) + self.in_conv(x))

        else:
            return self.maxpool_conv(x)



In [72]:
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2],mode='reflect')
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [73]:
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        # self.conv = nn.Sequential(
        #     nn.Conv2d(in_channels, out_channels, kernel_size=1),
        #     nn.Sigmoid())

    def forward(self, x):
        return self.conv(x)


In [74]:
class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block
        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block.
        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not
        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out


In [75]:
class MRUNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1, res_down=False, n_resblocks=1, padding_type="reflect", norm_layer=nn.BatchNorm2d, use_dropout=False, use_bias=True, bilinear=False):
        super(MRUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)

        ### Encoder
        self.down1 = Down(64, 128, res_down=res_down)
        self.down2 = Down(128, 256, res_down=res_down)
        self.down3 = Down(256, 512, res_down=res_down)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor, res_down=res_down)

        ### Residual blocks
        resblocks = []
        for i in range(n_resblocks):
            resblocks += [ResnetBlock(1024 // factor, padding_type, norm_layer, use_dropout, use_bias)]
        self.resblocks = nn.Sequential(*resblocks)

        ### Decoder
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
        self.up = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.up_dc = DoubleConv(64, 64)
        self.drop = nn.Dropout(p=0.3)

    def forward(self, x):
        x1 = self.inc(x) #256
        x2 = self.down1(x1) #512
        x3 = self.down2(x2) #1024
        x4 = self.down3(x3) #2048
        x5 = self.down4(x4) #4096

        x5 = self.resblocks(x5)

        xp1 = self.up1(x5, x4) #2048
        xp2 = self.up2(xp1, x3) #1024
        xp3 = self.up3(xp2, x2) #512
        xp4 = self.up4(xp3, x1) #256
        # x = self.up(x)
        # x = self.up_dc(x)
        logits = self.outc(xp4)+x
        return logits