In [1]:
import random
from functools import partial
import matplotlib
import matplotlib.image as mpimg
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import os
from google.colab import drive
!pip install segmentation_models_pytorch
drive.mount('/content/gdrive')
import seaborn as sns
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import Unet, UnetPlusPlus, DeepLabV3, PSPNet
import segmentation_models_pytorch
!pip install torchmetrics
from torchmetrics.functional import f1_score

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.3.4-py3-none-any.whl.metadata (30 kB)
Collecting efficientnet-pytorch==0.7.1 (from segmentation_models_pytorch)
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pretrainedmodels==0.7.4 (from segmentation_models_pytorch)
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting timm==0.9.7 (from segmentation_models_pytorch)
  Downloading timm-0.9.7-py3-none-any.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting munch (from pretrainedmodels==0.7.4->segmentation_models_pytorch)
  Downloading munch-4.0.0-py2.py3-none-any.whl.metadata (5.9 kB)
Downloading segm

In [2]:
###
###
# Paths
root_dir = "data/training/"
image_dir = root_dir + "images_extended_4_shadow_patches/"
gt_dir = root_dir + "groundtruth_extended_4_shadow_patches/"

GRADIENT_COLORS = False
if GRADIENT_COLORS:
  test_image_dir = "data/test_images_extended/"
else:
  test_image_dir = "data/test_set_images/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
###
###
# Model related methods
class PreloadedDataset(Dataset):
    def __init__(self, images, ground_truth):
        self.images = images
        self.ground_truth = ground_truth

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        ground_truth = self.ground_truth[idx]
        return image, ground_truth

def init_weights(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

def split(dataset, tr,val):
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [tr, val])
    return train_dataset, val_dataset

def get_dataloaders(train_dataset, val_dataset, batch_size):
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, val_loader


###
###
# Image handling

def load_npy(npy_file_path):
  array = np.load(npy_file_path)
  return array

def load_image(infilename):
    data = mpimg.imread(infilename)
    return data

def get_image_name(image_dir):
    file_name = os.path.basename(image_dir)
    return file_name

def img_crop(im, w, h):
    list_patches = []
    imgwidth = im.shape[0]
    imgheight = im.shape[1]

    is_2d = len(im.shape) < 3
    for i in range(0, imgheight, h):
        for j in range(0, imgwidth, w):
            if is_2d:
                im_patch = im[j : j + w, i : i + h]
            else:
                im_patch = im[j : j + w, i : i + h, :]
            list_patches.append(im_patch)
    return list_patches

def predict_patches(prediction, patch_threshold):
    patches = img_crop(prediction, 16, 16)
    iter = int(np.sqrt(len(patches)))

    results = np.zeros((iter, iter))
    c = 0
    for j in range(iter):
      for i in range(iter):
        if patches[c].mean() > patch_threshold:
          results[i, j] = 1
        c = c + 1
    return results

def extract_number(folder_name):
    return int(folder_name.split('_')[1])

###
###
# Submission generation

def masks_to_submission(submission_filename, results):
    """Converts images into a submission file"""
    img_number = 0
    with open(submission_filename, 'w') as f:
        f.write('id,prediction\n')
        for res in results:
            img_number = img_number + 1
            f.writelines('{}\n'.format(s) for s in mask_to_submission_strings(res, img_number))

def mask_to_submission_strings(result, img_number):
    """Reads a single image and outputs the strings that should go into the submission file"""
    im = result
    patch_size = 16
    for j in range(0, im.shape[1]):
        for i in range(0, im.shape[0]):
            label = result[i, j]

            yield("{:03d}_{}_{},{}".format(img_number, j*16, i*16, int(label)))

def write_predictions_to_file(predictions, labels, filename):
    max_labels = np.argmax(labels, 1)
    max_predictions = np.argmax(predictions, 1)
    file = open(filename, "w")
    n = predictions.shape[0]
    for i in range(0, n):
        file.write(max_labels(i) + " " + max_predictions(i))
    file.close()

