In [None]:
# https://drive.google.com/file/d/1sb6AS8JH6jxScJP9ZkIB4QTW5WIDHQTW/view?usp=sharing
!conda install -y gdown
!gdown --id "1sb6AS8JH6jxScJP9ZkIB4QTW5WIDHQTW"
!unzip -q SWIR.zip
!rm SWIR.zip

In [1]:
import os;
import torch;
import torch.fft;
import torchvision;
import numpy as np;
from tqdm import tqdm;
from PIL import Image;
from typing import Tuple;
from random import randint;
from torch import nn, optim;
from torch.utils import data;
from einops import rearrange;
import torch.nn.functional as F;
from torch.backends import cudnn;
from torch.utils.data import Dataset;
from torch.nn.init import trunc_normal_;
from torch.nn.parameter import Parameter;
import torchvision.transforms.functional as FT;

In [2]:
cudnn.benchmark = True;

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu");

batch_size = 64;
learning_rate = 1e-3;
epochs = 8000;

In [3]:
class WindowSelfAttention(nn.Module):

    def __init__(self, dim: int, num_heads: int = 8, window_size: int = 8) -> None:
        super().__init__();
        self.num_heads = num_heads;
        self.window_size = window_size;
        head_dim = dim // num_heads;
        self.scale = head_dim ** -0.5;
        self.query = nn.Conv1d(dim, dim, kernel_size = 1, padding = 0, bias = False);
        self.key = nn.Conv1d(dim, dim, kernel_size = 1, padding = 0, bias = False);
        self.value = nn.Conv1d(dim, dim, kernel_size = 1, padding = 0, bias = False);
        self.beta = nn.Parameter(torch.zeros(num_heads, window_size ** 2, window_size ** 2));
        self.proj_out = nn.Conv1d(dim, dim, kernel_size = 1, padding = 0, bias = True);
        trunc_normal_(self.beta, std = 0.02);

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape;
        # Pad
        pad_r = (self.window_size - w % self.window_size) % self.window_size;
        pad_b = (self.window_size - h % self.window_size) % self.window_size;
        x = F.pad(x, (0, pad_r, 0, pad_b));
        # Window partition
        x = rearrange(x, 'b c (h s1) (w s2) -> (b h w) c (s1 s2)', s1 = self.window_size, s2 = self.window_size);
        # Project
        q = self.query(x);
        k = self.key(x);
        v = self.value(x);
        # Attention
        q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> b h d n', h = self.num_heads), [q, k, v]);
        attn = torch.einsum('b h d n, b h d m -> b h n m', q, k) * self.scale + self.beta;
        attn = attn.softmax(dim = -1);
        x = torch.einsum('b h n m, b h d m -> b h d n', attn, v);
        x = rearrange(x, 'b i c j -> b (i c) j');
        x = self.proj_out(x);
        # Reverse window partition
        x = rearrange(x, 'B c (s1 s2) -> B c s1 s2', s1 = self.window_size, s2 = self.window_size);
        b = int(x.shape[0] / (h * w / self.window_size / self.window_size));
        x = rearrange(x, '(b h w) c s1 s2 -> b c (h s1) (w s2)', b = b, h = (h + pad_b) // self.window_size, w = (w + pad_r) // self.window_size);
        if((pad_r > 0) or (pad_b > 0)):
            x = x[:, :, :h, :w].contiguous();
        return x;

class PCFN(nn.Module):

    def __init__(self, dim: int, growth_rate: float = 2.0, p_rate: float = 0.25) -> None:
        super().__init__();
        hidden_dim = int(dim * growth_rate);
        p_dim = int(hidden_dim * p_rate);
        self.conv_0 = nn.Conv2d(dim, hidden_dim, kernel_size = 1, padding = 0);
        self.conv_1 = nn.Conv2d(p_dim, p_dim, kernel_size = 3, padding = 1, padding_mode = "reflect");
        self.act = nn.GELU();
        self.conv_2 = nn.Conv2d(hidden_dim, dim, kernel_size = 1, padding = 0);
        self.p_dim = p_dim;
        self.hidden_dim = hidden_dim;

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if(self.training):
            x = self.act(self.conv_0(x));
            x1, x2 = torch.split(x, [self.p_dim, self.hidden_dim - self.p_dim], dim = 1);
            x1 = self.act(self.conv_1(x1));
            x = self.conv_2(torch.cat([x1, x2], dim = 1));
        else:
            x = self.act(self.conv_0(x));
            x[:, :self.p_dim, :, :] = self.act(self.conv_1(x[:, :self.p_dim, :, :]));
            x = self.conv_2(x);
        return x;

class Transformer(nn.Module):

    def __init__(self,
                 dim: int,
                 num_heads: int = 4,
                 window_size: int = 8,
                 mlp_ratio: int = 4) -> None:
        super().__init__();
        self.attn = WindowSelfAttention(dim, num_heads, window_size);
        self.pcfn = PCFN(dim, mlp_ratio);

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.attn(F.normalize(x)) + x;
        x = self.pcfn(F.normalize(x)) + x;
        return x;

class SRViT(nn.Module):

    def __init__(self, n_feats: int = 40, n_heads: int = 8, ratio: int = 2, blocks = 5, upscaling_factor: int = 4) -> None:
        super(SRViT, self).__init__();
        self.head = nn.Conv2d(3, n_feats, 3, 1, 1);
        self.body = nn.Sequential(*[Transformer(n_feats, n_heads, 8, ratio) for i in range(blocks)]);
        self.upsampling = nn.Sequential(
            nn.Conv2d(n_feats, 3 * upscaling_factor ** 2, kernel_size = 3, padding = 1, padding_mode = "reflect"),
            nn.PixelShuffle(upscaling_factor)
        );
        
    def forward(self, x):
        res = F.interpolate(x, size = (x.size(2) * 4, x.size(3) * 4));
        x = self.head(x);
        x = self.upsampling(self.body(x) + x);
        return x + res;

In [4]:
def transform(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    if(randint(0, 10) > 5):
        input = FT.vflip(input);
        target = FT.vflip(target);
    if(randint(0, 10) > 5):
        input = FT.hflip(input);
        target = FT.hflip(target);
    angle = randint(-30, 30);
    input = FT.rotate(input, angle, interpolation = torchvision.transforms.InterpolationMode.BILINEAR);
    target = FT.rotate(target, angle, interpolation = torchvision.transforms.InterpolationMode.BILINEAR);
    return (input, target);

class SWIRDataset(Dataset):

    def __init__(self, root_dir: str, transform) -> None:
        self.root_dir = root_dir;
        self.transform = transform;
        self.img_64 = [];
        self.img_256 = [];
        for idx in range(715):
            self.img_64.append(torch.cat([FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "800", f"{(idx + 1):04}.jpg")).convert('L')),
                                          FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "1050", f"{(idx + 1):04}.jpg")).convert('L')),
                                          FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "1550", f"{(idx + 1):04}.jpg")).convert('L'))], dim = 0));
            self.img_256.append(torch.cat([FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "800", f"{(idx + 1):04}.jpg")).convert('L')),
                                           FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "1050", f"{(idx + 1):04}.jpg")).convert('L')),
                                           FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "1550", f"{(idx + 1):04}.jpg")).convert('L'))], dim = 0));
    
    def __len__(self) -> int:
        return 715;

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_64, img_256 = self.transform(self.img_64[idx], self.img_256[idx]);
        return img_64, img_256;

