In [None]:
# https://drive.google.com/file/d/1L5CEcaxIGiMYi-Zfugh_V1yqdeF_aUVM/view?usp=sharing
!conda install -y gdown
!gdown --id "1L5CEcaxIGiMYi-Zfugh_V1yqdeF_aUVM"
!unzip -q Minerals.zip
!rm Minerals.zip

In [1]:
import os;
import torch;
import torch.fft;
import torchvision;
import numpy as np;
from PIL import Image;
from typing import Tuple;
from random import randint;
from torch import nn, optim;
from torch.utils import data;
import torch.nn.functional as F;
from torch.backends import cudnn;
from torch.utils.data import Dataset;
from torch.nn.parameter import Parameter;
import torchvision.transforms.functional as FT;

In [2]:
cudnn.benchmark = True;

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu");

batch_size = 64;
learning_rate = 1e-3;
epochs = 4000;

In [3]:
class DMlp(nn.Module):

    def __init__(self, dim: int, growth_rate: float = 2.0) -> None:
        super().__init__();
        hidden_dim = int(dim * growth_rate);
        self.conv_0 = nn.Sequential(
            nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups = dim),
            nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
        );
        self.act = nn.GELU();
        self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0);

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_0(x);
        x = self.act(x);
        x = self.conv_1(x);
        return x;

class PCFN(nn.Module):

    def __init__(self, dim: int, growth_rate: float = 2.0, p_rate: float = 0.25) -> None:
        super().__init__();
        hidden_dim = int(dim * growth_rate);
        p_dim = int(hidden_dim * p_rate);
        self.conv_0 = nn.Conv2d(dim, hidden_dim, 1, 1, 0);
        self.conv_1 = nn.Conv2d(p_dim, p_dim, 3, 1, 1);
        self.act = nn.GELU();
        self.conv_2 = nn.Conv2d(hidden_dim, dim, 1, 1, 0);
        self.p_dim = p_dim;
        self.hidden_dim = hidden_dim;

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if(self.training):
            x = self.act(self.conv_0(x));
            x1, x2 = torch.split(x, [self.p_dim, self.hidden_dim - self.p_dim], dim = 1);
            x1 = self.act(self.conv_1(x1));
            x = self.conv_2(torch.cat([x1, x2], dim = 1));
        else:
            x = self.act(self.conv_0(x));
            x[:, :self.p_dim, :, :] = self.act(self.conv_1(x[:, :self.p_dim, :, :]));
            x = self.conv_2(x);
        return x;

class SMFA(nn.Module):

    def __init__(self, dim: int = 36) -> None:
        super().__init__();
        self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0);
        self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0);
        self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0);
        self.lde = DMlp(dim, 2);
        self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups = dim);
        self.gelu = nn.GELU();
        self.down_scale = 8;
        self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)));
        self.beta = nn.Parameter(torch.zeros((1, dim, 1, 1)));

    def forward(self, f: torch.Tensor) -> torch.Tensor:
        _, _, h, w = f.shape;
        y, x = self.linear_0(f).chunk(2, dim = 1);
        x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)));
        x_v = torch.var(x, dim = (-2, -1), keepdim = True);
        x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.beta)), size = (h, w), mode = 'nearest');
        y_d = self.lde(y);
        return self.linear_2(x_l + y_d);

class FMB(nn.Module):

    def __init__(self, dim: int, ffn_scale: float = 2.0) -> None:
        super().__init__();
        self.smfa = SMFA(dim);
        self.pcfn = PCFN(dim, ffn_scale);

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.smfa(F.normalize(x)) + x;
        x = self.pcfn(F.normalize(x)) + x;
        return x;
 
class SMFANet(nn.Module):

    def __init__(self, dim: int = 48, n_blocks: int = 8, ffn_scale: float = 2.0, upscaling_factor: int = 4) -> None:
        super().__init__();
        self.scale = upscaling_factor;
        self.to_feat = nn.Conv2d(3, dim, 3, 1, 1);
        self.feats = nn.Sequential(*[FMB(dim, ffn_scale) for _ in range(n_blocks)]);
        self.to_img = nn.Sequential(
            nn.Conv2d(dim, 3 * upscaling_factor ** 2, 3, 1, 1),
            nn.PixelShuffle(upscaling_factor)
        );

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        res = F.interpolate(x, size = (x.size(2) * 4, x.size(3) * 4), mode = "bilinear");
        x = self.to_feat(x);
        x = self.feats(x) + x;
        x = self.to_img(x);
        return x + res;

