In [1]:
import torch.nn as nn
import torch
import pytorch_lightning as pl
from utils import  get_sobel_kernel, get_gaussian_kernel, load_letter_conv_weights, get_rel_area_letters
from models import CustomTransposedConv2d

class LetterFilter(pl.LightningModule):
    def __init__(self,config):
        super().__init__()
        self.device = config["device"]
        self.letter_conv_k = config["letter_conv_k"]
        self.letter_conv_stride = config["letter_conv_stride"]
        sobel_k = config["sobel_k"]
        gauss_sig = config["gauss_sig"]
        font_path = config["font_path"]
        letters = config["letters"]
        self.num_letters = len(letters)
        self.letters_per_pix = config["letters_per_pix"]
        gauss_k = self.letter_conv_k
        self.eps = config["eps"]

        self.letter_size_alpha = config["letter_size_weight"]
        self.detail_beta = config["detail_weight"]

        # sobel edge filter
        sobel_2D = get_sobel_kernel(sobel_k)
        self.sobel_filter_x = nn.Conv2d(in_channels=1,
                                        out_channels=1,
                                        kernel_size=sobel_k,
                                        padding=sobel_k // 2,
                                        bias=False)
        self.sobel_filter_x.weight[:] = torch.from_numpy(sobel_2D)


        self.sobel_filter_y = nn.Conv2d(in_channels=1,
                                        out_channels=1,
                                        kernel_size=sobel_k,
                                        padding=sobel_k // 2,
                                        bias=False)
        self.sobel_filter_y.weight[:] = torch.from_numpy(sobel_2D.T)

        #gaussian fitler after sobel
        gaussian_2D = get_gaussian_kernel(gauss_k, 0, gauss_sig)
        self.gaussian_filter = nn.Conv2d(in_channels=1,
                                         out_channels=1,
                                         kernel_size=gauss_k,
                                         stride=self.letter_conv_stride,
                                         padding=gauss_k // 2,
                                         bias=False)
        self.gaussian_filter.weight[:] = torch.from_numpy(gaussian_2D)

        # letter convolutions
        letter_conv_weights = load_letter_conv_weights(font_path, self.letter_conv_k, letters)
        self.letter_filter = nn.Conv2d(in_channels=1,
                                       out_channels=1,
                                       kernel_size=self.letter_conv_k,
                                       stride=self.letter_conv_stride,
                                       padding=self.letter_conv_k//2)
        self.letter_filter.weight[:] = letter_conv_weights

        # letter areas
        self.letter_conv_areas = get_rel_area_letters(font_path, self.letter_conv_k, letters)


        # transposed convs for mask
        transposed_convs_weights = load_letter_conv_weights(font_path, self.letter_conv_k, letters)
        transposed_padding = self.letter_conv_k // 2
        transpose_out_padding = self.letter_conv_stride -1 
        self.transp_conv = CustomTransposedConv2d(transposed_convs_weights,
                                                   self.num_letters,
                                                     1,
                                                       self.letter_conv_k,
                                                         self.letter_conv_stride,
                                                           transposed_padding,
                                                             transpose_out_padding)
        
        
    def forward(self, input_img: torch.Tensor):
        B, _, H, W = input_img.shape # C will be 1
        H_letter_hits = (H - self.letter_conv_k + 2 * self.letter_conv_k//2) // self.letter_conv_stride + 1
        W_letter_hits = (W - self.letter_conv_k + 2 * self.letter_conv_k//2) // self.letter_conv_stride + 1
        max_letter_hits_total = B * H_letter_hits * W_letter_hits 
        grad_x = torch.zeros((B, 1, H, W)).to(self.device)
        grad_y = torch.zeros((B, 1, H, W)).to(self.device)
        grad_magnitude = torch.zeros((B, 1, H, W)).to(self.device)
        detail_map = torch.zeros((B, 1, H_letter_hits, W_letter_hits)).to(self.device)
        letter_hits = torch.zeros((B, self.num_letters, H_letter_hits, W_letter_hits)).to(self.device)
        # tracks letter hits pixels with max letter per pix reached, 0 = max reached
        filled_pixels = torch.ones((B, 1, H_letter_hits, W_letter_hits)).to(self.device)

        letter_match = torch.zeros((B, self.num_letters, H_letter_hits, W_letter_hits)).to(self.device)
        letter_areas = torch.zeros((B, self.num_letters, H_letter_hits, W_letter_hits)).to(self.device)        
        current_img = input_img.clone()

        # grad
        grad_x = self.sobel_filter_x(input_img)
        grad_y = self.sobel_filter_y(input_img)
        grad_magnitude = (grad_x ** 2 + grad_y ** 2) ** 0.5
        # normalize to [0,1]
        grad_max = grad_magnitude.max()
        grad_magnitude = grad_magnitude / grad_max if grad_max > 0 else grad_magnitude

        # blurr gradient image
        detail_map = self.gaussian_filter(grad_magnitude)
        # normalize to [0,1]
        grad_blurr_max = detail_map.max()
        detail_map = detail_map / grad_blurr_max if grad_blurr_max > 0 else detail_map

        # letter areas
        letter_areas = self.letter_conv_areas.view(1, self.num_letters, 1, 1).expand(B, self.num_letters, H_letter_hits, W_letter_hits)

        # letter matches
        letter_match : torch.Tensor = self.letter_filter(current_img)
        # letter hits total
        letter_hits_total = torch.sum(letter_hits == 0)

        while torch.any(letter_match > self.eps) and letter_hits_total < max_letter_hits_total:
            # TODO work on this line
            weighted_letter_match = letter_match.mul(torch.abs(self.letter_size_alpha *letter_areas - self.detail_beta * detail_map))
            indices = torch.argmax(weighted_letter_match.view(B, -1), dim=1, keepdim=True)
            mask = torch.zeros_like(weighted_letter_match.view(B, -1))
            mask = mask.scatter(1, indices, 1)
            mask = mask.view(B, self.num_letters, H_letter_hits, W_letter_hits)
            # add max letter hit of image of current iteration to all letterhits
            letter_hits = letter_hits + mask.mul(letter_match)

            # set filled pixels to 0 in filled pixels mask
            filled_pixels = (letter_hits > 0).sum(dim=1, keepdim=True) < self.letters_per_pix

            # update current image: substract found best letter from input img
            current_img = torch.clip(input_img - self.transp_conv(letter_hits), torch.tensor(0.), torch.tensor(1.))

            # letter matches
            letter_match = self.letter_filter(current_img)
            # set pixel with reached max letters per pix to zero
            letter_match = letter_match.mul(filled_pixels)

            # letter hits total
            letter_hits_total = torch.sum(filled_pixels == 0)
        
        return letter_hits, self.transp_conv(letter_hits).clip(torch.tensor(0.), torch.tensor(1.))



* 'schema_extra' has been renamed to 'json_schema_extra'