In [5]:
class FFTLoss(nn.Module):

    def __init__(self, loss_weight: float = 1.0, reduction: str = 'mean') -> None:
        super().__init__();
        self.loss_weight = loss_weight;
        self.criterion = torch.nn.L1Loss(reduction = reduction);

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred_fft = torch.fft.rfft2(pred);
        target_fft = torch.fft.rfft2(target);
        pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim = -1);
        target_fft = torch.stack([target_fft.real, target_fft.imag], dim = -1);
        return self.loss_weight * (self.criterion(pred_fft, target_fft));

In [6]:
def main():
    dataset = SWIRDataset('SWIR', transform);
    dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True);
    net = SRViT().train().to(device);
    opt_net = optim.Adam(net.parameters(), lr = learning_rate);
    fft = FFTLoss(0.1);
    avg_pix_loss = 0.0;
    avg_fft_loss = 0.0;
    it = 0;
    for i in tqdm(range(epochs)):
        for lr, hr in iter(dataloader):
            lr = lr.to(device);
            hr = hr.to(device);

            sr = net(lr);
            
            pix_loss = F.l1_loss(sr, hr);
            fft_loss = fft(sr, hr);
            avg_pix_loss += pix_loss.item();
            avg_fft_loss += fft_loss.item();
            total_loss = pix_loss + fft_loss;
            opt_net.zero_grad();
            total_loss.backward();
            opt_net.step();
            it += 1;
        if(((i + 1) % 100) == 0):
            print("epoch: {:d}, pix: {:f}, fft: {:f}".format(i + 1, avg_pix_loss / it, avg_fft_loss / it));
            torch.save(net.state_dict(), 'srvit_{:d}.pth'.format(i + 1));
            avg_pix_loss = 0.0;
            avg_fft_loss = 0.0;
            it = 0;
        if(((i + 1) % 1500) == 0):
            for param_group in opt_net.param_groups:
                param_group['lr'] /= 2.0;

