# Jaws Segmentation Task


In [1]:
! pip install --user torch torchvision matplotlib numpy progressbar

Collecting progressbar
  Downloading progressbar-2.5.tar.gz (10 kB)
Building wheels for collected packages: progressbar
  Building wheel for progressbar (setup.py) ... [?25l[?25hdone
  Created wheel for progressbar: filename=progressbar-2.5-py3-none-any.whl size=12082 sha256=16915f608ebdd0f4edc403bda1f8c87f785628339d5fa88c1ceb124b42748145
  Stored in directory: /root/.cache/pip/wheels/f0/fd/1f/3e35ed57e94cd8ced38dd46771f1f0f94f65fec548659ed855
Successfully built progressbar
Installing collected packages: progressbar
Successfully installed progressbar-2.5


In [2]:
import urllib.request
import zipfile
import os
import progressbar
from math import ceil
import torch
import gzip
import numpy as np
import glob
import matplotlib.pyplot as plt
from torchvision import transforms

### I worked on google colab and also on locally jupyter so:
1. if you test it locally on jupyter, then download the dataset or write the data path on your pc, don't forget to make "DATA_on_DRIVE = False"
2. if you test it on colab, then download the dataset or amount and write the data path on drive, don't forget to make "DATA_on_DRIVE = True" if the data on drive

In [3]:
LOCAL_DATASET_PATH = 'dataset'
BATCH_SIZE = 16
DATA_on_DRIVE = False
# AXIAL_TRAINING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/axial/train.zip'
# AXIAL_TESTING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/axial/test.zip'
# CORONAL_TRAINING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/coronal/train.zip'
# CORONAL_TESTING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/coronal/test.zip'
# SAGITTAL_TRAINING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/sagittal/train.zip'
# SAGITTAL_TESTING_DATASET = 'https://cvml-datasets.s3.eu-west-3.amazonaws.com/jaws-segmentation/v1/public/2d/sagittal/test.zip'
if(DATA_on_DRIVE == False):
    AXIAL_TRAINING_DATASET = 'dicom_dataset/axial/train/**/*.dicom.npy.gz'
    AXIAL_TESTING_DATASET = 'dicom_dataset/axial/test/**/*.dicom.npy.gz'
    CORONAL_TRAINING_DATASET = 'dicom_dataset/coronal/train/**/*.dicom.npy.gz'
    CORONAL_TESTING_DATASET = 'dicom_dataset/coronal/test/**/*.dicom.npy.gz'
    SAGITTAL_TRAINING_DATASET = 'dicom_dataset/sagittal/train/**/*.dicom.npy.gz'
    SAGITTAL_TESTING_DATASET = 'dicom_dataset/sagittal/test/**/*.dicom.npy.gz'
else:
    from google.colab import drive
    drive.mount('/content/drive')
    AXIAL_TRAINING_DATASET = '/content/drive/MyDrive/dicom_dataset/axial/train/**/*.dicom.npy.gz'
    AXIAL_TESTING_DATASET = '/content/drive/MyDrive/dicom_dataset/axial/test/**/*.dicom.npy.gz'
    CORONAL_TRAINING_DATASET = '/content/drive/MyDrive/dicom_dataset/coronal/train/**/*.dicom.npy.gz'
    CORONAL_TESTING_DATASET = '/content/drive/MyDrive/dicom_dataset/coronal/test/**/*.dicom.npy.gz'
    SAGITTAL_TRAINING_DATASET = '/content/drive/MyDrive/dicom_dataset/sagittal/train/**/*.dicom.npy.gz'
    SAGITTAL_TESTING_DATASET = '/content/drive/MyDrive/dicom_dataset/sagittal/test/**/*.dicom.npy.gz'

## Downloading Dataset

In this part we download the publicly available dataset, you can skip it if you already have it, it should be 5.6 Gb worth of data

In [4]:
# download_progress_bar = None
# def show_progress(block_num, block_size, total_size):
#     global download_progress_bar
#     if download_progress_bar is None:
#         download_progress_bar = progressbar.ProgressBar(maxval=total_size)
#         download_progress_bar.start()

#     downloaded = block_num * block_size
#     if downloaded < total_size:
#         download_progress_bar.update(downloaded)
#     else:
#         download_progress_bar.finish()
#         download_progress_bar = None

# def download_file(url, disk_path):
#     print(f'downloading {url}')
#     filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
#     os.makedirs(disk_path)
#     with zipfile.ZipFile(filename, 'r') as zip:
#         zip.extractall(disk_path)

# def download_data(to=LOCAL_DATASET_PATH):
#     download_file(AXIAL_TRAINING_DATASET, os.path.join(to, 'axial', 'train'))
#     download_file(AXIAL_TESTING_DATASET, os.path.join(to, 'axial', 'test'))
#     download_file(CORONAL_TRAINING_DATASET, os.path.join(to, 'coronal', 'train'))
#     download_file(CORONAL_TESTING_DATASET, os.path.join(to, 'coronal', 'test'))
#     download_file(SAGITTAL_TRAINING_DATASET, os.path.join(to, 'sagittal', 'train'))
#     download_file(SAGITTAL_TESTING_DATASET, os.path.join(to, 'sagittal', 'test'))

# download_data()

## Explore The Dataset

In this section you should explore/plot the dataset and get familiar with it, we are nice enough to write a dataset loader for you and we did some initial visualization for you

In [6]:
class JawsDataset(torch.utils.data.Dataset):
	def __init__(self, dicom_file_list, transforms):
		self.dicom_file_list = dicom_file_list
		self.transforms = transforms

	def __len__(self):
		return len(self.dicom_file_list)

	def __getitem__(self, idx):
		dicom_path = self.dicom_file_list[idx]
		label_path = dicom_path.replace('.dicom.npy.gz', '.label.npy.gz')
		dicom_file = gzip.GzipFile(dicom_path, 'rb')
		dicom = np.load(dicom_file)
		label_file = gzip.GzipFile(label_path, 'rb')
		label = np.load(label_file)
		return self.transforms(dicom), self.transforms(label)

def axial_dataset_train(transforms, validation_ratio = 0.1):
	files = glob.glob(AXIAL_TRAINING_DATASET)
	assert len(files) > 0
	validation_files_count = ceil(len(files) * validation_ratio)

	return (JawsDataset(files[validation_files_count:], transforms),
			JawsDataset(files[:validation_files_count], transforms))

def coronal_dataset_train(transforms, validation_ratio = 0.1):
	files = glob.glob(CORONAL_TRAINING_DATASET)
	assert len(files) > 0
	validation_files_count = ceil(len(files) * validation_ratio)

	return (JawsDataset(files[validation_files_count:], transforms),
			JawsDataset(files[:validation_files_count], transforms))

def sagittal_dataset_train(transforms, validation_ratio = 0.1):
	files = glob.glob(SAGITTAL_TRAINING_DATASET)
	assert len(files) > 0
	assert len(files) > 0
	validation_files_count = ceil(len(files) * validation_ratio)

	return (JawsDataset(files[validation_files_count:], transforms),
			JawsDataset(files[:validation_files_count], transforms))

def axial_dataset_test(transforms):
	files = glob.glob(AXIAL_TESTING_DATASET)
	assert len(files) > 0
	return JawsDataset(files, transforms)

def coronal_dataset_test(transforms):
	files = glob.glob(CORONAL_TESTING_DATASET)
	assert len(files) > 0
	return JawsDataset(files, transforms)

def sagittal_dataset_test(transforms):
	files = glob.glob(SAGITTAL_TESTING_DATASET)
	assert len(files) > 0
	return JawsDataset(files, transforms)


In [7]:
dataset_transforms = transforms.Compose([transforms.ToTensor(), transforms.Resize((128, 128)), transforms.Normalize(mean=[0.0], std=[1.0])])

In [8]:
def get_plane_datasets(plane_type="axial"):
    if(plane_type.lower() == "sagittal"):
        points_train_dataset, points_validation_dataset = sagittal_dataset_train(dataset_transforms)
        points_test_dataset = sagittal_dataset_test(dataset_transforms)
    elif(plane_type.lower() == "coronal"):
        points_train_dataset, points_validation_dataset = coronal_dataset_train(dataset_transforms)
        points_test_dataset = coronal_dataset_test(dataset_transforms)
    else:
        points_train_dataset, points_validation_dataset = axial_dataset_train(dataset_transforms)
        points_test_dataset = axial_dataset_test(dataset_transforms)
    return points_train_dataset, points_validation_dataset, points_test_dataset

In [9]:
def get_plane_dataloader(train_ds, val_ds, test_ds):
    train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE)
    val_loader = torch.utils.data.DataLoader(val_ds, shuffle=True, batch_size=BATCH_SIZE)
    test_loader = torch.utils.data.DataLoader(test_ds, shuffle=True, batch_size=BATCH_SIZE)
    return train_loader, val_loader, test_loader

In [10]:
def get_images_labels(data_loader):
	data_iter = iter(data_loader)
	images, labels = data_iter.next()
	return images, labels

In [11]:
def plot_images_labels(images, labels):
    plt.figure(figsize=(16, 4))
    for index in range(8, min(16, len(images))):
        plt.subplot(2, 8, index + 1)
        plt.axis('off')
        plt.imshow(images[index].numpy().squeeze(), cmap='bone')
        plt.imshow(labels[index].numpy().squeeze(), alpha=0.3)