In [7]:
import os
import io
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [8]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [9]:
class SRMConv2d_simple(nn.Module):
    
    def __init__(self, inc=3, learnable=False):
        super(SRMConv2d_simple, self).__init__()
        self.truc = nn.Hardtanh(-3, 3)
        kernel = self._build_kernel(inc)  # (3,3,5,5)
        self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
        # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)

    def forward(self, x):
        '''
        x: imgs (Batch, H, W, 3)
        '''
        out = F.conv2d(x, self.kernel, stride=1, padding=2)
        out = self.truc(out)

        return out

    def _build_kernel(self, inc):
        # filter1: KB
        filter1 = [[0, 0, 0, 0, 0],
                   [0, -1, 2, -1, 0],
                   [0, 2, -4, 2, 0],
                   [0, -1, 2, -1, 0],
                   [0, 0, 0, 0, 0]]
        # filter2：KV
        filter2 = [[-1, 2, -2, 2, -1],
                   [2, -6, 8, -6, 2],
                   [-2, 8, -12, 8, -2],
                   [2, -6, 8, -6, 2],
                   [-1, 2, -2, 2, -1]]
        # filter3：hor 2rd
        filter3 = [[0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0],
                  [0, 1, -2, 1, 0],
                  [0, 0, 0, 0, 0],
                  [0, 0, 0, 0, 0]]

        filter1 = np.asarray(filter1, dtype=float) / 4.
        filter2 = np.asarray(filter2, dtype=float) / 12.
        filter3 = np.asarray(filter3, dtype=float) / 2.
        # statck the filters
        filters = [[filter1],#, filter1, filter1],
                   [filter2],#, filter2, filter2],
                   [filter3]]#, filter3, filter3]]  # (3,3,5,5)
        filters = np.array(filters)
        filters = np.repeat(filters, inc, axis=1)
        filters = torch.FloatTensor(filters)    # (3,3,5,5)
        filters[filters > 2] = 2
        filters[filters < -2] = -2
        return filters



In [4]:
def SRM_filter(path):
    SRM = SRMConv2d_simple(inc=3)
    image1 = cv2.imread(path)
    image = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
    image = torch.tensor(image, dtype=torch.float)
    image = image.permute(2, 0, 1).unsqueeze(0)
    image = SRM(image)
    image = image.squeeze(0).permute(1, 2 ,0)
    image = 255 - image * image1
    
    cv2.imwrite(os.path.join('./data/ze_SRM', path.split('/')[-1]), np.float32(image))

In [12]:
file_paths = ['./data/another']
for file_path in file_paths:
    if not os.path.isdir('./data/ze_SRM'):
        os.makedirs('./data/ze_SRM')
    files = os.listdir(file_path)
    for file in files:
        file = os.path.join(file_path, file)
        print(file)
        SRM_filter(file)

./data/another/635b1831d3b3fae8ba84c325b96ae9b.jpg
./data/another/1648002355(1).jpg
