In [1]:
import torch
import os
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import cv2
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
import torchvision
from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import time
import numpy as np
import random
from os.path import exists
from tqdm import tqdm



plt.ion()   # interactive mode

<contextlib.ExitStack at 0x7f126c8ef9d0>

In [2]:
# Access to uploaded files
from google.colab import drive
dir_prefix = 'drive/My Drive/Colab Notebooks/Dipl_Projekt'
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# base path of the dataset
DATASET_PATH = os.path.join(dir_prefix, 'ForestDataset8C')
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False

# define the number of channels in the input, number of classes,
# and number of levels in the U-Net model
NUM_CHANNELS = 8
NUM_CLASSES = 13
N_ROWS = 40
N_COLS = 40

# UNET layers
ENC_CHANNELS = (NUM_CHANNELS, 16, 32, 64, 128)
DEC_CHANNELS = (128, 64, 32, 16)

# define the input image dimensions
INPUT_IMAGE_WIDTH = 512 # The images will be resized to this before getting fed into the model
INPUT_IMAGE_HEIGHT = 512 # The images will be resized to this before getting fed into the model


# define the path to the base output directory
BASE_OUTPUT = os.path.join(dir_prefix, 'ForestSegmentationOutput')

_MODEL_NAME = '3percent_smoothly_weighted_more_complex_unet_nks_13_class_forest_59+11_epoch_0.001_lr_0.0001_wd'
PRETRAINED_MODEL_PATH = os.path.join(BASE_OUTPUT, f'{_MODEL_NAME}.pth')
RECONSTRUCTED_MASK_PATH = os.path.join(BASE_OUTPUT, 'Reconstructed_Masks', f'New_mask_from_{_MODEL_NAME}.tif')


In [31]:
def slice_coords(path):
		path_no_ext = os.path.splitext(path)[0]
		coords = os.path.basename(path_no_ext).split('_')[-2:]
		return tuple([int(x) for x in coords])

class ForestSegmentationDataset(Dataset):
    def __init__(self,
                 channels, # list so it defines order of cjannels since dict doesnt
                 imagePaths,
                 maskPaths,
                 transforms_img,
                 transforms_mask):
        # store the image and mask filepaths, and augmentation
        # transforms
        num_of_imgs = len(list(imagePaths.values())[0])
        for cname in channels:
            assert (cname in imagePaths) and (len(imagePaths[cname]) == num_of_imgs),'image paths for different channels of different length OR channel not present in imagePaths dict!!!'
            for imgpath in imagePaths[cname]:
                assert exists(imgpath), f"{imgpath} does not exist!"
        
        for mpath in maskPaths:
            assert exists(mpath), f"{mpath} does not exist!"

        self.channels = channels
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.transforms_img = transforms_img
        self.transforms_mask = transforms_mask

    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(list(self.imagePaths.values())[0])
    
    def __getitem__(self, idx):
        channels_to_stack = [
            cv2.imread(
                self.imagePaths[cname][idx], 
                cv2.IMREAD_UNCHANGED
            ) for cname in self.channels
        ]
        image = np.stack(channels_to_stack) # This adds extra dimension at the beginning investigate if this does anything weird
        image = torch.tensor(np.float32(image / 255.0))
        mask = cv2.imread(self.maskPaths[idx], cv2.IMREAD_UNCHANGED)
        mask[mask < 0] = 0 # Background class should be zero not -3.4e+38
        mask = np.float32(mask)
        # check to see if we are applying any transformations
        # apply the transformations to both image and its mask
        if self.transforms_img is not None:
          image = self.transforms_img(image)

        if self.transforms_mask is not None:
          mask = self.transforms_mask(mask)
          mask = mask.to(torch.long).squeeze()
        # return a tuple of the image and its mask
        return (image, mask)
    
    def get_coords(self, idx):
        img_paths = [self.imagePaths[cname][idx] for cname in self.channels]
        coords = slice_coords(img_paths[0])
        for p in img_paths:
            assert coords == slice_coords(p), f"Coords do not match for all images on index {idx}"
        return coords


