# Low-pass filter to ISR
We can use a simple low pass filter to ISR, here is the example.

In [None]:
!pip install torchmetrics
!pip install einops

In [1]:
import FSDS_code
import utils
from PIL import Image
from torchvision.transforms import ToTensor
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

psnr = PeakSignalNoiseRatio(data_range=1).cuda()
ssim = StructuralSimilarityIndexMeasure(data_range=1).cuda()

lr = ToTensor()(Image.open("./example_figures/baby_x2.png")).unsqueeze(0).cuda()
gt = ToTensor()(Image.open("./example_figures/baby.png")).unsqueeze(0).cuda()
sr = utils.lpf_sr(img=lr, scale=2, omega=48.5)
print(f"LPF ISR performance: psnr={psnr(sr, gt)}, ssim={ssim(sr, gt)}, fsds={FSDS_code.FrequencySpectrumDistributionSimilarity(sr, gt)[0]}")

LPF ISR performance: psnr=27.45184326171875, ssim=0.7809420824050903, fsds=27.509355337158876


# Implementation
## 1) Zero-interpolation
Before we apply a low-pass filter to the LR image, we first interpolate 0 to it to achieve the target size by:
```python
def zero_interpolate_torch(img: torch.Tensor, scale: int):
    """
    interpolate 0 by `scale` times
    :param img: NxCxHxW
    :param scale:
    :return:
    """
    if len(img.shape) != 4:  # batched
        img = img.unsqueeze(dim=0)
    img_ = img.reshape(-1, 1, img.shape[2], img.shape[3])
    img_int = torch.concat(
        [img_, torch.zeros(img_.shape[0], scale * scale - 1, img_.shape[2], img_.shape[3]).to(img.device)],
        dim=1)
    return torch.nn.functional.pixel_shuffle(img_int, scale).reshape(img.shape[0], img.shape[1], img.shape[2] * scale,
                                                                     img.shape[3] * scale).squeeze(dim=0)

```
## 2) Low-pass filter
Then we apply a low-pass filter to the interpolated image using convolution. The full implementation is:
```python
def lpf_sr_single(img: torch.Tensor, scale: int, omega=3.):
    """
    Interpolate an image using the sinc function, it's slower than the cubic or others.

    :param img: the image to be interpolated.
    :param size: the expected size
    :param omega: the factor to adjust the scale of the sinc function
    :return: the interpolated image
    """
    img_pad = F.pad(input=img,
                    pad=(img.shape[2] // 2, img.shape[2] // 2, img.shape[3] // 2, img.shape[3] // 2),
                    mode="reflect")
    target = zero_interpolate_torch(img_pad, scale)  # zero interpolate to the target size
    h_grid = torch.linspace(-1, 1, (img.shape[2] // 2) * scale * 2 + 1)
    w_grid = torch.linspace(-1, 1, (img.shape[3] // 2) * scale * 2 + 1)
    kernel = torch.meshgrid([h_grid, w_grid], indexing='xy')

    kernel = sinc(kernel[0], omega) * sinc(kernel[1], omega) # generate the low-passfilter, the sinc function with parameter omega
    kernel = kernel.unsqueeze(dim=0).unsqueeze(dim=0).to(img.device)
    # low-pass filtering, since the sinc function is symmetric, we can directly utilize the torch.nn.functional.conv2d
    target = F.conv2d(input=target, weight=kernel, stride=1, padding="valid") 
    for i in range(target.shape[0]):
        if torch.max(img[i])>1:  # to avoid a all 0 image
            target[i] = (target[i] - torch.min(target[i]))/(torch.max(target[i])-torch.min(target[i])) * (torch.max(img[i])-torch.min(img[i])) + torch.min(img[i])
    return target
```
In the code above, the sinc function is defined as:
```python
def sinc(tensor, omega):
    """
    The sinc function implementation. sinc(t) is defined as sin(pi*t)/(pi*t), omega is a
    factor to adjust the scale
    :param tensor: variants of sinc function
    :param omega: scale factor
    :return:
    """
    return torch.sin(torch.abs(math.pi * tensor * omega) + 1e-9) / (torch.abs(math.pi * tensor * omega) + 1e-9)
```