In [2]:
import numpy as np
import math, random, string
import os, sys
from h5py import File
import glob
import torch
import torch.nn as nn
import torch.multiprocessing
%matplotlib notebook
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
np.random.seed()

import glob
import numpy as np
import matplotlib.pyplot as plt
import shutil


# Load Data

In [3]:
path = "/home/d0048/Desktop/hackerson/hackII"
img_paths=[file for file in glob.glob(path + "/all_frames_5m6b/*.npy")]
img_paths.sort()
label_paths=[file for file in glob.glob(path + "/all_masks_5m6b/*.npy")]
label_paths.sort()
images = np.asarray([np.load(pth) for pth in img_paths]).transpose([0,3,1,2])
labels = np.asarray([np.load(pth) for pth in label_paths]).astype(np.double)

print('images: {}'.format(images.shape))
print('labels: {}'.format(labels.shape))
plt.imshow(images[100,:,:,0:3])

images: (5028, 6, 128, 128)
labels: (5028, 128, 128)


# FCN Arch

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class UNet(nn.Module):

    def __init__(self, n_class):
        super().__init__()

        self.dconv_down1 = double_conv(6, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(
            scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)

        x = self.dconv_down4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        return out

In [5]:
model=UNet(1)
model.train()
y=model(torch.Tensor(images[0:2,:,:,:]))

In [14]:
batch_size=2
optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-4)
criterion = nn.MSELoss()

my_dataset = torch.utils.data.TensorDataset(torch.tensor(images),torch.tensor(labels)) # create your datset
my_dataloader = torch.utils.data.DataLoader(my_dataset,batch_size=batch_size) # create your dataloader

e=0
for x,label in my_dataloader:
    e+=1
    model.train()
    optimizer.zero_grad()
    pred = model(x).squeeze()

    pred=pred.double()
    
    loss = criterion(pred, label)
    loss.backward()
    optimizer.step()
    print('Step: {}, Loss: {}'.format(e,loss))

Step: 1, Loss: 3.221248399347194
Step: 2, Loss: 0.4491828699514317
Step: 3, Loss: 1.0787694526564264
Step: 4, Loss: 1.0745441796987094
Step: 5, Loss: 0.2711768017660552
Step: 6, Loss: 0.2249392209993741
Step: 7, Loss: 0.49947446009579427
Step: 8, Loss: 0.39719790977556635
Step: 9, Loss: 0.2645301066110153
Step: 10, Loss: 0.0944639327547864
Step: 11, Loss: 0.08860555275370266
Step: 12, Loss: 0.2318749972933029
Step: 13, Loss: 0.210039647086326
Step: 14, Loss: 0.14065784458245525
Step: 15, Loss: 0.05437015472806108
Step: 16, Loss: 0.07251023857706412
Step: 17, Loss: 0.12131261691744706
Step: 18, Loss: 0.1013093933563589
Step: 19, Loss: 0.0498976458437554
Step: 20, Loss: 0.0352788786740853
Step: 21, Loss: 0.027406247886418836
Step: 22, Loss: 0.057380803930791244
Step: 23, Loss: 0.06340820255244273
Step: 24, Loss: 0.0426581617238797
Step: 25, Loss: 0.023425501234415086
Step: 26, Loss: 0.03348186611335784
Step: 27, Loss: 0.02782358790920025
Step: 28, Loss: 0.05292712420870593
Step: 29, Loss

KeyboardInterrupt: 