In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import skimage
from skimage import io
import os
import glob
import numpy as np
from skimage import exposure, measure
from skimage.transform import rotate
import re
from torch.utils.data import Dataset

In [7]:
def dir_to_file_lists(directory):
    os.chdir(directory)
    input_list = []
    target_list = []
    all_tifs = glob.glob("*.tif")
    input_tifs = [file for file in all_tifs if '_channel2_' in file]
    input_tifs.sort()
    output_tifs = [file for file in all_tifs if '_channel6_' in file]
    output_tifs.sort()
    return(input_tifs, output_tifs)

In [None]:
input_tifs, output_tifs = dir_to_file_lists('/gpfs/data/lionnetlab/cellvision/pilotdata/20181009-top50')

In [9]:
def image_to_matrix_dataset(file_list):
    """
    funciton takes list of file names and returns list of matrices
    list will be 6 times as long since data is flipped + rotated too
    """
    
    mat_list = []
    for file in file_list:
        orig = io.imread(file)
        mat_list.append(orig)
        
        #vertical flip
        vert_flip = orig[::-1]
        mat_list.append(vert_flip)
        
        #horizonal flip
        horiz_flip = np.flip(orig,1)
        mat_list.append(horiz_flip)
        
        #rotate 90 degrees
        rot_90 = rotate(orig, 90)
        mat_list.append(rot_90)
        
        #rotate 180 degrees
        rot_180 = rotate(orig, 180)
        mat_list.append(rot_180)
        
        #rotate 270 degrees
        rot_270 = rotate(orig, 270)
        mat_list.append(rot_270)
    return(mat_list)

In [13]:
input_tifs_mats = image_to_matrix_dataset(input_tifs[:100])
output_tifs_mats = image_to_matrix_dataset(output_tifs[:100])
# write to pickle

# read from pickle

In [17]:
class two_image_dataset(Dataset):
    
    def __init__(self, input_tifs_mats, output_tifs_mats):
        
        self.input_tifs_mats = input_tifs_mats
        self.output_tifs_mats = output_tifs_mats
        assert (len(self.input_tifs_mats) == len(self.output_tifs_mats))
    
    def __len__(self):
        return len(self.input_tifs_mats)
    
    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        input_mat = self.input_tifs_mats[key]
        output_mat = self.output_tifs_mats[key]
        return [input_mat, output_mat]

In [18]:
def two_image_collate_func(batch):
    """
    function that returns input and target as tensors
    """
    input_list = []
    target_list = []
    for datum in batch:
        input_list.append(datum[0].astype(dtype = 'int32'))
        target_list.append(datum[1].astype(dtype = 'int32'))
    input_tensor = torch.from_numpy(np.array(input_list))
    target_tensor = torch.from_numpy(np.array(input_list))
    return [input_tensor, target_tensor]

In [21]:
BATCH_SIZE = 4
all_dataset = two_image_dataset(input_tifs_mats, output_tifs_mats)
all_loader = torch.utils.data.DataLoader(dataset=all_dataset, 
                                           batch_size=BATCH_SIZE,
                                           collate_fn=two_image_collate_func,
                                           shuffle=False)

In [35]:
#for i, (mat1, mat2) in enumerate(all_loader):
#    print(i)
#    print(mat1)
#    print(mat2)

In [30]:
# UNET arch here
def convrelu(in_channels, out_channels, kernel, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
        nn.ReLU(inplace=True),
    )

class ResNetUNet(nn.Module):

    def __init__(self):
        super().__init__()

        # Use ResNet18 as the encoder with the pretrained weights
        base_model = models.resnet18(pretrained=True)
        self.base_layers = list(base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
        self.layer0_1x1 = convrelu(64, 64, 1, 0)
        self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 256, x.H/4, x.W/4)
        self.layer1_1x1 = convrelu(256, 256, 1, 0)
        self.layer2 = self.base_layers[5]  # size=(N, 512, x.H/8, x.W/8)
        self.layer2_1x1 = convrelu(512, 512, 1, 0)
        self.layer3 = self.base_layers[6]  # size=(N, 1024, x.H/16, x.W/16)
        self.layer3_1x1 = convrelu(1024, 512, 1, 0)
        self.layer4 = self.base_layers[7]  # size=(N, 2048, x.H/32, x.W/32)
        self.layer4_1x1 = convrelu(2048, 1024, 1, 0)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up3 = convrelu(512 + 1024, 512, 3, 1)
        self.conv_up2 = convrelu(512 + 512, 512, 3, 1)
        self.conv_up1 = convrelu(256 + 512, 256, 3, 1)
        self.conv_up0 = convrelu(64 + 256, 128, 3, 1)

        self.conv_original_size0 = convrelu(3, 64, 3, 1)
        self.conv_original_size1 = convrelu(64, 64, 3, 1)
        self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)

        self.conv_last = nn.Conv2d(64, 1, 1)

    def forward(self, input):
        x_original = self.conv_original_size0(input)
        x_original = self.conv_original_size1(x_original)

        layer0 = self.layer0(input)
        layer1 = self.layer1(layer0)
        layer2 = self.layer2(layer1)
        layer3 = self.layer3(layer2)
        layer4 = self.layer4(layer3)

        # Upsample the last/bottom layer
        layer4 = self.layer4_1x1(layer4)
        x = self.upsample(layer4)
        # Create the shortcut from the encoder
        layer3 = self.layer3_1x1(layer3)
        x = torch.cat([x, layer3], dim=1)
        x = self.conv_up3(x)

        x = self.upsample(x)
        layer2 = self.layer2_1x1(layer2)
        x = torch.cat([x, layer2], dim=1)
        x = self.conv_up2(x)

        x = self.upsample(x)
        layer1 = self.layer1_1x1(layer1)
        x = torch.cat([x, layer1], dim=1)
        x = self.conv_up1(x)

        x = self.upsample(x)
        layer0 = self.layer0_1x1(layer0)
        x = torch.cat([x, layer0], dim=1)
        x = self.conv_up0(x)

        x = self.upsample(x)
        x = torch.cat([x, x_original], dim=1)
        x = self.conv_original_size2(x)

        out = self.conv_last(x)

        return out

In [31]:
model = ResNetUNet()

In [32]:
learning_rate = 0.001
num_epochs = 5 

# Criterion and Optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [33]:
for epoch in range(num_epochs): 
    for i, (inputs, targets) in enumerate(all_loader):
        model.train()
        optimizer.zero_grad()
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        #print(loss)
        # Backward and optimize
        loss.backward()
        optimizer.step()
        # validate every 10 iterations
        if i > 0 and i % 100 == 0:
            # validate
            val_acc = test_model(val_loader, model)
            #train_acc = test_model(train_loader, model)
            print("Training Loss : {}".format(loss))
            #print("Training Accuracy : {}".format(train_acc))
            print('Epoch: [{}/{}], Step: [{}/{}], Validation Acc: {}'.format(
                       epoch+1, num_epochs, i+1, len(train_loader), val_acc))

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 3, 3], but got input of size [4, 512, 512] instead