In [None]:
!test -f lena.jpg || wget https://raw.githubusercontent.com/opencv/opencv/master/samples/data/lena.jpg
from google.colab import files
import os
print('lena-sobel.txt')
while not os.path.exists('lena-sobel.txt'):
    files.upload()

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
# load samples (single image provides many samples for edge detection)
sample_inp = cv2.imread('lena.jpg',cv2.IMREAD_GRAYSCALE)
sample_out = np.loadtxt('lena-sobel.txt', dtype=np.float32)
rows,cols = sample_inp.shape
samples_inp = torch.from_numpy(sample_inp).float().unsqueeze(0).unsqueeze(0)
samples_out = torch.from_numpy(sample_out).float().unsqueeze(0).unsqueeze(0)
print(samples_inp.shape)
print(samples_out.shape)

In [None]:
# define the model
class SobelFilter(nn.Module):
    # initialization
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1,1,(3,3),bias=False,padding=1)
    # forward
    def forward(self, x):
        return self.conv(x)

In [None]:
model = SobelFilter()

In [None]:
# Define loss function
loss_function = nn.MSELoss(reduction='sum')

In [None]:
# Define optimizer
optimizer = optim.Adam(model.parameters())

In [None]:
# Training
num_epochs = 12500

In [None]:
for t in range(num_epochs):
    # Forward pass
    out = model(samples_inp)
    loss = loss_function(out, samples_out)
    if t % 100 == 0:
        print(t, loss.item())
    # Reset gradients
    optimizer.zero_grad()
    # Backward pass
    loss.backward()
    # Update model parameters (weights)
    optimizer.step()

In [None]:
# Print weighs
for param in model.parameters():
    print(param.data)

tensor([[[[-1.0000e+00,  1.1135e-07,  1.0000e+00],
          [-2.0000e+00, -3.0023e-07,  2.0000e+00],
          [-1.0000e+00,  1.0843e-07,  1.0000e+00]]]])    

In [None]:
output_images = model(samples_inp)

In [None]:
output_image = output_images[0]
output_image = output_image.squeeze(0)
output_image = output_image.detach().numpy()
print(output_image.shape)
print(np.amax(output_image))
print(np.amin(output_image))

In [None]:
output_image = abs(output_image);
output_image = cv2.normalize(output_image,None,0,255,cv2.NORM_MINMAX)
output_image = np.uint8(output_image)
cv2.imwrite('lena-sobel2.jpg',output_image)

In [None]:
plt.imshow(output_image, cmap='gray')
plt.show()

In [None]:
# save the model
torch.save(model.state_dict(), 'sobel.pth') # weights only