In [8]:
def setup_file_system(in_colab):
    if in_colab:
        from google.colab import drive

        # Set the base and mount path
        MOUNT_PATH_DRIVE = '/content/drive'
        BASE_PATH = join(
            MOUNT_PATH_DRIVE, 
            "MyDrive/barco_skin_lesion_classification"
        )

        # Mount the google drive
        drive.mount(MOUNT_PATH_DRIVE)

        return BASE_PATH

    else:
        return "/workspaces/barco_skin_lesion_classification"

In [9]:
import sys
from os import chdir
from os.path import join

# Method to check if the notebook is running in colab or local
IN_COLAB = 'google.colab' in sys.modules

# Set the base path of the project
BASE_PATH = setup_file_system(IN_COLAB)

# Set the base path of the project
chdir(join(BASE_PATH, "src/"))

In [10]:
# Import libraries
from os.path import join
from tqdm import tqdm
import wandb


# DL libraries
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image



# User libraries
from util import config, model_management
from models.unet_model import UNet
from datasets.file_and_name_dataset import FileAndNameDataset

In [11]:
# Get the data
dataset = FileAndNameDataset(
    join(BASE_PATH, config.CLASSIFICATION_DATA_PATH_TRAIN_FEATURES),
    config.SEGMENTATION_RUN_TRANSFORMATIONS,
    )

# Place the data in a dataloader
dataloader = DataLoader(dataset, batch_size=1)


In [12]:
# Set the base export paths
BASE_SEGMENTED_EXPORT_PATH = join(BASE_PATH, config.CLASSIFICATION_DATA_PATH_TRAIN_SEGMENTED_FEATURES)
BASE_UNSEGMENTED_EXPORT_PATH = join(BASE_PATH, config.CLASSIFICATION_DATA_PATH_TRAIN_UNSEGMENTED_FEATURES)

# Start wandb
wandb.init(
    settings=wandb.Settings(start_method="fork"),
    project="segmentation", 
)

In [13]:
# Get the model
model = UNet(n_channels = 3, n_classes = 1)
model.to(config.DEVICE)

# Get and set the saved model parameters
model_parameters = model_management.get_artifact_model_weights()
test = torch.load(model_parameters)
model.load_state_dict(torch.load(model_parameters))

# Place model in eval mode
model.eval()

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [14]:
# run the model over the dataloader
loop = tqdm(dataloader, leave=True)
for idx, (original_image, input, image_name) in enumerate(loop):
    # Send the input and label to device
    original_image, input = original_image.to(config.DEVICE), input.to(config.DEVICE)

    # Send the input through the segmentation model
    output = model(input)[0]

    # Apply binary threshold to output
    output = (output > 0.1).float() * 1

    # Resize the output to the original image dimensions
    original_height = original_image.size(dim=2)
    original_width = original_image.size(dim=3)
    resized_output = transforms.functional.resize(output,(original_height, original_width))

    # apply the resized output to the 
    masked_image = resized_output*original_image

    # Only save the image if the masked image is sufficiently large
    if (torch.count_nonzero(masked_image) / (original_height*original_width)) >= 0.2:
        # Save the unmasked image
        unsegmented_path = join(BASE_UNSEGMENTED_EXPORT_PATH, image_name[0])
        save_image(original_image, unsegmented_path)  

        # Save the masked image
        segmented_path = join(BASE_SEGMENTED_EXPORT_PATH, image_name[0])
        save_image(masked_image, segmented_path)    

100%|██████████| 3593/3593 [02:16<00:00, 26.36it/s]
