In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np

import os, sys

import utils

from train import train

# Preprocessing and import

In [None]:
ROOT_DIR = 'training/'
IMAGE_DIR = ROOT_DIR + 'images/'

In [None]:
def load_image(filename):
    return mpimg.imread(filename)

#def unstack(img):
#    return np.array([img[:,:,c] for c in range(3)])

#def restack(img):
#    return np.stack(img, axis=-1)

In [None]:
# Read images
files = os.listdir(IMAGE_DIR)
imgs = [load_image(IMAGE_DIR + file) for file in files] # images (400, 400, 3)

In [None]:
# Read groundtruts
GT_DIR = ROOT_DIR + 'groundtruth/'
gt_imgs = [load_image(GT_DIR + file) for file in files] # images (400, 400)

In [None]:
KERNEL_SIDE = 3

# Matrix creation

In [None]:
def image_to_features(img, kernel_size):
    """Linearizes patches of an image into lines.
    Arguments:
     :img: the image to linearize of shape (W, H, C)
     :kernel_size: the length of the side of the patch which will be squared
     should be odd.
    
    The radius of the patch is r = (kernel_size - 1) / 2
    The produced matrix has shape ((W - r)*(H - r), kernel_size**2 * C)
    """
    features = []
    for i in range(img.shape[0] - (kernel_size - 1)):
        for j in range(img.shape[1] - (kernel_size - 1)):
            features.append(np.ravel(img[i : i + kernel_size, j : j + kernel_size]))
    return np.vstack(features)

In [None]:
def crop_groundtruth(img, kernel_size):
    radius = (kernel_size - 1) // 2
    #img[img < 0.5] = -1
    #img[img >= 0.5] = 1
    return img[radius : -radius, radius : -radius]

In [None]:
def preds_to_tensor(preds, kernel_size, n, w, h):
    return np.reshape(preds, (n, w - (kernel_size - 1), h - (kernel_size - 1)))

In [None]:
#image_to_features(imgs[0], 3).shape

In [None]:
#features = np.vstack([image_to_features(img, 3) for img in imgs])
features = np.vstack([image_to_features(img, KERNEL_SIDE) for img in imgs[:1]])

In [None]:
#labels = [crop_groundtruth(gt, 3) for gt in gt_imgs]
labels = [crop_groundtruth(gt, KERNEL_SIDE) for gt in gt_imgs[:1]]

In [None]:
labels = np.ravel(labels)

In [None]:
features.shape, labels.shape

In [None]:
X = torch.from_numpy(features)
Y = torch.from_numpy(labels)

# Train

In [None]:
linear = nn.Linear(KERNEL_SIDE**2 * 3, 12)
final_layer = nn.Linear(12, 1)
model = nn.Sequential(linear, nn.Sigmoid(), final_layer)

In [None]:
lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr=lr)#, momentum=0.9)
costf = torch.nn.MSELoss()

train(X, Y, model, costf, optimizer, 50)

# Test on first image

In [None]:
img = imgs[0]
test_x = torch.from_numpy(image_to_features(img, KERNEL_SIDE))

In [None]:
preds = model(Variable(test_x))

In [None]:
preds = preds.data.numpy()

In [None]:
np.max(preds)

In [None]:
t = preds_to_tensor(preds, KERNEL_SIDE, 1, 400, 400)

In [None]:
t = t[0]

In [None]:
t.shape

In [None]:
plt.plot(np.sort(t.ravel()))

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(t, cmap='Greys_r')

In [None]:
new_img = utils.make_img_overlay(img[1:-1, 1:-1], t)
plt.figure(figsize=(10,10))
plt.imshow(new_img)

# Dumb one to one test

In [None]:
img = imgs[0]
gt = gt_imgs[0]

In [None]:
greyscale = np.mean(img, axis=2)

In [None]:
x = torch.from_numpy(np.reshape(np.ravel(greyscale), (400*400, 1)))

In [None]:
y = torch.from_numpy(np.ravel(gt))

In [None]:
m = nn.Linear(1, 1, bias=False)
opt = torch.optim.SGD(m.parameters(), lr = 0.01)
c = nn.MSELoss()

In [None]:
x.shape

In [None]:
train(x, y, m, c, opt, 10**4)

In [None]:
preds = np.reshape(m(Variable(x)).data.numpy(), (400, 400))

In [None]:
plt.plot(np.sort(np.ravel(preds)))

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(preds, cmap='Greys_r')

# Old Shit

In [None]:
# Build train tensor

#X = torch.from_numpy(np.stack(imgs))) # tensor (N, W, H, C)
X = torch.from_numpy(np.stack(imgs[:10]))

In [None]:
lr = 0.2
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
costf = torch.nn.MSELoss()

In [None]:
# Build groundtruth tensor

#Y = torch.from_numpy(np.stack(gt_imgs))
Y = torch.from_numpy(np.stack(gt_imgs[:10]))

In [None]:
train(X, Y, funky.model, costf, optimizer, 50)

In [None]:
# Choose image to analyze
img = imgs[0]

In [None]:
test_x.shape, X.shape

In [None]:
# Predict and give correct shape
test_x = Variable(torch.from_numpy(img).unsqueeze(0))
preds = model(test_x).squeeze(0).data.numpy()

In [None]:
plt.plot(np.sort(np.ravel(preds)))

In [None]:
preds.max()

In [None]:
# Draw
cimg = utils.concatenate_images(img, preds)
fig1 = plt.figure(figsize=(10, 10))
plt.imshow(cimg, cmap='Greys_r')

In [None]:
new_img = utils.make_img_overlay(img, preds)
plt.figure(figsize=(10,10))
plt.imshow(new_img)