In [1]:
import numpy as np
import torch
import torchvision
import cv2
import os
import matplotlib.pyplot as plt
import shutil
import torch.utils.data as utils
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Data clean

In [2]:
data = np.load("visibility.npy")
cleaned = np.zeros((0,4), dtype='<U9')
for d in data:
    if d[2][0] == ">":
        d[2] = d[2][1:]
    try:
        d[1:].astype(np.float32)
        cleaned = np.concatenate([cleaned, d.reshape(1,-1)], axis=0)
    except:
        pass
labels = np.float64(cleaned[:,1:])
labels_mean = labels.mean(axis=0)
labels_std = labels.std(axis=0)
normal_labels = (labels - labels_mean)/(labels_std + 0.000001)

# Image preprocessing

In [64]:
try:
    shutil.rmtree('Capture_Clean')
except:
    pass

os.mkdir('Capture_Clean')

for raw_image in os.listdir('Capture'):
    if raw_image[-3:]=="jpg":
        if raw_image[:9] in cleaned[:, 0]:
            try:
                img = plt.imread("Capture/%s" %(raw_image))[300:700, 600:1000]
                plt.imsave("Capture_Clean/%s" %(raw_image), img)
            except:
                pass

In [18]:
used_label = 1 # 0:wet, 1:visible, 2:uv

data_info = torchvision.datasets.ImageFolder("Capture_Clean")

x = np.zeros((len(data_info),3,400,400))
y = np.zeros((len(data_info),1))

for i in range(len(data_info.imgs)):
    if data_info.imgs[i][0][-13:-4] in cleaned[:,0]:
        temp_img = plt.imread(data_info.imgs[i][0])[:,:,:3]
        temp_img = np.transpose(temp_img, (2,0,1))
        x[i] = temp_img
        _idx = np.where(cleaned[:,0] == data_info.imgs[i][0][-13:-4])[0][0]
        y[i] = normal_labels[_idx, used_label]

# x /= 255
x = Variable(torch.from_numpy(x).type(torch.FloatTensor))
y = Variable(torch.from_numpy(y).type(torch.FloatTensor))
data_set = utils.TensorDataset(x, y)

In [19]:
datums = len(data_set.tensors[0])

valid_len = int(0.1*datums)
train_len = datums - valid_len
split_data = utils.dataset.random_split(data_set, [train_len, valid_len])
train_data = split_data[0]
valid_data = split_data[1]

train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=4, shuffle=True, num_workers=2)

In [20]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, 10, 9) # (400 - 8)/2 = 196
        self.conv2 = nn.Conv2d(10, 20, 5) # (196 - 4)/2 = 96
        self.conv3 = nn.Conv2d(20, 40, 5) # (96 - 4)/2 = 46
        self.conv4 = nn.Conv2d(40, 80, 5) # (46 - 4)/2 = 21
        self.fc1 = nn.Linear(80 * 21 * 21, 1000)
        self.fc2 = nn.Linear(1000, 200)
        self.fc3 = nn.Linear(200, 50)
        self.fc4 = nn.Linear(50, 10)
        self.fc5 = nn.Linear(10, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, 80 * 21 * 21)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

net = Net()

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)

In [21]:
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = Variable(inputs)
        labels = Variable(labels)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.data
        if i % 10 == 9:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, 4*(i + 1), running_loss / 10))
        
            running_loss = 0.

    counter = 0
    running_loss = 0.0
    for i, data in enumerate(valid_loader, 0):
        counter += 1
        inputs, labels = data
        inputs = Variable(inputs)
        labels = Variable(labels)
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        running_loss += loss.data

    print("epoch %d valid error : %.3f" %(epoch + 1, running_loss/counter))

[1,    40] loss: 22.302
[1,    80] loss: 0.934
[1,   120] loss: 1.234
[1,   160] loss: 1.584
[1,   200] loss: 0.163
[1,   240] loss: 0.204
[1,   280] loss: 1.264
[1,   320] loss: 1.186
[1,   360] loss: 1.264
[1,   400] loss: 0.568
[1,   440] loss: 0.806
[1,   480] loss: 1.477
[1,   520] loss: 0.868
[1,   560] loss: 0.827
[1,   600] loss: 1.069
[1,   640] loss: 1.852
[1,   680] loss: 0.656
[1,   720] loss: 0.772
[1,   760] loss: 0.518
[1,   800] loss: 0.904
[1,   840] loss: 0.899
[1,   880] loss: 1.519
[1,   920] loss: 1.109
[1,   960] loss: 0.704
[1,  1000] loss: 0.879
[1,  1040] loss: 0.509
[1,  1080] loss: 0.484
[1,  1120] loss: 0.471
[1,  1160] loss: 0.743
[1,  1200] loss: 0.408
[1,  1240] loss: 0.157
[1,  1280] loss: 0.770
[1,  1320] loss: 1.597
[1,  1360] loss: 0.368
[1,  1400] loss: 0.792
[1,  1440] loss: 0.695
[1,  1480] loss: 1.012
[1,  1520] loss: 0.754
[1,  1560] loss: 0.847
[1,  1600] loss: 1.069
[1,  1640] loss: 0.410
[1,  1680] loss: 0.693
[1,  1720] loss: 0.271
[1,  1760]

[8,  1640] loss: 0.179
[8,  1680] loss: 0.255
[8,  1720] loss: 0.081
[8,  1760] loss: 0.289
epoch 8 valid error : 0.091
[9,    40] loss: 0.067
[9,    80] loss: 0.097
[9,   120] loss: 0.042
[9,   160] loss: 0.052
[9,   200] loss: 0.132
[9,   240] loss: 0.085
[9,   280] loss: 0.129
[9,   320] loss: 0.083
[9,   360] loss: 0.066
[9,   400] loss: 0.050
[9,   440] loss: 0.071
[9,   480] loss: 0.051
[9,   520] loss: 0.073
[9,   560] loss: 0.096
[9,   600] loss: 0.078
[9,   640] loss: 0.096
[9,   680] loss: 0.030
[9,   720] loss: 0.087
[9,   760] loss: 0.036
[9,   800] loss: 0.212
[9,   840] loss: 0.033
[9,   880] loss: 0.322
[9,   920] loss: 0.114
[9,   960] loss: 0.214
[9,  1000] loss: 0.087
[9,  1040] loss: 0.043
[9,  1080] loss: 0.085
[9,  1120] loss: 0.043
[9,  1160] loss: 0.039
[9,  1200] loss: 0.006
[9,  1240] loss: 0.060
[9,  1280] loss: 0.091
[9,  1320] loss: 0.014
[9,  1360] loss: 0.019
[9,  1400] loss: 0.042
[9,  1440] loss: 0.100
[9,  1480] loss: 0.109
[9,  1520] loss: 0.183
[9,  1