In [1]:
%load_ext autoreload
%autoreload 2

## Download the small test dataset first

If working on Colab or you have gsutil installed on your local machine, try the code block below; otherwise, use the [link](https://storage.googleapis.com/kidney_dataset/test_data.zip) to download the dataset and move it to `/tmp/test_data/` folder.

In [2]:
!mkdir -p /tmp/test_data/
!gsutil cp "gs://kidney_dataset/test_data.zip" "/tmp/test_data/"

Copying gs://kidney_dataset/test_data.zip...
| [1 files][ 69.9 MiB/ 69.9 MiB]                                                
Operation completed over 1 objects/69.9 MiB.                                     


In [3]:
!unzip -o /tmp/test_data/test_data.zip -d /tmp/test_data/

Archive:  /tmp/test_data/test_data.zip
   creating: /tmp/test_data/hdf5_source/
  inflating: /tmp/test_data/hdf5_source/PAS_005_tuft_mask.png  
  inflating: /tmp/test_data/hdf5_source/PAS_005.png  
  inflating: /tmp/test_data/hdf5_source/PAS_004_tuft_mask.png  
  inflating: /tmp/test_data/hdf5_source/PAS_004.png  
  inflating: /tmp/test_data/hdf5_source/PAS_003_tuft_mask.png  
  inflating: /tmp/test_data/hdf5_source/PAS_003_capsule_mask.png  
  inflating: /tmp/test_data/hdf5_source/PAS_003.png  
  inflating: /tmp/test_data/hdf5_source/PAS_002_tuft_mask.png  
  inflating: /tmp/test_data/hdf5_source/PAS_002.png  
  inflating: /tmp/test_data/hdf5_source/PAS_001_tuft_mask.png  
  inflating: /tmp/test_data/hdf5_source/PAS_001_capsule_mask.png  
  inflating: /tmp/test_data/hdf5_source/PAS_001.png  


## Create hdf5 dataset

In [5]:
from pathlib import Path

In [6]:
input_data_root = Path("/tmp/test_data/hdf5_source/")
output_data_root = Path("/tmp/test_data/hdf5_output")
output_data_root.mkdir(exist_ok=True)

In [7]:
from patches_extraction import extract_img_patches, extract_mask_patches, Extractor, crop_and_save_patches_to_hdf5

In [8]:
extractor = Extractor(resize=0.25, mirror_pad_size=128, patch_size=256, stride_size=64)

In [9]:
images = [p for p in input_data_root.glob("*.png") if "mask" not in str(p)]
images.sort()
images

[PosixPath('/tmp/test_data/hdf5_source/PAS_001.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_002.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_003.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_004.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_005.png')]

In [10]:
masks = [p for p in input_data_root.glob("*tuft_mask.png")]
masks.sort()
masks

[PosixPath('/tmp/test_data/hdf5_source/PAS_001_tuft_mask.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_002_tuft_mask.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_003_tuft_mask.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_004_tuft_mask.png'),
 PosixPath('/tmp/test_data/hdf5_source/PAS_005_tuft_mask.png')]

In [11]:
hdf5_dataset_fname = output_data_root / "PAS_glom_train.h5"
hdf5_dataset_fname

PosixPath('/tmp/test_data/hdf5_output/PAS_glom_train.h5')

In [12]:
# test the function of image extraction helper function
patches, image_indices = extract_img_patches(images[0], extractor)
patches.shape, len(image_indices)



((144, 256, 256, 3), 144)

In [13]:
# test the function of mask extraction helper function
mask_patches, mask_indices = extract_mask_patches(masks[0],extractor)
mask_patches.shape, len(mask_indices)

((144, 256, 256, 3), 119)

In [14]:
# test the main function used for generating hdf5 file
crop_and_save_patches_to_hdf5(hdf5_dataset_fname, images, masks, extractor)

In [15]:
!ls $output_data_root

PAS_glom_train.h5


## Load data as torch Dataset and Dataloader

In [16]:
from dataset import Dataset
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

In [17]:
# parameters for augmentations and training
patch_size = 256
device = torch.device("cuda:0")
edge_weight = 1.2

In [18]:
#note that since we need the transofrmations to be reproducible for both masks and images
#we do the spatial transformations first, and afterwards do any color augmentations
img_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(size=(patch_size,patch_size),pad_if_needed=True), #these need to be in a reproducible order, first affine transforms and then color
    transforms.RandomResizedCrop(size=patch_size,scale=(0.8, 1.2)),
    transforms.RandomRotation(180),
    
    #randomly pick a color augmentation
    transforms.RandomChoice([
    transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=.2),
    transforms.ColorJitter(brightness=0.3, contrast=0, saturation=0, hue=0),
    transforms.ColorJitter(brightness=0.1, contrast=0, saturation=0.2, hue=0),
    transforms.ColorJitter(brightness=0.2, contrast=0, saturation=0.1, hue=0.2),
    transforms.RandomGrayscale()]),
    transforms.ToTensor()
    ])