In [7]:
class Block(Module):
    def __init__(self, inChannels, outChannels):
        super().__init__()
        # store the convolution and RELU layers
        self.conv1 = Conv2d(inChannels, outChannels, 3)
        self.relu = ReLU()
        self.conv2 = Conv2d(outChannels, outChannels, 3)
    
    def forward(self, x):
        # apply CONV => RELU => CONV block to the inputs and return it
        return self.conv2(self.relu(self.conv1(x)))
  
class Encoder(Module):
    def __init__(self, channels=ENC_CHANNELS):
        super().__init__()
        # store the encoder blocks and maxpooling layer
        self.encBlocks = ModuleList(
          [Block(channels[i], channels[i + 1])
            for i in range(len(channels) - 1)])
        self.pool = MaxPool2d(2)
    def forward(self, x):
        # initialize an empty list to store the intermediate outputs
        blockOutputs = []
        # loop through the encoder blocks
        for block in self.encBlocks:
            # pass the inputs through the current encoder block, store
            # the outputs, and then apply maxpooling on the output
            x = block(x)
            blockOutputs.append(x)
            x = self.pool(x)
        # return the list containing the intermediate outputs
        return blockOutputs

class Decoder(Module):
    def __init__(self, channels=DEC_CHANNELS):
        super().__init__()
        # initialize the number of channels, upsampler blocks, and
        # decoder blocks
        self.channels = channels
        self.upconvs = ModuleList(
          [ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
            for i in range(len(channels) - 1)])
        self.dec_blocks = ModuleList(
          [Block(channels[i], channels[i + 1])
            for i in range(len(channels) - 1)])
    
    def forward(self, x, encFeatures):
        # loop through the number of channels
        for i in range(len(self.channels) - 1):
            # pass the inputs through the upsampler blocks
            x = self.upconvs[i](x)
            # crop the current features from the encoder blocks,
            # concatenate them with the current upsampled features,
            # and pass the concatenated output through the current
            # decoder block
            encFeat = self.crop(encFeatures[i], x)
            x = torch.cat([x, encFeat], dim=1)
            x = self.dec_blocks[i](x)
        # return the final decoder output
        return x
    
    def crop(self, encFeatures, x):
        # grab the dimensions of the inputs, and crop the encoder
        # features to match the dimensions
        (_, _, H, W) = x.shape
        encFeatures = CenterCrop([H, W])(encFeatures)
        # return the cropped features
        return encFeatures

class UNet(Module):
    def __init__(self, encChannels=ENC_CHANNELS,
        decChannels=DEC_CHANNELS,
        nbClasses=NUM_CLASSES, retainDim=True,
        outSize=(INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH)):
        super().__init__()
        # initialize the encoder and decoder
        self.encoder = Encoder(encChannels)
        self.decoder = Decoder(decChannels)
        # initialize the regression head and store the class variables
        self.head = Conv2d(decChannels[-1], nbClasses, 1)
        self.retainDim = retainDim
        self.outSize = outSize

    def forward(self, x):
        # grab the features from the encoder
        encFeatures = self.encoder(x)
        # pass the encoder features through decoder making sure that
        # their dimensions are suited for concatenation
        decFeatures = self.decoder(encFeatures[::-1][0],
          encFeatures[::-1][1:])
        # pass the decoder features through the regression head to
        # obtain the segmentation mask
        map = self.head(decFeatures)
        # check to see if we are retaining the original output
        # dimensions and if so, then resize the output to match them
        if self.retainDim:
            map = F.interpolate(map, self.outSize)
        # return the segmentation map
        return map

In [32]:
def load_imgs():
		Imgs = dict()
		for i in range(8):
				channel_path = os.path.join(DATASET_PATH, f'wv2_{i}')
				Imgs[i] = [
						os.path.join(channel_path, img) 
						for img in sorted(os.listdir(channel_path))
				]
		maskPaths = [os.path.join(DATASET_PATH, 'nks', 'nks_split_0_0.tif')] * len(Imgs[0])
		for k in Imgs:
				assert len(Imgs[0]) == len(Imgs[k]), f"List of paths of diffrenet length for differen channels, problematic channel: {k}"
		for i in range(len(Imgs[0])):
				assert all([slice_coords(Imgs[0][i]) == slice_coords(Imgs[k][i]) for k in Imgs]), "Slices out of order for different channels/mask of dataset"
		return Imgs, maskPaths