def remove_islands(prediction):
  prediction = torch.tensor(prediction)
  kernel = torch.ones((1, 1, 3, 3), dtype=prediction.dtype)
  pred = torch.unsqueeze(prediction, dim=0)
  neighbor_count = F.conv2d(pred, kernel, padding=1)
  mask = (prediction == 1) & (neighbor_count == 1)
  prediction = prediction.squeeze()
  mask = mask.reshape((38,38))
  prediction[mask] = 0
  prediction = prediction.reshape((38,38))
  return prediction

def apply_crf(image, prob_map, sxy1=3, compat1=25, sxy2=50, srgb=30, compat2=10):

  # CRF stuff starts
  input_image = image.cpu().detach().numpy().copy() * 255
  input_image = input_image.astype('uint8')
  H,W,_ = input_image.shape

  # Create the DenseCRF model
  d = dcrf.DenseCRF2D(W, H, 2)  # W x H and 2 classes (road, non-road)

  unary = unary_from_softmax(prob_map)
  d.setUnaryEnergy(unary)
  d.addPairwiseGaussian(sxy=3, compat=25)
  d.addPairwiseBilateral(sxy=50, srgb=30, rgbim=input_image, compat=10)

  # Perform CRF inference
  # Produces a refined probability map with the same shape as input
  max_iter = 10  # Number of CRF iterations
  Q = d.inference(max_iter)

  crf_output = np.array(Q).reshape((2, H, W))
  crf_output = crf_output[1]  # Extract the "road" class probabilities

  final_segmentation = (crf_output > 0.5).astype(np.uint8)  # Binary mask

  return final_segmentation

###
###
# Metrics

def compute_f1(res, ans):
    """
    Precision: tp/(tp+fp)
    Recall: tp/(tp+fn)
    F1 = 2 * (precision * recall) / (precision + recall)
    """

    tp = torch.sum((res == 1) & (ans == 1))
    fp = torch.sum((res == 1) & (ans != 1))
    fn = torch.sum((res != 1) & (ans == 1))

    precision = tp / (tp + fp) if (tp + fp) != 0 else torch.tensor(0.0)
    recall = tp / (tp + fn) if (tp + fn) != 0 else torch.tensor(0.0)
    if precision + recall == 0:
        return torch.tensor(0.0)
    else:
        return 2 * (precision * recall) / (precision + recall)


In [4]:
model = UnetPlusPlus(
      encoder_name="resnet50",        # Choose encoder
      encoder_weights="imagenet",    # Use pre-trained ImageNet weights
      classes=1,                     # Number of output classes
      activation=None               # No activation, as it's handled in loss/metrics
)

model2 = DeepLabV3(
      encoder_name="resnet50",        # Choose encoder
      encoder_weights="imagenet",    # Use pre-trained ImageNet weights
      classes=1,                     # Number of output classes
      activation=None               # No activation, as it's handled in loss/metrics
)
model.load_state_dict(torch.load("/content/gdrive/MyDrive/ML/predictions/U_NET++_1_base_case_100_im.pth", map_location=torch.device('cpu')))
model.load_state_dict(torch.load("/content/gdrive/MyDrive/ML/predictions/DEEPLAB_1_base_case_100_im.pth", map_location=torch.device('cpu')))

model.eval()

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 258MB/s]
  model.load_state_dict(torch.load("/content/gdrive/MyDrive/ML/predictions/U_NET++_1_base_case_100_im.pth", map_location=torch.device('cpu')))
  model.load_state_dict(torch.load("/content/gdrive/MyDrive/ML/predictions/DEEPLAB_1_base_case_100_im.pth", map_location=torch.device('cpu')))