In [4]:
def transform(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    if(randint(0, 10) > 5):
        input = FT.vflip(input);
        target = FT.vflip(target);
    if(randint(0, 10) > 5):
        input = FT.hflip(input);
        target = FT.hflip(target);
    angle = randint(-30, 30);
    input = FT.rotate(input, angle, interpolation = torchvision.transforms.InterpolationMode.BILINEAR);
    target = FT.rotate(target, angle, interpolation = torchvision.transforms.InterpolationMode.BILINEAR);
    return (input, target);

class MineralDataset(Dataset):

    def __init__(self, root_dir: str, transform) -> None:
        self.root_dir = root_dir;
        self.transform = transform;
        self.img_64 = [];
        self.img_256 = [];
        for idx in range(334):
            img_64 = torch.cat([FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "800", f"{(idx + 1):03}_800.jpg")).convert('L')),
                                FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "1050", f"{(idx + 1):03}_1050.jpg")).convert('L')),
                                FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "1550", f"{(idx + 1):03}_1550.jpg")).convert('L'))], dim = 0);
            img_256 = torch.cat([FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "800", f"{(idx + 1):03}_800.jpg")).convert('L')),
                                 FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "1050", f"{(idx + 1):03}_1050.jpg")).convert('L')),
                                 FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "1550", f"{(idx + 1):03}_1550.jpg")).convert('L'))], dim = 0);
            self.img_64.append(img_64);
            self.img_256.append(img_256);

    def __len__(self) -> int:
        return 334;

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.transform(self.img_64[idx], self.img_256[idx]);

In [5]:
class FFTLoss(nn.Module):

    def __init__(self, loss_weight: float = 1.0, reduction: str = 'mean') -> None:
        super().__init__();
        self.loss_weight = loss_weight;
        self.criterion = torch.nn.L1Loss(reduction = reduction);

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred_fft = torch.fft.rfft2(pred);
        target_fft = torch.fft.rfft2(target);
        pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim = -1);
        target_fft = torch.stack([target_fft.real, target_fft.imag], dim = -1);
        return self.loss_weight * (self.criterion(pred_fft, target_fft));

In [None]:
def main():
    dataset = MineralDataset('Minerals', transform);
    dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True);
    net = SMFANet().train().to(device);
    opt_net = optim.Adam(net.parameters(), lr = learning_rate);
    fft = FFTLoss(0.1);
    avg_pix_loss = 0.0;
    avg_fft_loss = 0.0;
    it = 0;
    for i in range(epochs):
        for lr, hr in iter(dataloader):
            lr = lr.to(device);
            hr = hr.to(device);

            sr = net(lr);
            
            pix_loss = F.l1_loss(sr, hr);
            fft_loss = fft(sr, hr);
            avg_pix_loss += pix_loss.item();
            avg_fft_loss += fft_loss.item();
            total_loss = pix_loss + fft_loss;
            opt_net.zero_grad();
            total_loss.backward();
            opt_net.step();
            it += 1;
        if(((i + 1) % 100) == 0):
            print("epoch: {:d}, pix: {:f}, fft: {:f}".format(i + 1, avg_pix_loss / it, avg_fft_loss / it));
            torch.save(net.state_dict(), 'smfanet_{:d}.pth'.format(i + 1));
            avg_pix_loss = 0.0;
            avg_fft_loss = 0.0;
            it = 0;
        if(((i + 1) % 400) == 0):
            for param_group in opt_net.param_groups:
                param_group['lr'] /= 2.0;

main();

epoch: 100, pix: 0.067405, fft: 0.400208
epoch: 200, pix: 0.058782, fft: 0.346130
epoch: 300, pix: 0.055572, fft: 0.323497
epoch: 400, pix: 0.054174, fft: 0.311843
epoch: 500, pix: 0.052730, fft: 0.302955
epoch: 600, pix: 0.052397, fft: 0.298569
epoch: 700, pix: 0.051848, fft: 0.293731
epoch: 800, pix: 0.051654, fft: 0.290456
epoch: 900, pix: 0.051052, fft: 0.286622
epoch: 1000, pix: 0.050863, fft: 0.285137
epoch: 1100, pix: 0.050707, fft: 0.283876
epoch: 1200, pix: 0.050620, fft: 0.282791
epoch: 1300, pix: 0.050253, fft: 0.281285
epoch: 1400, pix: 0.050178, fft: 0.280643
epoch: 1500, pix: 0.050009, fft: 0.279866
epoch: 1600, pix: 0.050010, fft: 0.279591
epoch: 1700, pix: 0.049772, fft: 0.278819
epoch: 1800, pix: 0.049777, fft: 0.278622
epoch: 1900, pix: 0.049754, fft: 0.278467
epoch: 2000, pix: 0.049688, fft: 0.278051
epoch: 2100, pix: 0.049552, fft: 0.277659
epoch: 2200, pix: 0.049547, fft: 0.277469
epoch: 2300, pix: 0.049539, fft: 0.277497
epoch: 2400, pix: 0.049524, fft: 0.277293
e