In [27]:
# There were some duplicates in certain folders with ...(1).tif kind of names so I removed any files with a parenthesis
for i in range(8):
    channel_path = os.path.join(DATASET_PATH, f'wv2_{i}')
    for path in os.listdir(channel_path):
        if '(' in path: os.remove(os.path.join(channel_path, path))

In [33]:
Images_dataset, Masks_dataset = load_imgs()
# define transformations
img_transforms = transforms.Compose([
   # transforms.ToPILImage(),
 	  transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH), antialias=True),
	 # transforms.ToTensor()
])

mask_transforms = transforms.Compose([
    transforms.ToPILImage(),
 	  transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH), interpolation=torchvision.transforms.InterpolationMode.NEAREST),
	  transforms.ToTensor()
])

#create the train and test datasets
DS = ForestSegmentationDataset(
		channels=list(np.arange(8)),
		imagePaths=Images_dataset,
		maskPaths=Masks_dataset,
		transforms_img=img_transforms, 
		transforms_mask=mask_transforms
		)
print(f"[INFO] found {len(DS)} examples in the dataset...")

[INFO] found 1600 examples in the dataset...


In [34]:
# initialize our UNet model
print(f"Loading pretrained model {PRETRAINED_MODEL_PATH}")
unet = torch.load(PRETRAINED_MODEL_PATH).to(DEVICE)

Loading pretrained model drive/My Drive/Colab Notebooks/Dipl_Projekt/ForestSegmentationOutput/3percent_smoothly_weighted_more_complex_unet_nks_13_class_forest_59+11_epoch_0.001_lr_0.0001_wd.pth


In [44]:
reconstructed_mask = np.zeros((N_ROWS * INPUT_IMAGE_HEIGHT, N_COLS * INPUT_IMAGE_WIDTH))

unet.eval()
with torch.no_grad():
    for i in tqdm(range(len(DS))):
        (input, target) = DS[i]
        (input, target) = (input[None, :, :, :], target[None, :, :])
        (input, target) = (input.to(DEVICE), target.to(DEVICE))
        pred = unet(input)
        pred = pred[0]
        pred_class = pred.softmax(dim=0).argmax(dim=0).cpu().numpy()
        coords = DS.get_coords(i)
        (i1, i2) = (coords[0] * INPUT_IMAGE_HEIGHT, (coords[0]+1) * INPUT_IMAGE_HEIGHT)
        (j1, j2) = (coords[1] * INPUT_IMAGE_WIDTH, (coords[1]+1) * INPUT_IMAGE_WIDTH)
        reconstructed_mask[i1:i2, j1:j2] = pred_class

100%|██████████| 1600/1600 [47:15<00:00,  1.77s/it]


In [45]:
print(reconstructed_mask.shape)

(20480, 20480)


In [46]:
reconstructed_mask

array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]])

In [47]:
cv2.imwrite(RECONSTRUCTED_MASK_PATH,reconstructed_mask)

True

In [4]:
# Testing whether write worked
loaded_reconstructed = cv2.imread(RECONSTRUCTED_MASK_PATH, cv2.IMREAD_UNCHANGED)
#print(f'Loaded equal to reconstructed: {(loaded_reconstructed == reconstructed_mask).all()}')

In [8]:
unique, counts = np.unique(loaded_reconstructed, return_counts=True)


[2.83672810e-04 3.37331152e-02 7.16998315e-02 7.24090338e-03
 9.48327136e-02 1.35889366e-01 4.82787879e-01 6.66609049e-02
 1.06641653e-01 2.29961872e-04]
0.0: 0.028367280960083008%
1.0: 3.3733115196228027%
2.0: 7.169983148574829%
3.0: 0.7240903377532959%
4.0: 9.483271360397339%
7.0: 13.58893656730652%
8.0: 48.27878785133362%
9.0: 6.666090488433837%
10.0: 10.664165258407593%
11.0: 0.022996187210083008%


In [10]:

for i,u in enumerate(unique):
    print(f"{u}: {((counts/counts.sum())[i] * 100):.{2}f}%")

0.0: 0.03%
1.0: 3.37%
2.0: 7.17%
3.0: 0.72%
4.0: 9.48%
7.0: 13.59%
8.0: 48.28%
9.0: 6.67%
10.0: 10.66%
11.0: 0.02%


In [1]:
from google.colab import runtime


In [None]:
runtime.unassign()