mask_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(size=(patch_size,patch_size),pad_if_needed=True), #these need to be in a reproducible order, first affine transforms and then color
    transforms.RandomResizedCrop(size=patch_size,scale=(0.8, 1.2)),
    transforms.RandomRotation(180),
    transforms.ToTensor()
    ])

In [19]:
hdf5_dataset_fname = Path("/tmp/test_data/hdf5_output/PAS_glom_train.h5")

In [20]:
dataset = Dataset(hdf5_dataset_fname, 
                  img_transform=img_transform, 
                  mask_transform=mask_transform, 
                  use_edge_mask=True)

In [21]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)

## Simulate the training procedure to test the functionality of dataloader

In [22]:
# just for testing dataloader, not applying gradient descent here
loss = torch.nn.CrossEntropyLoss(reduction='none')

In [23]:
for i, data in enumerate(dataloader):
    img, mask, edge_mask = list(map(lambda x: x.to(device), data))
    
    # generate a naive prediction using binary random samples
    output = torch.rand(size=(mask.shape[0], 2, *mask.shape[1:])).to(device) # [N, Nclass, H, W]
    
    loss_matrix = loss(output, mask)
    loss_val = (loss_matrix * (edge_weight**edge_mask)).mean()
    
    if i == 0:
        list(map(lambda x: print(f"Shape: {x.shape}, Dtype: {x.dtype}"), data))
    print(f"Batch {i} has loss: {loss_val}")

Shape: torch.Size([8, 3, 256, 256]), Dtype: torch.float32
Shape: torch.Size([8, 256, 256]), Dtype: torch.int64
Shape: torch.Size([8, 256, 256]), Dtype: torch.float32
Batch 0 has loss: 0.7136338949203491
Batch 1 has loss: 0.713422417640686
Batch 2 has loss: 0.7133139371871948
Batch 3 has loss: 0.7133567333221436
Batch 4 has loss: 0.7135315537452698
Batch 5 has loss: 0.7135619521141052
Batch 6 has loss: 0.713448166847229
Batch 7 has loss: 0.713767409324646
Batch 8 has loss: 0.7139379382133484
Batch 9 has loss: 0.7135822772979736
Batch 10 has loss: 0.7134979963302612
Batch 11 has loss: 0.7140174508094788
Batch 12 has loss: 0.7135435342788696
Batch 13 has loss: 0.7138286828994751
Batch 14 has loss: 0.7135250568389893
Batch 15 has loss: 0.7136993408203125
Batch 16 has loss: 0.713544487953186
Batch 17 has loss: 0.7137070894241333
Batch 18 has loss: 0.7136034965515137
Batch 19 has loss: 0.7132987380027771
Batch 20 has loss: 0.7140078544616699
Batch 21 has loss: 0.7136460542678833
Batch 22 has