main();

  1%|▏         | 100/8000 [06:17<8:15:50,  3.77s/it]

epoch: 100, pix: 0.091184, fft: 0.696591


  2%|▎         | 200/8000 [12:32<8:07:54,  3.75s/it]

epoch: 200, pix: 0.075961, fft: 0.646417


  4%|▍         | 300/8000 [18:48<8:02:36,  3.76s/it]

epoch: 300, pix: 0.071911, fft: 0.632141


  5%|▌         | 400/8000 [25:03<7:57:52,  3.77s/it]

epoch: 400, pix: 0.070325, fft: 0.628124


  6%|▋         | 500/8000 [31:18<7:51:04,  3.77s/it]

epoch: 500, pix: 0.068901, fft: 0.624101


  8%|▊         | 600/8000 [37:34<7:43:44,  3.76s/it]

epoch: 600, pix: 0.068174, fft: 0.620369


  9%|▉         | 700/8000 [43:49<7:36:36,  3.75s/it]

epoch: 700, pix: 0.067260, fft: 0.616512


 10%|█         | 800/8000 [50:04<7:30:42,  3.76s/it]

epoch: 800, pix: 0.066831, fft: 0.614156


 11%|█▏        | 900/8000 [56:19<7:24:10,  3.75s/it]

epoch: 900, pix: 0.066153, fft: 0.611085


 12%|█▎        | 1000/8000 [1:02:34<7:18:44,  3.76s/it]

epoch: 1000, pix: 0.065857, fft: 0.611202


 14%|█▍        | 1100/8000 [1:08:49<7:13:53,  3.77s/it]

epoch: 1100, pix: 0.065449, fft: 0.609500


 15%|█▌        | 1200/8000 [1:15:05<7:06:10,  3.76s/it]

epoch: 1200, pix: 0.065068, fft: 0.609297


 16%|█▋        | 1300/8000 [1:21:20<6:59:37,  3.76s/it]

epoch: 1300, pix: 0.064792, fft: 0.607571


 18%|█▊        | 1400/8000 [1:27:35<6:53:57,  3.76s/it]

epoch: 1400, pix: 0.064226, fft: 0.605794


 19%|█▉        | 1500/8000 [1:33:50<6:47:11,  3.76s/it]

epoch: 1500, pix: 0.064057, fft: 0.606664


 20%|██        | 1600/8000 [1:40:06<6:41:49,  3.77s/it]

epoch: 1600, pix: 0.062994, fft: 0.604130


 21%|██▏       | 1700/8000 [1:46:21<6:35:55,  3.77s/it]

epoch: 1700, pix: 0.062948, fft: 0.603903


 22%|██▎       | 1800/8000 [1:52:36<6:28:13,  3.76s/it]

epoch: 1800, pix: 0.062520, fft: 0.602401


 24%|██▍       | 1900/8000 [1:58:52<6:23:11,  3.77s/it]

epoch: 1900, pix: 0.062549, fft: 0.602243


 25%|██▌       | 2000/8000 [2:05:08<6:17:12,  3.77s/it]

epoch: 2000, pix: 0.062156, fft: 0.601105


 26%|██▋       | 2100/8000 [2:11:24<6:09:49,  3.76s/it]

epoch: 2100, pix: 0.062042, fft: 0.599773


 28%|██▊       | 2200/8000 [2:17:39<6:04:08,  3.77s/it]

epoch: 2200, pix: 0.061789, fft: 0.599316


 29%|██▉       | 2300/8000 [2:23:55<5:58:38,  3.78s/it]

epoch: 2300, pix: 0.061664, fft: 0.599296


 30%|███       | 2400/8000 [2:30:11<5:51:26,  3.77s/it]

epoch: 2400, pix: 0.061613, fft: 0.599888


 31%|███▏      | 2500/8000 [2:36:27<5:45:24,  3.77s/it]

