In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import skimage
from skimage import io
import os
import glob
import numpy as np
from skimage import exposure, measure
from skimage.transform import rotate
import re
from torch.utils.data import Dataset

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [57]:
def dir_to_file_lists(directory):
    os.chdir(directory)
    input_list = []
    target_list = []
    all_tifs = glob.glob("*.tif")
    input_tifs = [file for file in all_tifs if '_channel1_' in file]
    input_tifs.sort()
    output_tifs = [file for file in all_tifs if '_channel6_' in file]
    output_tifs.sort()
    return(input_tifs, output_tifs)

In [58]:
input_tifs, output_tifs = dir_to_file_lists('/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-top50')

In [59]:
def image_to_matrix_dataset(file_list):
    """
    funciton takes list of file names and returns list of matrices
    list will be 6 times as long since data is flipped + rotated too
    """
    
    mat_list = []
    for file in file_list:
        orig = io.imread(file)
        mat_list.append(orig)
        
        #vertical flip
        vert_flip = orig[::-1]
        mat_list.append(vert_flip)
        
        #horizonal flip
        horiz_flip = np.flip(orig,1)
        mat_list.append(horiz_flip)
        
        #rotate 90 degrees
        rot_90 = rotate(orig, 90)
        mat_list.append(rot_90)
        
        #rotate 180 degrees
        rot_180 = rotate(orig, 180)
        mat_list.append(rot_180)
        
        #rotate 270 degrees
        rot_270 = rotate(orig, 270)
        mat_list.append(rot_270)
    return(mat_list)

In [60]:
input_tifs_mats = image_to_matrix_dataset(input_tifs[:100])
output_tifs_mats = image_to_matrix_dataset(output_tifs[:100])
# write to pickle

# read from pickle

In [76]:
input_tifs_mats[1][0].astype(dtype = 'float32')

array([ 875.,  950.,  930.,  836.,  923.,  871.,  757.,  968.,  929.,
       1016.,  812.,  830.,  882.,  838.,  915.,  856.,  867.,  782.,
        901.,  939.,  961.,  814.,  786.,  913.,  779.,  845., 1012.,
        909.,  916.,  815.,  785.,  662.,  752.,  822.,  826.,  745.,
        887.,  786.,  899.,  721., 1029.,  722.,  778.,  863.,  740.,
        891.,  898.,  847.,  954.,  742.,  868.,  842.,  791.,  650.,
        708.,  902.,  711.,  720.,  782.,  727.,  814.,  752.,  732.,
        790.,  871.,  777.,  652.,  667.,  721.,  800.,  727.,  801.,
        796.,  665.,  744.,  763.,  775.,  787.,  669.,  717.,  719.,
        845.,  723.,  655.,  787.,  734.,  713.,  682.,  749.,  651.,
        750.,  882.,  796.,  631.,  611.,  758.,  792.,  766.,  699.,
        703.,  695.,  725.,  748.,  680.,  623.,  654.,  695.,  753.,
        575.,  697.,  660.,  669.,  748.,  724.,  667.,  597.,  644.,
        777.,  594.,  857.,  759.,  642.,  628.,  775.,  788.,  713.,
        633.,  535.,

In [61]:
class two_image_dataset(Dataset):
    
    def __init__(self, input_tifs_mats, output_tifs_mats):
        
        self.input_tifs_mats = input_tifs_mats
        self.output_tifs_mats = output_tifs_mats
        assert (len(self.input_tifs_mats) == len(self.output_tifs_mats))
    
    def __len__(self):
        return len(self.input_tifs_mats)
    
    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        input_mat = self.input_tifs_mats[key]
        output_mat = self.output_tifs_mats[key]
        return [input_mat, output_mat]

In [79]:
def two_image_collate_func(batch):
    """
    function that returns input and target as tensors
    """
    input_list = []
    target_list = []
    for datum in batch:
        input_list.append(datum[0].astype(dtype = 'float32')/32768)
        target_list.append(datum[1].astype(dtype = 'float32')/32768)
    input_tensor = torch.from_numpy(np.array(input_list))
    target_tensor = torch.from_numpy(np.array(input_list))
    return [input_tensor, target_tensor]

In [80]:
BATCH_SIZE = 4
all_dataset = two_image_dataset(input_tifs_mats, output_tifs_mats)
all_loader = torch.utils.data.DataLoader(dataset=all_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=two_image_collate_func,
                                           shuffle=False)

In [81]:
#for i, (mat1, mat2) in enumerate(all_loader):
#    print(i)
#    print(mat1)
#    print(mat2)

In [82]:
# unet parts here

class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # for padding issues, see 
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

In [83]:
# UNET arch here
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x = x.view(x.size(0), 1, 512, 512).to(device)
        #print(x.shape)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return F.sigmoid(x)

In [87]:
model = UNet(1, 1)
model = model.to(device)

In [88]:
learning_rate = 0.001
num_epochs = 5 

# Criterion and Optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [89]:
for epoch in range(num_epochs): 
    for i, (inputs, targets) in enumerate(all_loader):
        model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets.view(4,1,512,512).to(device))
        #print(outputs[0])
        #print(targets.view(4,1,512,512)[0])
        #print(loss)
        # Backward and optimize
        loss.backward()
        optimizer.step()
        #print("Training Loss : {}".format(loss))
        # validate every 10 iterations
        if i > 0 and i % 10 == 0:
            # validate
            print("Training Loss : {}".format(loss))


Training Loss : 0.12659518420696259
Training Loss : 0.10929626226425171
Training Loss : 0.072621189057827
Training Loss : 0.05901113897562027
Training Loss : 0.06625333428382874
Training Loss : 0.0381472148001194
Training Loss : 0.032305099070072174
Training Loss : 0.04280365630984306
Training Loss : 0.021401144564151764
Training Loss : 0.01861284300684929
Training Loss : 0.02628067322075367
Training Loss : 0.013229690492153168
Training Loss : 0.01122827734798193
Training Loss : 0.010287940502166748
Training Loss : 0.00688639422878623
Training Loss : 0.006497588939964771
Training Loss : 0.004920864012092352
Training Loss : 0.004697423428297043
Training Loss : 0.004643122665584087
Training Loss : 0.003507231129333377
Training Loss : 0.0034348531626164913
Training Loss : 0.0034930482506752014
Training Loss : 0.0026012668386101723
Training Loss : 0.0026501037646085024
Training Loss : 0.0027290810830891132
Training Loss : 0.001981619745492935
Training Loss : 0.002037632744759321
Training L

In [94]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(model_parameters)
print(params)

<filter object at 0x20008a65f160>
13394177
