In [102]:
import os
import sys
import datetime
import numpy as np
import matplotlib.pyplot as plt
import random

from tqdm import tqdm
from PIL import Image

In [103]:
data_root = "../dataset"
images_root = os.path.join(data_root, "images_all")

train_data_dist = os.path.join(data_root, "train_data.csv")
test_data_dist = os.path.join(data_root, "test_data.csv")

In [104]:
train_data = pd.read_csv(train_data_dist)
test_data = pd.read_csv(test_data_dist)

In [105]:
scripts_path = "../scripts"

In [106]:
sys.path.append(scripts_path)

In [107]:
import torch
import constants as const

from cls_train_utils import *
from metrics import get_iou_metric

In [None]:
def interval_mapping(image, from_min, from_max, to_min, to_max):
    from_range = from_max - from_min
    to_range = to_max - to_min
    scaled = np.array((image - from_min) / float(from_range), dtype=float)
    return to_min + (scaled * to_range)

In [108]:
train_dataset = MelanomaClassificationDataset(csv_file = train_data, 
                                              root_dir = images_root,
                                              augmentation = None,
                                              preprocessing = MelanomaClassificationDataset.get_default_preprocessing())


test_dataset = MelanomaClassificationDataset(csv_file = test_data, 
                                             root_dir = images_root,
                                             augmentation = None,
                                             preprocessing = MelanomaClassificationDataset.get_default_preprocessing())

In [109]:
segmentation_model = torch.load("../models/segmentation_model_xception_backbone.pth")
segmentation_model.eval()

FPN(
  (encoder): XceptionEncoder(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (block1): Block(
      (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (skipbn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (rep): Sequential(
        (0): SeparableConv2d(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [136]:
im_index = 0

for image, label in tqdm(train_dataset):
    
    image_name = train_data.iloc[im_index]["name"]
    
    # Obtain a binary mask from the model
    mask = segmentation_model(image.unsqueeze(0).cuda())
    
    # Process mask from probability to binary
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = mask.cpu().detach().numpy()[0, 0, :, :]
    
    image = np.array(Image.open(os.path.join(images_root, image_name)))
    
    # Apply mask on image
    idx = (mask == 0)
    image[idx] = 0

    image_pil = Image.fromarray(image)
    image_pil.save(os.path.join("../dataset/images_all_processed", image_name))
    
    im_index += 1

100%|██████████| 5538/5538 [20:39<00:00,  4.47it/s]


In [137]:
im_index = 0

for image, label in tqdm(test_dataset):
    
    image_name = test_data.iloc[im_index]["name"]
    
    # Obtain a binary mask from the model
    mask = segmentation_model(image.unsqueeze(0).cuda())
    
    # Process mask from probability to binary
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = mask.cpu().detach().numpy()[0, 0, :, :]
    
    image = np.array(Image.open(os.path.join(images_root, image_name)))
    
    # Apply mask on image
    idx = (mask == 0)
    image[idx] = 0

    image_pil = Image.fromarray(image)
    image_pil.save(os.path.join("../dataset/images_all_processed", image_name))
    
    im_index += 1

100%|██████████| 1385/1385 [05:11<00:00,  4.44it/s]
