In [None]:
import os, csv
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from imageio import imread

img_range = [-2**12, 2**12]
data_path = '/home/masonmcgough/Workspace/Data/siim-medical-image-analysis-tutorial/tiff_images'
csv_file_path = '/home/masonmcgough/Workspace/Data/siim-medical-image-analysis-tutorial/overview.csv'
VALID_EXTS = ('.jpg', '.png', '.tif', '.tiff')

## Find names and labels of images

In [None]:
def find_files(path, csv_file):
#     # get file names
#     file_names = []
#     for e in VALID_EXTS:
#         file_names.extend([
#             f for f in os.listdir(data_path) 
#             if f.endswith(e)
#         ])
    
    # get labels associated with file names
    imgs_data = []
    with open(csv_file, 'r') as f:
        f_reader = iter(csv.reader(f, delimiter=','))
        header = next(f_reader)
        for row in f_reader:
            imgs_data.append({f: x for f, x in zip(header, row)})
    
    # append path to file names
    for i in imgs_data:
        i['tiff_full_path'] = os.path.join(path, i['tiff_name'])
        
    return imgs_data
    
imgs_data = find_files(data_path, csv_file_path)
print("{0} files found in '{1}'".format(len(imgs_data), data_path))
print("imgs_data[0]: '{0}'".format(imgs_data[0]))

## Make subclass for dataset

In [None]:
from skimage.transform import resize
class CTDataset(Dataset):
    def __init__(self, files_list, labels, img_dims=(256, 256)):
        self.files_list = files_list
        self.labels = labels
        self.img_dims = img_dims
    
    def __len__(self):
        return len(self.files_list)
    
    def __getitem__(self, index):
        img = np.array(imread(self.files_list[index]))
        img = resize(img, self.img_dims)
        img = np.expand_dims(img, 0)
        img = img / img_range[1]
        
        label = self.labels[index]
        if label == 'True':
            label = 1
        else:
            label = 0
        return {'image': img, 'label': label, 'index': index}

## Create instance

In [None]:
imgs_paths = [i['tiff_full_path'] for i in imgs_data]
imgs_labels = [i['Contrast'] for i in imgs_data]
img_dims = (128, 128)
mydataset = CTDataset(imgs_paths, imgs_labels, img_dims=img_dims)
print("Number of images: {0}".format(len(mydataset)))
ex = mydataset.__getitem__(0)
print(ex)
print("Shape: {0}".format(ex['image'].shape))

# add to dataloader
mydataloader = torch.utils.data.DataLoader(
    mydataset,
    batch_size=4,
    shuffle=True
)

## Define model

In [None]:
import torch.nn as nn
import torch.nn.functional as F
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 16, 3, stride=2)
        self.conv2 = nn.Conv2d(16, 32, 3, stride=2)
        self.conv3 = nn.Conv2d(32, 64, 3, stride=2)
        
        self.fc1 = nn.Linear(1568, 50)
        self.fc2 = nn.Linear(50, 2)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
        
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

mynet = MyNet()    

# define loss and optimizer
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(mynet.parameters(), lr=0.001, momentum=0.8)

## Display one example

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

dataiter = iter(mydataloader) # use trainloader as iterator
data = dataiter.next() # load a random batch
print(data)

img_axes = plt.imshow(np.squeeze(data['image'][0]))
print("Label: {0}".format(data['label'][0]))

## Train model

In [None]:
from torch.autograd import Variable

disp_interval = 5
n_epochs = 500
for epoch in range(n_epochs):
    running_loss = 0.0
    for i, data in enumerate(mydataloader, 0):
        inputs, labels = Variable(data['image']).float(), Variable(data['label'])
        
        optimizer.zero_grad()
        
        outputs = mynet(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        running_loss += loss.data[0]
        
        if i % disp_interval == disp_interval - 1:
            print("Epoch: {0}, Batch: {1}, Loss: {2}".format(
                epoch + 1, i + 1, running_loss / disp_interval))
            running_loss = 0