In [1]:
import os
import torch
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transform
from __future__ import print_function, division
from skimage import io, transform
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
class unet(nn.Module):
    def __init__(self):
        super(unet, self).__init__()
        self.conv1 = nn.Conv2d(1,64,3) # (in_channels, out_channels, kernel_size)
        self.conv2 = nn.Conv2d(64,64,3) # (in_channels, out_channels, kernel_size)
        self.conv3 = nn.Conv2d(64,128,3) # (in_channels, out_channels, kernel_size)
        self.conv4 = nn.Conv2d(128,128,3) # (in_channels, out_channels, kernel_size)
        self.conv5 = nn.Conv2d(128,256,3) # (in_channels, out_channels, kernel_size)
        self.conv6 = nn.Conv2d(256,256,3) # (in_channels, out_channels, kernel_size)
        self.conv7 = nn.Conv2d(256,512,3) # (in_channels, out_channels, kernel_size)
        self.conv8 = nn.Conv2d(512,512,3) # (in_channels, out_channels, kernel_size)
        self.conv9 = nn.Conv2d(512,1024,3) # (in_channels, out_channels, kernel_size)
        self.conv10 = nn.Conv2d(1024,1024,3) # (in_channels, out_channels, kernel_size)
        self.conv11 = nn.Conv2d(1024,512,3) # (in_channels, out_channels, kernel_size)
        self.conv12 = nn.Conv2d(512,512,3) # (in_channels, out_channels, kernel_size)
        self.conv13 = nn.Conv2d(512,256,3) # (in_channels, out_channels, kernel_size)
        self.conv14 = nn.Conv2d(256,256,3) # (in_channels, out_channels, kernel_size)
        self.conv15 = nn.Conv2d(256,128,3) # (in_channels, out_channels, kernel_size)
        self.conv16 = nn.Conv2d(128,128,3) # (in_channels, out_channels, kernel_size)
        self.conv17 = nn.Conv2d(128,64,3) # (in_channels, out_channels, kernel_size)
        self.conv18 = nn.Conv2d(64,64,3) # (in_channels, out_channels, kernel_size)
        self.conv19 = nn.Conv2d(64,2,1) # (in_channels, out_channels, kernel_size)        
        self.maxpool = nn.MaxPool2d(2) # (kernel_size, stride, padding)   
        self.convtrans1 = nn.ConvTranspose2d(1024,512,2) # (in_channels, out_channels, kernel_size)
        self.convtrans2 = nn.ConvTranspose2d(512,256,2) # (in_channels, out_channels, kernel_size)
        self.convtrans3 = nn.ConvTranspose2d(256,128,2) # (in_channels, out_channels, kernel_size)
        self.convtrans4 = nn.ConvTranspose2d(128,64,2) # (in_channels, out_channels, kernel_size)
        
    def forward(self,x):
        
        x = F.relu(self.conv1(x))
        out1 = F.relu(self.conv2(x))
        out1_transform = nn.Upsample(size=(392, 392), mode='bilinear')(out1)
        
        x = self.maxpool(out1)
        x = F.relu(self.conv3(x))
        out2 = F.relu(self.conv4(x))
        out2_transform = nn.Upsample(size=(200, 200), mode='bilinear')(out2)
        
        x = self.maxpool(out2)
        x = F.relu(self.conv5(x))
        out3 = F.relu(self.conv6(x))
        out3_transform = nn.Upsample(size=(104, 104), mode='bilinear')(out3)
        
        x = self.maxpool(out3)
        x = F.relu(self.conv7(x))
        out4 = F.relu(self.conv8(x))
        out4_transform = nn.Upsample(size=(56, 56), mode='bilinear')(out4)
        
        x = self.maxpool(out4)
        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))
        out5 = self.convtrans1(x)
        out5_transform = nn.Upsample(size=(56, 56), mode='bilinear')(out5)
        
        x = F.relu(self.conv11(torch.cat((out4_transform,out5_transform), 1)))
        x = F.relu(self.conv12(x))
        out6 = self.convtrans2(x)
        out6_transform = nn.Upsample(size=(104, 104), mode='bilinear')(out6)
        
        x = F.relu(self.conv13(torch.cat((out3_transform,out6_transform), 1)))
        x = F.relu(self.conv14(x))
        out7 = self.convtrans3(x)
        out7_transform = nn.Upsample(size=(200, 200), mode='bilinear')(out7)
        
        x = F.relu(self.conv15(torch.cat((out2_transform,out7_transform), 1)))
        x = F.relu(self.conv16(x))
        out8 = self.convtrans4(x)
        out8_transform = nn.Upsample(size=(392, 392), mode='bilinear')(out8)
        
        x = F.relu(self.conv17(torch.cat((out1_transform,out8_transform), 1)))
        x = F.relu(self.conv18(x))
        x = F.relu(self.conv19(x))
        
        return x
        

In [3]:
# unet_inst = unet().float()

In [4]:
# img = torch.from_numpy(np.zeros((574,574))).float().unsqueeze_(0).unsqueeze_(0)
# print(img.shape)

In [5]:
# unet_inst(img)

In [6]:
# Read original training data file
csv_file = './train_ship_segmentations.csv'
X = np.array((pd.read_csv(csv_file, header=0)).as_matrix())

In [7]:
# Create and split data here
from sklearn.model_selection import train_test_split
train, valid = train_test_split(X, test_size=0.20, random_state=42)
X_train, y_train = train[:,0], train[:,1]
X_valid, y_valid = valid[:,0], valid[:,1]

In [8]:
print(X_train.shape)
print(X_valid.shape)

(104824,)
(26206,)


In [25]:
# Write TRAINING data to file
data = []
path = './train/'
for i,(X,y) in enumerate(zip(X_train, y_train),0):
    data.append({'X': path+X, 'y': y})

data = np.array(data)
print(data.shape)
np.savetxt("./data_files/train.csv", data, fmt='%s', delimiter=",")

(104824,)


In [26]:
# Write VALIDATION data to file
data = []
path = './train/'
for i,(X,y) in enumerate(zip(X_valid, y_valid),0):
    data.append({'X': path+X, 'y': y})

data = np.array(data)
print(data.shape)
np.savetxt("./data_files/valid.csv", data, fmt='%s', delimiter=",")

(26206,)


In [29]:
class ShipsDataset():
    """Face Landmarks dataset."""
    data = []
    def __init__(self, FLAG):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
        """
        if (FLAG=='train'):
            print('train_data')
            csv_file = './data_files/train.csv'
        elif (FLAG=='valid'):
            print('valid_data')
            csv_file = './data_files/valid.csv'
        elif (FLAG=='test'):
            print('test_data')
            csv_file = './data_files/test.csv'
            
        data = pd.read_csv(csv_file)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = data[idx]['X']
        image = io.imread(img_name)
        mask_coords = data[idx]['y']
        
        print('image dims: {}'.format(image.shape))
        
        # Construct mask here
        
        
        return image, mask

In [34]:
dataset_train = ShipsDataset('train')
dataloader_train = DataLoader(dataset_train, batch_size=2, shuffle=True, num_workers=0)
dataset_valid = ShipsDataset('valid')
dataloader_valid = DataLoader(dataset_valid, batch_size=2, shuffle=True, num_workers=0)

train_data
valid_data


In [36]:
# Run the training
for i, (data_train, data_valid) in enumerate(zip(dataloader_train, dataloader_valid),0):
    print(i)
    print(data_train)