In [None]:
! unzip -q '/content/drive/MyDrive/temp/TestDatav2.zip' -d './data'

In [None]:
import os

from PIL import Image
import numpy as np
import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.optim as optim
import torch.nn.functional as F

from torch import nn
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.io import read_image, ImageReadMode
from torch.utils.data import DataLoader, Dataset, random_split, ConcatDataset

In [None]:
# Loading Dataset
class TreeSegmentationDataset(Dataset):
    def __init__(self, transform=None, target_transform=None):
        
        self.BASEDIR = './data'
        self.FOLDER = f"{self.BASEDIR}/TestDatav2"
        self.image_idx = os.listdir(f"{self.FOLDER}/Images/")
        self.n_images = len(self.image_idx)
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.n_images

    def __getitem__(self, idx):
        image_name = self.image_idx[idx]
        img = read_image(f"{self.FOLDER}/Images/{image_name}", mode=ImageReadMode.RGB) / 255 

        seg_with_alpha = read_image(f"{self.FOLDER}/Seg/{image_name}", mode=ImageReadMode.GRAY_ALPHA) / 255
        seg = torch.where(seg_with_alpha[1] > 0.5, torch.tensor(1), torch.tensor(0)).unsqueeze(0)

        if self.transform:
            img = self.transform(img)
            seg = self.transform(seg)

        if self.target_transform:
            seg = self.target_transform(seg)
            
        return img, seg

In [None]:

BATCH_SIZE = 1
SHUFFLE=False

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

if DEVICE == 'cuda':
    device_name = torch.cuda.get_device_name(0)
    print(f"Setting compute device to {DEVICE} on {device_name}")

Setting compute device to cuda on Tesla T4


In [None]:
def visualize(**images):
    n_images = len(images)
    plt.figure(figsize=(10,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.axis(False)
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image.permute(1,2,0), cmap='gray')
    plt.show()

In [None]:

class FuzzyPooling(nn.Module):
    def __init__(self, kernel_size, stride):
        super(FuzzyPooling, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.t_norm = torch.max # using minimum t-norm as an example

    def forward(self, x):
        batch_size, channels, width, height = x.shape

        # Apply fuzzy pooling
        x = x.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
        out = self.t_norm(x, dim=4)[0]
        out = self.t_norm(out, dim=4)[0]
        out = nn.functional.max_pool2d(out, kernel_size=(1, 1), stride=(1, 1))

        return out

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()
    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x


class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = FuzzyPooling(2, 2)
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c, kernal_size=2):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=kernal_size, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)
    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)
        self.b = conv_block(512, 1024)
        self.d1 = decoder_block(1024, 512, kernal_size=3)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
    def forward(self, inputs):
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)
        b = self.b(p4)
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)
        outputs = self.outputs(d4)
        return outputs

In [None]:
# Loss Function 
class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, act):
        loss_fn = nn.MSELoss()
        loss = torch.sqrt(loss_fn(pred, act))

        return loss

In [None]:
def iou_score(mask1, mask2):
    intersection = torch.logical_and(mask1, mask2).sum()
    union = torch.logical_or(mask1, mask2).sum()
    iou = intersection.float() / union.float()
    
    return iou.item()

In [None]:
model = torch.load('/content/drive/MyDrive/Colab Notebooks/models/Trees/FuzzyUnetOnDatasetv4_8bf15.pth') # FUZZY 

In [None]:
# model = torch.load('/content/drive/MyDrive/Colab Notebooks/models/Trees/UnetOnDatasetv4_399f8.pth') # NORMAL

In [None]:
dataset = TreeSegmentationDataset()
# _, test_data = random_split(dataset, [0.5, 0.5], generator=torch.Generator().manual_seed(42))

In [None]:
test_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
len(test_loader)

2000

In [None]:
total_score = 0
n_dp = 1
for idx, (image, seg) in tqdm(enumerate(test_loader), total=len(test_loader)):
    if torch.sum(seg[0]) < 10000: continue # Filter out empty or very sparse segmentation masks
    
    # Make predictions from the model
    pred_raw = model(image.to(DEVICE))
    pred_raw = pred_raw.cpu().detach()

    # Threshold the prediction 
    thres = torch.quantile(pred_raw[0], 0.55)
    pred = torch.where(pred_raw[0] > thres, torch.tensor(1), torch.tensor(0))
    # Compute IOU
    score = iou_score(seg[0], pred)
    total_score += score
    n_dp += 1

    
print(f"Average Score of {n_dp} enteries= {total_score/n_dp}")

  0%|          | 0/2000 [00:00<?, ?it/s]

Average Score of 473 enteries= 0.6077072543680794


## Get images

In [None]:
count = 0
for idx, (image, seg) in tqdm(enumerate(test_loader), total=len(test_loader)):
    if torch.sum(seg[0]) < 10000: continue # Filter out empty or very sparse segmentation masks
    # Make predictions from the model
    pred_raw = model(image.to(DEVICE))
    pred_raw = pred_raw.cpu().detach()

    # Threshold the prediction 
    thres = torch.quantile(pred_raw[0], 0.75)
    pred = torch.where(pred_raw[0] > thres, torch.tensor(1), torch.tensor(0))
    # Visualize
    visualize(Image=image[0], Seg=seg[0], Pred=pred)
    if count > 30:
        break

    count += 1
    