In [1]:
# https://drive.google.com/file/d/1qumjvCdnDQvZdrS_o-MMgRatNLFd9QTd/view?usp=sharing
!conda install -y gdown
!gdown --id "1qumjvCdnDQvZdrS_o-MMgRatNLFd9QTd"
!unzip -q Minerals.zip
!rm Minerals.zip

/bin/bash: line 1: conda: command not found
Downloading...
From: https://drive.google.com/uc?id=1qumjvCdnDQvZdrS_o-MMgRatNLFd9QTd
To: /kaggle/working/Minerals.zip
100%|██████████████████████████████████████| 13.8M/13.8M [00:00<00:00, 70.4MB/s]


In [11]:
import os;
import torch;
import torch.fft;
import torchvision;
import numpy as np;
from PIL import Image;
from typing import Tuple;
from random import randint;
from torch import nn, optim;
from torch.utils import data;
import torch.nn.functional as F;
from torch.backends import cudnn;
from torch.utils.data import Dataset;
from torch.nn.parameter import Parameter;
import torchvision.transforms.functional as FT;

In [12]:
cudnn.benchmark = True;

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu");

batch_size = 16;
learning_rate = 3e-4;
epochs = 2500;

In [13]:
class UNet(nn.Module):

    def __init__(self) -> None:
        super().__init__();
        self.enc0 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size = 5, padding = 2, stride = 2),
            nn.PReLU(64)
        );
        self.enc1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size = 5, padding = 2, stride = 2),
            nn.PReLU(128)
        );
        self.enc2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size = 5, padding = 2, stride = 2),
            nn.PReLU(256)
        );
        self.enc3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size = 5, padding = 2, stride = 2),
            nn.PReLU(512)
        );
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size = 5, padding = 2, output_padding = 1, stride = 2),
            nn.PReLU(256)
        );
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, kernel_size = 5, padding = 2, output_padding = 1, stride = 2),
            nn.PReLU(128)
        );
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size = 5, padding = 2, output_padding = 1, stride = 2),
            nn.PReLU(64)
        );
        self.dec0 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, kernel_size = 5, padding = 2, output_padding = 1, stride = 2)
        );

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x0 = self.enc0(x);
        x1 = self.enc1(x0);
        x2 = self.enc2(x1);
        x3 = self.enc3(x2);
        y = self.dec3(x3);
        y = self.dec2(torch.cat([F.interpolate(y, size = x2.shape[2:], mode = "nearest"), x2], dim = 1));
        y = self.dec1(torch.cat([F.interpolate(y, size = x1.shape[2:], mode = "nearest"), x1], dim = 1));
        y = self.dec0(torch.cat([F.interpolate(y, size = x0.shape[2:], mode = "nearest"), x0], dim = 1));
        return y;

In [14]:
def transform(input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    target = F.interpolate(target.unsqueeze(0), size = (64, 64), mode = "bicubic").squeeze(0);
    if(randint(0, 10) > 5):
        input = FT.vflip(input);
        target = FT.vflip(target);
    if(randint(0, 10) > 5):
        input = FT.hflip(input);
        target = FT.hflip(target);
    angle = randint(-30, 30);
    input = FT.rotate(input, angle, interpolation = torchvision.transforms.InterpolationMode.BILINEAR);
    target = FT.rotate(target, angle, interpolation = torchvision.transforms.InterpolationMode.BILINEAR);
    return (input, target);

class MineralDataset(Dataset):

    def __init__(self, root_dir: str, transform) -> None:
        self.root_dir = root_dir;
        self.transform = transform;

    def __len__(self) -> int:
        return 334;

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        img_64 = torch.cat([FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "800", f"{(idx + 1):03}_800.jpg")).convert('L')),
                            FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "1050", f"{(idx + 1):03}_1050.jpg")).convert('L')),
                            FT.to_tensor(Image.open(os.path.join(self.root_dir, "64", "1550", f"{(idx + 1):03}_1550.jpg")).convert('L'))], dim = 0);
        img_256 = torch.cat([FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "800", f"{(idx + 1):03}_800.jpg")).convert('L')),
                             FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "1050", f"{(idx + 1):03}_1050.jpg")).convert('L')),
                             FT.to_tensor(Image.open(os.path.join(self.root_dir, "256", "1550", f"{(idx + 1):03}_1550.jpg")).convert('L'))], dim = 0);
        img_64, img_256 = self.transform(img_64, img_256);
        return img_64, img_256;

