In [1]:
# -*- coding: utf-8 -*-
from __future__ import division
import argparse
import torch
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.models as models
import torchvision.transforms as T
import torchvision.transforms as transforms
import pandas as pd
import os
import pydicom
import numpy as np
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import glob
import os
from os import listdir
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader,Dataset
from skimage.color import gray2rgb
import functools
from tqdm.auto import tqdm
import pydicom
import seaborn as sns
import scipy
import PIL
import json

class KagglePEDataset(torch.utils.data.Dataset):
    """Kaggle PE dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.pedataframe = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        """ Return number of 2D images. (Each CT slice is an independent image.)"""
        return len(self.pedataframe)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.pedataframe.StudyInstanceUID[idx],
                                self.pedataframe.SeriesInstanceUID[idx],
                                self.pedataframe.SOPInstanceUID[idx] + '.dcm')
        dicom_image = pydicom.dcmread(img_name) 
        image = dicom_image.pixel_array
        
        # in OSIC we find outside-scanner-regions with raw-values of -2000. 
        # Let's threshold between air (0) and this default (-2000) using -1000
        image[image <= -1000] = 0
        
        # convert to HU using DICOM information
        # HU is a number between -1000 and 1000 (generally)
        # good lung tissue is between -950 and -700 (approximately)
        intercept = dicom_image.RescaleIntercept
        slope = dicom_image.RescaleSlope
        
        if slope != 1:
            image = slope * image.astype(np.float64)
            
        image = image.astype(np.int16)
        image += np.int16(intercept)
        
        # Convert image from numpy array to PIL image (so that we can use pytorch transforms)
        image[image >= 500] = 500
        image[image <= -1000] = -1000
        image = (image + 1000)/1500
        image = image*255
        image = np.uint8(image)
        image = PIL.Image.fromarray(image).convert('RGB')

        # image is 512x512 RGB PIL image
        # pe_present_on_image is 0 or 1
        sample = {'image': image, 
                  'pe_present_on_image': int(self.pedataframe.pe_present_on_image[idx])}

        # Only apply transform to image.
        if self.transform:
            sample['image'] = self.transform(sample['image'])
            
        return sample

In [2]:
data_dir = '/projectnb/ece601/kaggle-pulmonary-embolism/rsna-str-pulmonary-embolism-detection/'
train_csv = data_dir + 'train.csv'
train_dir = data_dir + 'train/'

resnext101 = models.resnext101_32x8d(pretrained=True, progress=True)

In [3]:
# use values from sample image (but ideally this should be values from entire dataset)
global_mean = 111.6126708984375
global_std = 79.95233637352047

transform=T.Compose([T.Resize(256),
                     T.RandomCrop(224),
                     T.ToTensor(),
                     T.Normalize(mean=[global_mean, global_mean, global_mean], 
                                          std=[global_std, global_std, global_std]),
                    ])

transformed_dataset = KagglePEDataset(csv_file=train_csv, root_dir=train_dir, transform=transform)

In [8]:
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(transformed_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(transformed_dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(transformed_dataset, batch_size=batch_size,
                                                sampler=valid_sampler)

In [19]:
print(len(train_sampler))
print(len(valid_sampler))
print(len(train_sampler) + len(valid_sampler))
transformed_dataset.pedataframe.pe_present_on_image == 1


1432476
358118
1790594


0          False
1          False
2          False
3          False
4          False
           ...  
1790589    False
1790590    False
1790591    False
1790592    False
1790593     True
Name: pe_present_on_image, Length: 1790594, dtype: bool

In [11]:
resnext101.fc = torch.nn.Linear(resnext101.fc.in_features, 2)

In [12]:
epochs = 150
batch_size = 1
learning_rate = 0.1
momentum = 0.9
decay = 0.0005
schedule = [50, 100]
ngpu = 1
prefetch = 2
log = './'
save = './snapshots'

# Init logger
if not os.path.isdir(log):
    os.makedirs(log)
log = open(os.path.join(log, 'log.txt'), 'w')
state = {'learning_rate':learning_rate,'decay':decay,'momentum':momentum}
log.write(json.dumps(state) + '\n')

# Init checkpoints
if not os.path.isdir(save):
    os.makedirs(save)

# Init model, criterion, and optimizer
net = resnext101

if ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(ngpu)))

if ngpu > 0:
    net.cuda()

optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                            weight_decay=state['decay'], nesterov=True)

dataloader = DataLoader(transformed_dataset, batch_size=batch_size,
                        shuffle=True, num_workers=1)

train_loader = dataloader
test_loader = dataloader

# train function (forward, backward, update)
def train():
    net.train()
    loss_avg = 0.0
    for batch_idx, sample_batched in enumerate(train_loader):
        data = torch.autograd.Variable(sample_batched['image'].cuda())
        target = torch.autograd.Variable(sample_batched['pe_present_on_image'].cuda())
        
        print(data.shape)
        print(target.shape)

        # forward
        output = net(data.float())

        # backward
        optimizer.zero_grad()
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        # exponential moving average
        loss_avg = loss_avg * 0.2 + float(loss) * 0.8
        
        break

    state['train_loss'] = loss_avg


# test function (forward only)
def test():
    net.eval()
    loss_avg = 0.0
    correct = 0
    for batch_idx, sample_batched in enumerate(test_loader):
        data = torch.autograd.Variable(sample_batched['image'].cuda())
        target = torch.autograd.Variable(sample_batched['pe_present_on_image'].cuda())

        # forward
        output = net(data.float())
        loss = F.cross_entropy(output, target)

        # accuracy
        pred = output.data.max(1)[1]
        correct += float(pred.eq(target.data).sum())

        # test loss average
        loss_avg += float(loss)
        
        break

    state['test_loss'] = loss_avg / len(test_loader)
    state['test_accuracy'] = correct / len(test_loader.dataset)


# Main loop
best_accuracy = 0.0
for epoch in range(epochs):
    if epoch in schedule:
        state['learning_rate'] *= gamma
        for param_group in optimizer.param_groups:
            param_group['lr'] = state['learning_rate']

    state['epoch'] = epoch
    train()
    test()
    if state['test_accuracy'] > best_accuracy:
        best_accuracy = state['test_accuracy']
        torch.save(net.state_dict(), os.path.join(save, 'model.pytorch'))
    log.write('%s\n' % json.dumps(state))
    log.flush()
    print(state)
    print("Best accuracy: %f" % best_accuracy)
    break

log.close()


torch.Size([1, 3, 224, 224])
torch.Size([1])
{'learning_rate': 0.1, 'decay': 0.0005, 'momentum': 0.9, 'epoch': 0, 'train_loss': 0.44564170837402345, 'test_loss': 0.0, 'test_accuracy': 5.584738919040273e-07}
Best accuracy: 0.000001


In [18]:
print(pydicom.config.pixel_data_handlers)

[<module 'pydicom.pixel_data_handlers.numpy_handler' from '/share/pkg.7/python3/3.7.7/install/lib/python3.7/site-packages/pydicom/pixel_data_handlers/numpy_handler.py'>, <module 'pydicom.pixel_data_handlers.rle_handler' from '/share/pkg.7/python3/3.7.7/install/lib/python3.7/site-packages/pydicom/pixel_data_handlers/rle_handler.py'>, <module 'pydicom.pixel_data_handlers.gdcm_handler' from '/share/pkg.7/python3/3.7.7/install/lib/python3.7/site-packages/pydicom/pixel_data_handlers/gdcm_handler.py'>, <module 'pydicom.pixel_data_handlers.pillow_handler' from '/share/pkg.7/python3/3.7.7/install/lib/python3.7/site-packages/pydicom/pixel_data_handlers/pillow_handler.py'>, <module 'pydicom.pixel_data_handlers.jpeg_ls_handler' from '/share/pkg.7/python3/3.7.7/install/lib/python3.7/site-packages/pydicom/pixel_data_handlers/jpeg_ls_handler.py'>]


In [23]:
pydicom.config.pixel_data_handlers[2] = None