RuntimeError: Error(s) in loading state_dict for UnetPlusPlus:
	Missing key(s) in state_dict: "decoder.blocks.x_0_0.conv1.0.weight", "decoder.blocks.x_0_0.conv1.1.weight", "decoder.blocks.x_0_0.conv1.1.bias", "decoder.blocks.x_0_0.conv1.1.running_mean", "decoder.blocks.x_0_0.conv1.1.running_var", "decoder.blocks.x_0_0.conv2.0.weight", "decoder.blocks.x_0_0.conv2.1.weight", "decoder.blocks.x_0_0.conv2.1.bias", "decoder.blocks.x_0_0.conv2.1.running_mean", "decoder.blocks.x_0_0.conv2.1.running_var", "decoder.blocks.x_0_1.conv1.0.weight", "decoder.blocks.x_0_1.conv1.1.weight", "decoder.blocks.x_0_1.conv1.1.bias", "decoder.blocks.x_0_1.conv1.1.running_mean", "decoder.blocks.x_0_1.conv1.1.running_var", "decoder.blocks.x_0_1.conv2.0.weight", "decoder.blocks.x_0_1.conv2.1.weight", "decoder.blocks.x_0_1.conv2.1.bias", "decoder.blocks.x_0_1.conv2.1.running_mean", "decoder.blocks.x_0_1.conv2.1.running_var", "decoder.blocks.x_1_1.conv1.0.weight", "decoder.blocks.x_1_1.conv1.1.weight", "decoder.blocks.x_1_1.conv1.1.bias", "decoder.blocks.x_1_1.conv1.1.running_mean", "decoder.blocks.x_1_1.conv1.1.running_var", "decoder.blocks.x_1_1.conv2.0.weight", "decoder.blocks.x_1_1.conv2.1.weight", "decoder.blocks.x_1_1.conv2.1.bias", "decoder.blocks.x_1_1.conv2.1.running_mean", "decoder.blocks.x_1_1.conv2.1.running_var", "decoder.blocks.x_0_2.conv1.0.weight", "decoder.blocks.x_0_2.conv1.1.weight", "decoder.blocks.x_0_2.conv1.1.bias", "decoder.blocks.x_0_2.conv1.1.running_mean", "decoder.blocks.x_0_2.conv1.1.running_var", "decoder.blocks.x_0_2.conv2.0.weight", "decoder.blocks.x_0_2.conv2.1.weight", "decoder.blocks.x_0_2.conv2.1.bias", "decoder.blocks.x_0_2.conv2.1.running_mean", "decoder.blocks.x_0_2.conv2.1.running_var", "decoder.blocks.x_1_2.conv1.0.weight", "decoder.blocks.x_1_2.conv1.1.weight", "decoder.blocks.x_1_2.conv1.1.bias", "decoder.blocks.x_1_2.conv1.1.running_mean", "decoder.blocks.x_1_2.conv1.1.running_var", "decoder.blocks.x_1_2.conv2.0.weight", "decoder.blocks.x_1_2.conv2.1.weight", "decoder.blocks.x_1_2.conv2.1.bias", "decoder.blocks.x_1_2.conv2.1.running_mean", "decoder.blocks.x_1_2.conv2.1.running_var", "decoder.blocks.x_2_2.conv1.0.weight", "decoder.blocks.x_2_2.conv1.1.weight", "decoder.blocks.x_2_2.conv1.1.bias", "decoder.blocks.x_2_2.conv1.1.running_mean", "decoder.blocks.x_2_2.conv1.1.running_var", "decoder.blocks.x_2_2.conv2.0.weight", "decoder.blocks.x_2_2.conv2.1.weight", "decoder.blocks.x_2_2.conv2.1.bias", "decoder.blocks.x_2_2.conv2.1.running_mean", "decoder.blocks.x_2_2.conv2.1.running_var", "decoder.blocks.x_0_3.conv1.0.weight", "decoder.blocks.x_0_3.conv1.1.weight", "decoder.blocks.x_0_3.conv1.1.bias", "decoder.blocks.x_0_3.conv1.1.running_mean", "decoder.blocks.x_0_3.conv1.1.running_var", "decoder.blocks.x_0_3.conv2.0.weight", "decoder.blocks.x_0_3.conv2.1.weight", "decoder.blocks.x_0_3.conv2.1.bias", "decoder.blocks.x_0_3.conv2.1.running_mean", "decoder.blocks.x_0_3.conv2.1.running_var", "decoder.blocks.x_1_3.conv1.0.weight", "decoder.blocks.x_1_3.conv1.1.weight", "decoder.blocks.x_1_3.conv1.1.bias", "decoder.blocks.x_1_3.conv1.1.running_mean", "decoder.blocks.x_1_3.conv1.1.running_var", "decoder.blocks.x_1_3.conv2.0.weight", "decoder.blocks.x_1_3.conv2.1.weight", "decoder.blocks.x_1_3.conv2.1.bias", "decoder.blocks.x_1_3.conv2.1.running_mean", "decoder.blocks.x_1_3.conv2.1.running_var", "decoder.blocks.x_2_3.conv1.0.weight", "decoder.blocks.x_2_3.conv1.1.weight", "decoder.blocks.x_2_3.conv1.1.bias", "decoder.blocks.x_2_3.conv1.1.running_mean", "decoder.blocks.x_2_3.conv1.1.running_var", "decoder.blocks.x_2_3.conv2.0.weight", "decoder.blocks.x_2_3.conv2.1.weight", "decoder.blocks.x_2_3.conv2.1.bias", "decoder.blocks.x_2_3.conv2.1.running_mean", "decoder.blocks.x_2_3.conv2.1.running_var", "decoder.blocks.x_3_3.conv1.0.weight", "decoder.blocks.x_3_3.conv1.1.weight", "decoder.blocks.x_3_3.conv1.1.bias", "decoder.blocks.x_3_3.conv1.1.running_mean", "decoder.blocks.x_3_3.conv1.1.running_var", "decoder.blocks.x_3_3.conv2.0.weight", "decoder.blocks.x_3_3.conv2.1.weight", "decoder.blocks.x_3_3.conv2.1.bias", "decoder.blocks.x_3_3.conv2.1.running_mean", "decoder.blocks.x_3_3.conv2.1.running_var", "decoder.blocks.x_0_4.conv1.0.weight", "decoder.blocks.x_0_4.conv1.1.weight", "decoder.blocks.x_0_4.conv1.1.bias", "decoder.blocks.x_0_4.conv1.1.running_mean", "decoder.blocks.x_0_4.conv1.1.running_var", "decoder.blocks.x_0_4.conv2.0.weight", "decoder.blocks.x_0_4.conv2.1.weight", "decoder.blocks.x_0_4.conv2.1.bias", "decoder.blocks.x_0_4.conv2.1.running_mean", "decoder.blocks.x_0_4.conv2.1.running_var". 
	Unexpected key(s) in state_dict: "decoder.0.convs.0.0.weight", "decoder.0.convs.0.1.weight", "decoder.0.convs.0.1.bias", "decoder.0.convs.0.1.running_mean", "decoder.0.convs.0.1.running_var", "decoder.0.convs.0.1.num_batches_tracked", "decoder.0.convs.1.0.weight", "decoder.0.convs.1.1.weight", "decoder.0.convs.1.1.bias", "decoder.0.convs.1.1.running_mean", "decoder.0.convs.1.1.running_var", "decoder.0.convs.1.1.num_batches_tracked", "decoder.0.convs.2.0.weight", "decoder.0.convs.2.1.weight", "decoder.0.convs.2.1.bias", "decoder.0.convs.2.1.running_mean", "decoder.0.convs.2.1.running_var", "decoder.0.convs.2.1.num_batches_tracked", "decoder.0.convs.3.0.weight", "decoder.0.convs.3.1.weight", "decoder.0.convs.3.1.bias", "decoder.0.convs.3.1.running_mean", "decoder.0.convs.3.1.running_var", "decoder.0.convs.3.1.num_batches_tracked", "decoder.0.convs.4.1.weight", "decoder.0.convs.4.2.weight", "decoder.0.convs.4.2.bias", "decoder.0.convs.4.2.running_mean", "decoder.0.convs.4.2.running_var", "decoder.0.convs.4.2.num_batches_tracked", "decoder.0.project.0.weight", "decoder.0.project.1.weight", "decoder.0.project.1.bias", "decoder.0.project.1.running_mean", "decoder.0.project.1.running_var", "decoder.0.project.1.num_batches_tracked", "decoder.1.weight", "decoder.2.weight", "decoder.2.bias", "decoder.2.running_mean", "decoder.2.running_var", "decoder.2.num_batches_tracked". 
	size mismatch for segmentation_head.0.weight: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 16, 3, 3]).