epoch: 2500, pix: 0.061496, fft: 0.599121


 32%|███▎      | 2600/8000 [2:42:43<5:38:25,  3.76s/it]

epoch: 2600, pix: 0.061283, fft: 0.598686


 34%|███▍      | 2700/8000 [2:48:58<5:32:34,  3.77s/it]

epoch: 2700, pix: 0.061229, fft: 0.598160


 35%|███▌      | 2800/8000 [2:55:14<5:26:42,  3.77s/it]

epoch: 2800, pix: 0.060836, fft: 0.597658


 36%|███▋      | 2900/8000 [3:01:30<5:20:26,  3.77s/it]

epoch: 2900, pix: 0.060853, fft: 0.598230


 38%|███▊      | 3000/8000 [3:07:45<5:15:44,  3.79s/it]

epoch: 3000, pix: 0.060639, fft: 0.597088


 39%|███▉      | 3100/8000 [3:14:01<5:07:03,  3.76s/it]

epoch: 3100, pix: 0.059892, fft: 0.596114


 40%|████      | 3200/8000 [3:20:17<5:01:12,  3.77s/it]

epoch: 3200, pix: 0.059795, fft: 0.595485


 41%|████▏     | 3300/8000 [3:26:33<4:55:12,  3.77s/it]

epoch: 3300, pix: 0.059571, fft: 0.594390


 42%|████▎     | 3400/8000 [3:32:48<4:48:45,  3.77s/it]

epoch: 3400, pix: 0.059582, fft: 0.595654


 44%|████▍     | 3500/8000 [3:39:04<4:42:53,  3.77s/it]

epoch: 3500, pix: 0.059470, fft: 0.595834


 45%|████▌     | 3600/8000 [3:45:19<4:37:22,  3.78s/it]

epoch: 3600, pix: 0.059392, fft: 0.595383


 46%|████▋     | 3700/8000 [3:51:35<4:29:39,  3.76s/it]

epoch: 3700, pix: 0.059361, fft: 0.595302


 48%|████▊     | 3800/8000 [3:57:50<4:22:56,  3.76s/it]

epoch: 3800, pix: 0.059108, fft: 0.593937


 49%|████▉     | 3900/8000 [4:04:05<4:17:00,  3.76s/it]

epoch: 3900, pix: 0.059167, fft: 0.595123


 50%|█████     | 4000/8000 [4:10:20<4:10:51,  3.76s/it]

epoch: 4000, pix: 0.058936, fft: 0.593702


 51%|█████▏    | 4100/8000 [4:16:35<4:04:35,  3.76s/it]

epoch: 4100, pix: 0.058916, fft: 0.594553


 52%|█████▎    | 4200/8000 [4:22:50<3:58:48,  3.77s/it]

epoch: 4200, pix: 0.058883, fft: 0.594549


 54%|█████▍    | 4300/8000 [4:29:05<3:52:00,  3.76s/it]

epoch: 4300, pix: 0.058825, fft: 0.594475


 55%|█████▌    | 4400/8000 [4:35:21<3:45:24,  3.76s/it]

epoch: 4400, pix: 0.058813, fft: 0.594435


 56%|█████▋    | 4500/8000 [4:41:36<3:39:15,  3.76s/it]

epoch: 4500, pix: 0.058626, fft: 0.593842


 57%|█████▊    | 4600/8000 [4:47:51<3:33:16,  3.76s/it]

epoch: 4600, pix: 0.057936, fft: 0.591740


 59%|█████▉    | 4700/8000 [4:54:06<3:27:22,  3.77s/it]

epoch: 4700, pix: 0.057996, fft: 0.593537


 60%|██████    | 4800/8000 [5:00:21<3:20:47,  3.76s/it]

epoch: 4800, pix: 0.057918, fft: 0.592452


 61%|██████▏   | 4900/8000 [5:06:36<3:14:15,  3.76s/it]

epoch: 4900, pix: 0.057772, fft: 0.592333


 62%|██████▎   | 5000/8000 [5:12:51<3:07:49,  3.76s/it]

epoch: 5000, pix: 0.057925, fft: 0.592763


 64%|██████▍   | 5100/8000 [5:19:06<3:01:47,  3.76s/it]

epoch: 5100, pix: 0.057673, fft: 0.591272


 65%|██████▌   | 5200/8000 [5:25:22<2:55:29,  3.76s/it]

epoch: 5200, pix: 0.057801, fft: 0.592767


 66%|██████▋   | 5300/8000 [5:31:37<2:49:24,  3.76s/it]