In [15]:
class FFTLoss(nn.Module):

    def __init__(self, loss_weight: float = 1.0, reduction: str = 'mean') -> None:
        super().__init__();
        self.loss_weight = loss_weight;
        self.criterion = torch.nn.L1Loss(reduction = reduction);

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred_fft = torch.fft.rfft2(pred);
        target_fft = torch.fft.rfft2(target);
        pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim = -1);
        target_fft = torch.stack([target_fft.real, target_fft.imag], dim = -1);
        return self.loss_weight * (self.criterion(pred_fft, target_fft));

In [16]:
def main():
    dataset = MineralDataset('Minerals', transform);
    dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True);
    net = UNet().train().to(device);
    opt_net = optim.Adam(net.parameters(), lr = learning_rate);
    fft = FFTLoss(0.1);
    avg_pix_loss = 0.0;
    avg_fft_loss = 0.0;
    it = 0;
    for i in range(epochs):
        for lr, hr in iter(dataloader):
            lr = lr.to(device);
            hr = hr.to(device);

            sr = net(lr);
            
            pix_loss = F.l1_loss(sr, hr);
            fft_loss = fft(sr, hr);
            avg_pix_loss += pix_loss.item();
            avg_fft_loss += fft_loss.item();
            total_loss = pix_loss + fft_loss;
            opt_net.zero_grad();
            total_loss.backward();
            opt_net.step();
            it += 1;
        if(((i + 1) % 100) == 0):
            print("epoch: {:d}, pix: {:f}, fft: {:f}".format(i + 1, avg_pix_loss / it, avg_fft_loss / it));
            torch.save(net.state_dict(), 'unet_{:d}.pth'.format(i + 1));
            avg_pix_loss = 0.0;
            avg_fft_loss = 0.0;
            it = 0;
        if(((i + 1) % 500) == 0):
            for param_group in opt_net.param_groups:
                param_group['lr'] /= 2.0;

main();

epoch: 100, pix: 0.052502, fft: 0.171430
epoch: 200, pix: 0.045331, fft: 0.151658
epoch: 300, pix: 0.040792, fft: 0.145294
epoch: 400, pix: 0.037367, fft: 0.139061
epoch: 500, pix: 0.035157, fft: 0.134398
epoch: 600, pix: 0.033020, fft: 0.129531
epoch: 700, pix: 0.031991, fft: 0.127005
epoch: 800, pix: 0.031300, fft: 0.125308
epoch: 900, pix: 0.030759, fft: 0.123918
epoch: 1000, pix: 0.030327, fft: 0.122847
epoch: 1100, pix: 0.029580, fft: 0.120815
epoch: 1200, pix: 0.029228, fft: 0.119850
epoch: 1300, pix: 0.028988, fft: 0.119118
epoch: 1400, pix: 0.028807, fft: 0.118663
epoch: 1500, pix: 0.028618, fft: 0.118114
epoch: 1600, pix: 0.028304, fft: 0.117167
epoch: 1700, pix: 0.028179, fft: 0.116804
epoch: 1800, pix: 0.028060, fft: 0.116451
epoch: 1900, pix: 0.027987, fft: 0.116247
epoch: 2000, pix: 0.027908, fft: 0.116025
epoch: 2100, pix: 0.027744, fft: 0.115496
epoch: 2200, pix: 0.027708, fft: 0.115389
epoch: 2300, pix: 0.027631, fft: 0.115137
epoch: 2400, pix: 0.027608, fft: 0.115051
e