epoch: 5300, pix: 0.057728, fft: 0.593437


 68%|██████▊   | 5400/8000 [5:37:52<2:43:02,  3.76s/it]

epoch: 5400, pix: 0.057642, fft: 0.592921


 69%|██████▉   | 5500/8000 [5:44:07<2:36:29,  3.76s/it]

epoch: 5500, pix: 0.057543, fft: 0.591707


 70%|███████   | 5600/8000 [5:50:22<2:30:42,  3.77s/it]

epoch: 5600, pix: 0.057601, fft: 0.592371


 71%|███████▏  | 5700/8000 [5:56:37<2:24:10,  3.76s/it]

epoch: 5700, pix: 0.057369, fft: 0.590957


 72%|███████▎  | 5800/8000 [6:02:52<2:18:01,  3.76s/it]

epoch: 5800, pix: 0.057588, fft: 0.592881


 74%|███████▍  | 5900/8000 [6:09:07<2:11:43,  3.76s/it]

epoch: 5900, pix: 0.057421, fft: 0.592181


 75%|███████▌  | 6000/8000 [6:15:22<2:05:12,  3.76s/it]

epoch: 6000, pix: 0.057390, fft: 0.591710


 76%|███████▋  | 6100/8000 [6:21:37<1:58:52,  3.75s/it]

epoch: 6100, pix: 0.057131, fft: 0.592216


 78%|███████▊  | 6200/8000 [6:27:52<1:53:00,  3.77s/it]

epoch: 6200, pix: 0.057094, fft: 0.591173


 79%|███████▉  | 6300/8000 [6:34:08<1:46:53,  3.77s/it]

epoch: 6300, pix: 0.056954, fft: 0.590567


 80%|████████  | 6400/8000 [6:40:24<1:40:25,  3.77s/it]

epoch: 6400, pix: 0.056990, fft: 0.591739


 81%|████████▏ | 6500/8000 [6:46:40<1:34:11,  3.77s/it]

epoch: 6500, pix: 0.056874, fft: 0.590762


 82%|████████▎ | 6600/8000 [6:52:55<1:27:44,  3.76s/it]

epoch: 6600, pix: 0.056972, fft: 0.590914


 84%|████████▍ | 6700/8000 [6:59:10<1:21:26,  3.76s/it]

epoch: 6700, pix: 0.056881, fft: 0.590867


 85%|████████▌ | 6800/8000 [7:05:25<1:15:11,  3.76s/it]

epoch: 6800, pix: 0.056839, fft: 0.590773


 86%|████████▋ | 6900/8000 [7:11:40<1:08:53,  3.76s/it]

epoch: 6900, pix: 0.056596, fft: 0.588996


 88%|████████▊ | 7000/8000 [7:17:56<1:02:44,  3.76s/it]

epoch: 7000, pix: 0.056812, fft: 0.590671


 89%|████████▉ | 7100/8000 [7:24:11<56:28,  3.77s/it]  

epoch: 7100, pix: 0.056691, fft: 0.589806


 90%|█████████ | 7200/8000 [7:30:27<50:09,  3.76s/it]

epoch: 7200, pix: 0.056776, fft: 0.590907


 91%|█████████▏| 7300/8000 [7:36:43<43:57,  3.77s/it]

epoch: 7300, pix: 0.056890, fft: 0.592414


 92%|█████████▎| 7400/8000 [7:42:58<37:41,  3.77s/it]

epoch: 7400, pix: 0.056817, fft: 0.592108


 94%|█████████▍| 7500/8000 [7:49:14<31:21,  3.76s/it]

epoch: 7500, pix: 0.056622, fft: 0.590086


 95%|█████████▌| 7600/8000 [7:55:29<25:06,  3.77s/it]

epoch: 7600, pix: 0.056567, fft: 0.591078


 96%|█████████▋| 7700/8000 [8:01:45<18:50,  3.77s/it]

epoch: 7700, pix: 0.056585, fft: 0.591242


 98%|█████████▊| 7800/8000 [8:08:01<12:32,  3.76s/it]

epoch: 7800, pix: 0.056473, fft: 0.590530


 99%|█████████▉| 7900/8000 [8:14:16<06:16,  3.76s/it]

epoch: 7900, pix: 0.056485, fft: 0.590250


100%|██████████| 8000/8000 [8:20:31<00:00,  3.75s/it]

epoch: 8000, pix: 0.056320, fft: 0.588702



