In [1]:
import math
import logging
from functools import partial
from collections import OrderedDict
from copy import Error, deepcopy
from re import S
from numpy.lib.arraypad import pad
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import torch.fft
from torch.nn.modules.container import Sequential
#from main_afnonet import get_args
from torch.utils.checkpoint import checkpoint_sequential

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def kernel(in_chan=2, up_dim=32):
    """
        Kernel network apply on grid
    """
    layers = nn.Sequential(
                nn.Linear(in_chan, up_dim, bias=True), torch.nn.GELU(),
                nn.Linear(up_dim, up_dim, bias=True), torch.nn.GELU(),
                nn.Linear(up_dim, 1, bias=False)
            )
    return layers

In [3]:
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, dim1, dim2,modes1 = None, modes2 = None):
        super(SpectralConv2d, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """
        in_channels = int(in_channels)
        out_channels = int(out_channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dim1 = dim1 #output dimensions
        self.dim2 = dim2
        #self.batch_norm = nn.BatchNorm2d(out_channels)
        #self.dropout = nn.Dropout(p=0.42)
        if modes1 is not None:
            self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
            self.modes2 = modes2
        else:
            self.modes1 = dim1//2 + 1 #if not given take the highest number of modes can be taken
            self.modes2 = dim2//2 
        self.scale = (1 / (2*in_channels))**(1.0/2.0)
        self.weights1 = nn.Parameter(self.scale * (torch.randn(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)))
        self.weights2 = nn.Parameter(self.scale * (torch.randn(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x, dim1 = None,dim2 = None):
        if dim1 is not None:
            self.dim1 = dim1
            self.dim2 = dim2
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes

        out_ft = torch.zeros(batchsize, self.out_channels,  self.dim1, self.dim2//2 + 1 , dtype=torch.cfloat, device=x.device)
        #print("Out FT Shape = ", out_ft.shape)
        #print("x_ft Shape = ", x_ft.shape)
        #print("Modes1 = ", self.modes1, " Modes2 = ", self.modes2, " self.dim1 = ", self.dim1, " self.dim2//2 + 1 = ", self.dim2//2+1)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(self.dim1, self.dim2))
        #x = self.dropout(x)
        #x = self.batch_norm(x)
        return x


class pointwise_op(nn.Module):
    def __init__(self, in_channel, out_channel,dim1, dim2):
        super(pointwise_op,self).__init__()
        self.conv = nn.Conv2d(int(in_channel), int(out_channel), 1)
        self.dim1 = int(dim1)
        self.dim2 = int(dim2)
        #self.alpha_drop = nn.AlphaDropout(p=0.42)
        #self.layer_norm = nn.LayerNorm([int(out_channel), self.dim1, self.dim2])
        #self.batch_norm = nn.BatchNorm2d(int(out_channel))
        #self.upsample = nn.Upsample(size=[self.dim1, self.dim2], mode='bicubic', align_corners=True)

    def forward(self,x, dim1 = None, dim2 = None):
        if dim1 is None:
            dim1 = self.dim1
            dim2 = self.dim2
        x_out = self.conv(x)
        x_out = torch.nn.functional.interpolate(x_out, size = (dim1, dim2),mode = 'bicubic',align_corners=True)
        #x_out = self.upsample(x_out)
        #x_out = self.alpha_drop(x_out)
        #x_out = self.layer_norm(x_out)
        #x_out = self.batch_norm(x_out)
        return x_out

class UNO_vanilla(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 16/4):
        super(UNO_vanilla, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.padding = pad  # pad the domain if input is non-periodic

        self.fc0 = nn.Linear(self.in_d_co_domain, self.d_co_domain) # input channel is 3: (a(x, y), x, y)

        self.conv0 = SpectralConv2d(self.d_co_domain, 4*factor*self.d_co_domain, 16, 16, 42, 42)

        self.conv1 = SpectralConv2d(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16, 21,21)

        self.conv2 = SpectralConv2d(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,10,10)
        
        self.conv2_1 = SpectralConv2d(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 4, 4,5,5)
        
        self.conv2_9 = SpectralConv2d(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,5,5)
    

        self.conv3 = SpectralConv2d(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16,10,10)

        self.conv4 = SpectralConv2d(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 32, 32,21,21)

        self.conv5 = SpectralConv2d(8*factor*self.d_co_domain, self.d_co_domain, 48, 48,42,42) # will be reshaped

        self.w0 = pointwise_op(self.d_co_domain,4*factor*self.d_co_domain,75, 75) #
        
        self.w1 = pointwise_op(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w2 = pointwise_op(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25) #
        
        self.w2_1 = pointwise_op(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 12, 12)

        self.w2_9 = pointwise_op(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25)
        
        self.w3 = pointwise_op(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w4 = pointwise_op(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 75, 75)
        
        self.w5 = pointwise_op(8*factor*self.d_co_domain, self.d_co_domain, 100, 100) # will be reshaped

        self.fc1 = nn.Linear(2*self.d_co_domain, 4*self.d_co_domain)
        self.fc2 = nn.Linear(4*self.d_co_domain, self.d_co_domain)

    def forward(self, x):
        
        grid = self.get_grid(x[0].shape, x.device)
        x = x.permute(0,2,3,1)

        x_fc0 = self.fc0(x)

        x_fc0 = F.gelu(x_fc0)
        
        x_fc0 = x_fc0.permute(0, 3, 1, 2)
        
        
        x_fc0 = F.pad(x_fc0, [0,self.padding, 0,self.padding])
        
        D1,D2 = x_fc0.shape[-2],x_fc0.shape[-1]
        
        x1_c0 = self.conv0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x2_c0 = self.w0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x_c0 = x1_c0 + x2_c0
        x_c0 = F.gelu(x_c0)

        x1_c1 = self.conv1(x_c0 ,D1//2,D2//2)
        x2_c1 = self.w1(x_c0 ,D1//2,D2//2)
        x_c1 = x1_c1 + x2_c1
        x_c1 = F.gelu(x_c1)

        x1_c2 = self.conv2(x_c1 ,D1//4,D2//4)
        x2_c2 = self.w2(x_c1 ,D1//4,D2//4)
        x_c2 = x1_c2 + x2_c2
        x_c2 = F.gelu(x_c2 )

        x1_c2_1 = self.conv2_1(x_c2,D1//8,D2//8)
        x2_c2_1 = self.w2_1(x_c2,D1//8,D2//8)
        x_c2_1 = x1_c2_1 + x2_c2_1
        x_c2_1 = F.gelu(x_c2_1)

        
        x1_c2_9 = self.conv2_9(x_c2_1,D1//4,D2//4)
        x2_c2_9 = self.w2_9(x_c2_1,D1//4,D2//4)
        x_c2_9 = x1_c2_9 + x2_c2_9
        x_c2_9 = F.gelu(x_c2_9)
        x_c2_9 = torch.cat([x_c2_9, x_c2], dim=1)

        x1_c3 = self.conv3(x_c2_9,D1//2,D2//2)
        x2_c3 = self.w3(x_c2_9,D1//2,D2//2)
        x_c3 = x1_c3 + x2_c3
        x_c3 = F.gelu(x_c3)
        x_c3 = torch.cat([x_c3, x_c1], dim=1)

        x1_c4 = self.conv4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x2_c4 = self.w4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x_c4 = x1_c4 + x2_c4
        x_c4 = F.gelu(x_c4)
        x_c4 = torch.cat([x_c4, x_c0], dim=1)

        x1_c5 = self.conv5(x_c4,D1,D2)
        x2_c5 = self.w5(x_c4,D1,D2)
        x_c5 = x1_c5 + x2_c5
        x_c5 = F.gelu(x_c5)


        x_c5 = torch.cat([x_c5, x_fc0], dim=1)
        if self.padding!=0:
            x_c5 = x_c5[..., :-self.padding, :-self.padding]

        x_c5 = x_c5.permute(0, 2, 3, 1)
        
        x_fc1 = self.fc1(x_c5)
        x_fc1 = F.gelu(x_fc1)
        
        x_out = self.fc2(x_fc1)
        x_out = x_out.permute(0,3,1,2)
        x_out = F.gelu(x_out)
        return x_out
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).cuda()

In [4]:
class UNO(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 16/4):
        super(UNO, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.padding = pad  # pad the domain if input is non-periodic
        self.reduceColor = nn.Conv2d(3, 1, kernel_size=1)
        self.interpolation = nn.Conv2d(3, 3, kernel_size=1)

        self.fc0 = nn.Linear(self.in_d_co_domain, self.d_co_domain) # input channel is 3: (a(x, y), x, y)

        self.conv0 = SpectralConv2d(self.d_co_domain, 4*factor*self.d_co_domain, 16, 16, 42, 42)

        self.conv1 = SpectralConv2d(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16, 21,21)

        self.conv2 = SpectralConv2d(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,10,10)
        
        self.conv2_1 = SpectralConv2d(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 4, 4,5,5)
        
        self.conv2_9 = SpectralConv2d(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,5,5)
    

        self.conv3 = SpectralConv2d(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16,10,10)

        self.conv4 = SpectralConv2d(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 32, 32,21,21)

        self.conv5 = SpectralConv2d(8*factor*self.d_co_domain, self.d_co_domain, 48, 48,42,42) # will be reshaped

        self.w0 = pointwise_op(self.d_co_domain,4*factor*self.d_co_domain,75, 75) #
        
        self.w1 = pointwise_op(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w2 = pointwise_op(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25) #
        
        self.w2_1 = pointwise_op(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 12, 12)
        """
        #VAE Code begin
        self.linearFlatten1 = nn.Linear(384*12*12, 128)
        self.linearFlatten2 = nn.Linear(128, 4)
        self.linearFlatten3 = nn.Linear(128, 4)
        self.softmax1 = nn.Softmax(dim=1)
        self.N = torch.distributions.Normal(0, 10)
        self.N.loc = self.N.loc.cuda()
        self.N.scale = self.N.scale.cuda()
        self.kl = 0
        self.linearUnflatten = nn.Linear(4, 128)
        self.relu = nn.ReLU(True)
        self.linearUnflatten2 = nn.Linear(128, 384*12*12)
        self.relu2 = nn.ReLU(True)
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(384, 12, 12))
        #VAE Code End
        """
        self.w2_9 = pointwise_op(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25)
        
        self.w3 = pointwise_op(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w4 = pointwise_op(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 75, 75)
        
        self.w5 = pointwise_op(8*factor*self.d_co_domain, self.d_co_domain, 100, 100) # will be reshaped

        self.fc1 = nn.Linear(2*self.d_co_domain, 4*self.d_co_domain)
        self.fc2 = nn.Linear(4*self.d_co_domain, 3)
        self.increaseColor = nn.Conv2d(1, 3, kernel_size=1)

    def forward(self, x):
        
        grid = self.get_grid(x[0].shape, x.device)
        #print(grid.shape)
        
        #print(x[0].shape)
        #x_l = torch.cat((x[0], grid), dim=-1).cuda()
        #x_r = torch.cat((x[1], grid), dim=-1).cuda()

        x_l = x[0].permute(0,3,1,2)
        x_r = x[1].permute(0,3,1,2)
        #print(x_l.shape)
        x1 = self.interpolation(x_l)
        x2 = self.interpolation(x_r)
        ##print(x1.shape)
        x = x1 + x2
        x3 = x
        #x = self.reduceColor(x)
        x = x.permute(0,2,3,1)
        #x = torch.cat((x, grid), dim=-1).cuda()
        #print(x.shape)
        x_fc0 = self.fc0(x)

        x_fc0 = F.gelu(x_fc0)
        
        x_fc0 = x_fc0.permute(0, 3, 1, 2)
        
        
        x_fc0 = F.pad(x_fc0, [0,self.padding, 0,self.padding])
        
        D1,D2 = x_fc0.shape[-2],x_fc0.shape[-1]
        
        x1_c0 = self.conv0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x2_c0 = self.w0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x_c0 = x1_c0 + x2_c0
        x_c0 = F.gelu(x_c0)

        x1_c1 = self.conv1(x_c0 ,D1//2,D2//2)
        x2_c1 = self.w1(x_c0 ,D1//2,D2//2)
        x_c1 = x1_c1 + x2_c1
        x_c1 = F.gelu(x_c1)

        x1_c2 = self.conv2(x_c1 ,D1//4,D2//4)
        x2_c2 = self.w2(x_c1 ,D1//4,D2//4)
        x_c2 = x1_c2 + x2_c2
        x_c2 = F.gelu(x_c2 )

        x1_c2_1 = self.conv2_1(x_c2,D1//8,D2//8)
        x2_c2_1 = self.w2_1(x_c2,D1//8,D2//8)
        x_c2_1 = x1_c2_1 + x2_c2_1
        x_c2_1 = F.gelu(x_c2_1)
        """
        #Variational Autoencoder part - Extract mu and sigma
        x_c2_1 = torch.flatten(x_c2_1, start_dim=1)
        x_c2_1 = self.linearFlatten1(x_c2_1)
        mu = self.linearFlatten2(x_c2_1)
        sigma = self.linearFlatten3(x_c2_1)
        std = torch.exp(0.5*sigma)
        eps = torch.randn_like(std)
        z = sigma*self.N.sample(mu.shape) + mu
        #z = mu + std*eps

        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        z = self.linearUnflatten(z)
        z = self.linearUnflatten2(z)
        x_c2_1 = self.unflatten(z)
        #print("x_c2_1 shape = ", x_c2_1.shape)
        #Variational Autoencoder -end
        """
        
        x1_c2_9 = self.conv2_9(x_c2_1,D1//4,D2//4)
        x2_c2_9 = self.w2_9(x_c2_1,D1//4,D2//4)
        x_c2_9 = x1_c2_9 + x2_c2_9
        x_c2_9 = F.gelu(x_c2_9)
        x_c2_9 = torch.cat([x_c2_9, x_c2], dim=1)

        x1_c3 = self.conv3(x_c2_9,D1//2,D2//2)
        x2_c3 = self.w3(x_c2_9,D1//2,D2//2)
        x_c3 = x1_c3 + x2_c3
        x_c3 = F.gelu(x_c3)
        x_c3 = torch.cat([x_c3, x_c1], dim=1)

        x1_c4 = self.conv4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x2_c4 = self.w4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x_c4 = x1_c4 + x2_c4
        x_c4 = F.gelu(x_c4)
        x_c4 = torch.cat([x_c4, x_c0], dim=1)

        x1_c5 = self.conv5(x_c4,D1,D2)
        x2_c5 = self.w5(x_c4,D1,D2)
        x_c5 = x1_c5 + x2_c5
        x_c5 = F.gelu(x_c5)


        x_c5 = torch.cat([x_c5, x_fc0], dim=1)
        if self.padding!=0:
            x_c5 = x_c5[..., :-self.padding, :-self.padding]

        x_c5 = x_c5.permute(0, 2, 3, 1)
        
        x_fc1 = self.fc1(x_c5)
        x_fc1 = F.gelu(x_fc1)
        
        x_out = self.fc2(x_fc1)
        
        #x3 = torch.permute(x3, (0, 2,3,1))
        x3 = (x_l+x_r)/2
        #print(x3.shape)
        x3 = torch.permute(x3, (0,2,3,1))
        #x_fin = x_out + x3
        #x_fin = torch.permute(x_fin, (0, 3, 1, 2))
        #x_out = self.interpolation(x_fin)
        #x_out = torch.permute(x_out, (0, 3, 1, 2))
        #print(x_out.shape)
        #x_out = self.increaseColor(x_out)
        return x_out
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).cuda()

In [5]:
UNONet = UNO_vanilla(2,1)
summary(UNONet, input_size=(10, 2, 100,100))

Layer (type:depth-idx)                   Output Shape              Param #
UNO_vanilla                              [10, 1, 100, 100]         --
├─Linear: 1-1                            [10, 100, 100, 1]         3
├─SpectralConv2d: 1-2                    [10, 16, 100, 100]        56,448
├─pointwise_op: 1-3                      [10, 16, 100, 100]        --
│    └─Conv2d: 2-1                       [10, 16, 100, 100]        32
├─SpectralConv2d: 1-4                    [10, 32, 50, 50]          451,584
├─pointwise_op: 1-5                      [10, 32, 50, 50]          --
│    └─Conv2d: 2-2                       [10, 32, 100, 100]        544
├─SpectralConv2d: 1-6                    [10, 64, 25, 25]          409,600
├─pointwise_op: 1-7                      [10, 64, 25, 25]          --
│    └─Conv2d: 2-3                       [10, 64, 50, 50]          2,112
├─SpectralConv2d: 1-8                    [10, 128, 12, 12]         409,600
├─pointwise_op: 1-9                      [10, 128, 12, 12]     

In [6]:
class NIO(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 10/4):
        super(NIO, self).__init__()
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.interpolation = nn.Conv2d(3, 3, kernel_size=1)
        self.UNO_block_1 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_2 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_3 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_4 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        #self.UNO_block_5 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=20/4)
    def forward(self, x):
        uno_1_out = self.UNO_block_1(x)
        #print(uno_1_out.shape)
        #print(x[0].shape)
        #print(x.shape)
        x_int_l = self.interpolation(x[0].permute(0,3,1,2))
        x_int_r = self.interpolation(x[1].permute(0,3,1,2))
        x_int = x_int_l + x_int_r
        x_int = torch.permute(x_int, (0,2,3,1))
        y = torch.stack((x_int, uno_1_out))
        
        #y = x[0] + uno_1_out
        #print("Y Shape = ", y.shape)
        uno_2_out = self.UNO_block_2(y)
        y = torch.stack((x_int, uno_2_out))
        #y = x[1] + uno_2_out
        uno_3_out = self.UNO_block_3(y)
        y = torch.stack((x_int, uno_3_out))
        uno_4_out = self.UNO_block_4(x)
        #y = torch.stack((x_int, uno_4_out))
        #x = self.UNO_block_5(y)
        return uno_4_out

In [7]:
class PeriodicPad2d(nn.Module):
    """ 
        pad longitudinal (left-right) circular 
        and pad latitude (top-bottom) with zeros
    """
    def __init__(self, pad_width):
       super(PeriodicPad2d, self).__init__()
       self.pad_width = pad_width

    def forward(self, x):
        # pad left and right circular
        out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular") 
        # pad top and bottom zeros
        out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0) 
        return out

In [8]:
class MinMaxNormalizer:
    def __init__(self, batch):
        self.batch = batch
        self.min = batch.min()
        self.max = batch.max()
    def encode(self, batch): # min max normalization -> [0,1]
        encoded = (batch - self.min) / (self.max - self.min)
        return encoded 
    def decode(self, batch): # returns to standard values -> [a,b]
        decoded = batch * (self.max - self.min) + self.min
        return decoded 
    def cuda(self):
        self.min = self.min.cuda()
        self.max = self.max.cuda()
    def cpu(self):
        self.min = self.min.cpu()
        self.max = self.max.cpu()

In [9]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [10]:
class AFNO2D(nn.Module):
    def __init__(self, hidden_size, num_blocks=8, sparsity_threshold=0.01, hard_thresholding_fraction=1, hidden_size_factor=1):
        super().__init__()
        assert hidden_size % num_blocks == 0, f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}"

        self.hidden_size = hidden_size
        self.sparsity_threshold = sparsity_threshold
        self.num_blocks = num_blocks
        self.block_size = self.hidden_size // self.num_blocks
        self.hard_thresholding_fraction = hard_thresholding_fraction
        self.hidden_size_factor = hidden_size_factor
        self.scale = 0.02

        self.w1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size * self.hidden_size_factor))
        self.b1 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor))
        self.w2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size * self.hidden_size_factor, self.block_size))
        self.b2 = nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))

    def forward(self, x):
        bias = x

        dtype = x.dtype
        x = x.float()
        B, H, W, C = x.shape

        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
        x = x.reshape(B, H, W // 2 + 1, self.num_blocks, self.block_size)

        o1_real = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o1_imag = torch.zeros([B, H, W // 2 + 1, self.num_blocks, self.block_size * self.hidden_size_factor], device=x.device)
        o2_real = torch.zeros(x.shape, device=x.device)
        o2_imag = torch.zeros(x.shape, device=x.device)


        total_modes = H // 2 + 1
        kept_modes = int(total_modes * self.hard_thresholding_fraction)

        o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[0]) - \
            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[1]) + \
            self.b1[0]
        )

        o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes] = F.relu(
            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].imag, self.w1[0]) + \
            torch.einsum('...bi,bio->...bo', x[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes].real, self.w1[1]) + \
            self.b1[1]
        )

        o2_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes]  = (
            torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) - \
            torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
            self.b2[0]
        )

        o2_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes]  = (
            torch.einsum('...bi,bio->...bo', o1_imag[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[0]) + \
            torch.einsum('...bi,bio->...bo', o1_real[:, total_modes-kept_modes:total_modes+kept_modes, :kept_modes], self.w2[1]) + \
            self.b2[1]
        )

        x = torch.stack([o2_real, o2_imag], dim=-1)
        x = F.softshrink(x, lambd=self.sparsity_threshold)
        x = torch.view_as_complex(x)
        x = x.reshape(B, H, W // 2 + 1, C)
        x = torch.fft.irfft2(x, s=(H, W), dim=(1,2), norm="ortho")
        x = x.type(dtype)

        return x + bias

In [11]:
class Block(nn.Module):
    def __init__(
            self,
            dim,
            mlp_ratio=4.,
            drop=0.,
            drop_path=0.,
            act_layer=nn.GELU,
            norm_layer=nn.LayerNorm,
            double_skip=True,
            num_blocks=8,
            sparsity_threshold=0.01,
            hard_thresholding_fraction=1.0
        ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.filter = AFNO2D(dim, num_blocks, sparsity_threshold, hard_thresholding_fraction) 
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        #self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.double_skip = double_skip

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.filter(x)

        if self.double_skip:
            x = x + residual
            residual = x

        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x + residual
        return x

In [12]:
class PrecipNet(nn.Module):
    def __init__(self, params, backbone):
        super().__init__()
        self.params = params
        self.patch_size = (params.patch_size, params.patch_size)
        self.in_chans = params.N_in_channels
        self.out_chans = params.N_out_channels
        self.backbone = backbone
        self.ppad = PeriodicPad2d(1)
        self.conv = nn.Conv2d(self.out_chans, self.out_chans, kernel_size=3, stride=1, padding=0, bias=True)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.backbone(x)
        x = self.ppad(x)
        x = self.conv(x)
        x = self.act(x)
        return x

In [13]:
class pointwise_op(nn.Module):
    def __init__(self, in_channel, out_channel,dim1, dim2):
        super(pointwise_op,self).__init__()
        self.conv = nn.Conv2d(int(in_channel), int(out_channel), 1)
        self.dim1 = int(dim1)
        self.dim2 = int(dim2)
        #self.alpha_drop = nn.AlphaDropout(p=0.42)
        #self.layer_norm = nn.LayerNorm([int(out_channel), self.dim1, self.dim2])
        #self.batch_norm = nn.BatchNorm2d(int(out_channel))
        #self.upsample = nn.Upsample(size=[self.dim1, self.dim2], mode='bicubic', align_corners=True)

    def forward(self,x, dim1 = None, dim2 = None):
        if dim1 is None:
            dim1 = self.dim1
            dim2 = self.dim2
        x_out = self.conv(x)
        x_out = torch.nn.functional.interpolate(x_out, size = (dim1, dim2),mode = 'bicubic',align_corners=True)
        #x_out = self.upsample(x_out)
        #x_out = self.alpha_drop(x_out)
        #x_out = self.layer_norm(x_out)
        #x_out = self.batch_norm(x_out)
        return x_out

In [14]:
class AFNONet(nn.Module):
    def __init__(
            self,
            img_size=(720, 1440),
            patch_size=(16, 16),
            in_chans=3,
            out_chans=3,
            embed_dim=768,
            depth=12,
            mlp_ratio=4.,
            drop_rate=0.,
            drop_path_rate=0.,
            num_blocks=16,
            sparsity_threshold=0.01,
            hard_thresholding_fraction=1.0,
        ):
        super().__init__()
        #self.params = params
        self.img_size = img_size
        self.patch_size = (patch_size[0], patch_size[1])
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.num_features = self.embed_dim = embed_dim
        self.num_blocks = num_blocks 
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        self.h = img_size[0] // self.patch_size[0]
        self.w = img_size[1] // self.patch_size[1]

        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
            num_blocks=self.num_blocks, sparsity_threshold=sparsity_threshold, hard_thresholding_fraction=hard_thresholding_fraction) 
        for i in range(depth)])

        self.norm = norm_layer(embed_dim)

        self.head = nn.Linear(embed_dim, self.out_chans*self.patch_size[0]*self.patch_size[1], bias=False)

        trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        x = x.reshape(B, self.h, self.w, self.embed_dim)
        for blk in self.blocks:
            x = blk(x)

        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        x = rearrange(
            x,
            "b h w (p1 p2 c_out) -> b c_out (h p1) (w p2)",
            p1=self.patch_size[0],
            p2=self.patch_size[1],
            h=self.img_size[0] // self.patch_size[0],
            w=self.img_size[1] // self.patch_size[1],
        )
        return x


class PatchEmbed(nn.Module):
    def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

In [21]:
class NIO_transformer(nn.Module):
    def __init__(self, img_size=(100,100), patch_size=(2,2), in_chans=3, stack_factor=2, num_classes=0,embed_dim=768, depth=4,
                 mlp_ratio=4., width=100, height=100, mixing_type='sa'):
        super().__init__()
        #self.interpolation_l = nn.Conv2d(3, 3, kernel_size=1)
        #self.interpolation_r = nn.Conv2d(3,3,kernel_size=1)
        self.embed_dim = embed_dim
        #self.AF_transformer_red = AFNONet(img_size=img_size, patch_size=patch_size, in_chans=2, out_chans=1, depth=1)
        #self.AF_transformer_green = AFNONet(img_size=img_size, patch_size=patch_size, in_chans=2, out_chans=1, depth=1)
        #self.AF_transformer_blue = AFNONet(img_size=img_size, patch_size=patch_size, in_chans=2, out_chans=1, depth=1)
        self.AF_transformer = AFNONet(img_size=img_size, patch_size=patch_size, in_chans=6, out_chans=3, depth=depth)
        self.height = height
        self.width = width
        self.UNO_block_red = UNO_vanilla(2, 1, pad=0, factor=16/4)
        self.UNO_block_green = UNO_vanilla(2, 1, pad=0, factor=16/4)
        self.UNO_block_blue = UNO_vanilla(2, 1, pad=0, factor=16/4)
        self.UNO_block_merge = UNO(3, 3, pad=0, factor=16/4)
        
        
        """
        self.output_dims = in_chans*width*height
        #self.unflatten = nn.Unflatten(1, (3,16,16))
        self.deconv1 = nn.ConvTranspose2d(3, 3, 3,stride=1, padding=1)
        #self.alpha_drop1 = nn.AlphaDropout(p=0.42)
        #self.batch_norm1 = nn.BatchNorm2d(3)
        self.deconv2 = nn.ConvTranspose2d(3, 3, 4,stride=2, padding=1)
        #self.alpha_drop2 = nn.AlphaDropout(p=0.42)
        #self.batch_norm2 = nn.BatchNorm2d(3)
        self.deconv3 = nn.ConvTranspose2d(3, 3, 2,stride=2, padding=7)
        #self.alpha_drop3 = nn.AlphaDropout(p=0.42)
        #self.batch_norm3 = nn.BatchNorm2d(3)
        self.deconv4 = nn.ConvTranspose2d(3, 3, 2,stride=2, padding=0)
        #self.alpha_drop4 = nn.AlphaDropout(p=0.42)
        #self.batch_norm4 = nn.BatchNorm2d(3)
        self.pointwise1 = pointwise_op(3, 3, self.height, self.width)
        self.pointwise2 = pointwise_op(3, 3, self.height, self.width)
        self.pointwise3 = pointwise_op(3,3, self.height, self.width)
        self.pointwise4 = pointwise_op(3, 3, self.height, self.width)
        """
        
        
    def forward(self, x):
        self.height = x.shape[2]
        self.width = x.shape[3]
        #self.output_dims = self.height*self.width*x.shape[4]
        x_l = x[0]
        x_r = x[1]
        
        x_l_colors = x_l.permute(1,0,2,3)
        x_r_colors = x_r.permute(1,0,2,3)
        
        #print(x_l_colors.shape)
        x_l_red = x_l_colors[0]
        x_l_green = x_l_colors[1]
        x_l_blue = x_l_colors[2]
        
        x_r_red = x_r_colors[0]
        x_r_green = x_r_colors[1]
        x_r_blue = x_r_colors[2]
        
        x_red = torch.stack((x_l_red, x_r_red)).permute(1,0,2,3)
        x_green = torch.stack((x_l_green, x_r_green)).permute(1,0,2,3)
        x_blue = torch.stack((x_l_blue, x_r_blue)).permute(1,0,2,3)
        
        #print(x_red.shape)
        #print(x_green.shape)
        #print(x_blue.shape)

        #x = torch.cat((x_l, x_r), dim=1)
        #x = self.AF_transformer(x)
        #r = self.UNO_block_red(x_red)
        #g = self.UNO_block_green(x_green)
        #b = self.UNO_block_blue(x_blue)
        #uno_x = torch.cat((r,g,b), dim=1)
        
        
        x = torch.cat((x_l, x_r), dim=1)
        transfo_x = self.AF_transformer(x)
        #x_concat = 0.0001*uno_x + transfo_x
        #x_out = self.UNO_block_merge(x_concat)
        #x_concat = transfo_x
        
        """
        x = self.deconv1(x)
        x = self.pointwise1(x, self.height//8, self.width//8)
        #x = torch.nn.functional.interpolate(x, size = (self.output_dims//8, self.output_dims//8),mode = 'bicubic',align_corners=True)
        #x = self.alpha_drop1(x)
        #x = self.batch_norm1(x)
        x = self.deconv2(x)
        x = self.pointwise2(x, self.height//4, self.width//4)
        #x = torch.nn.functional.interpolate(x, size=(self.output_dims//4, self.output_dims//4), mode = 'bicubic',align_corners=True)
        #x = self.alpha_drop2(x)
        #x = self.batch_norm2(x)
        x = self.deconv3(x)
        x = self.pointwise3(x, self.height//2, self.width//2)
        #x = torch.nn.functional.interpolate(x, size=(self.output_dims//2, self.output_dims//2), mode = 'bicubic',align_corners=True)
        #x = self.alpha_drop3(x)
        #x = self.batch_norm3(x)
        x = self.deconv4(x)
        x = self.pointwise4(x, self.height, self.width)
        #x = torch.nn.functional.interpolate(x, size=(self.output_dims, self.output_dims), mode = 'bicubic',align_corners=True)
        #x = self.alpha_drop4(x)
        #x = self.batch_norm4(x)
        #x = self.pointwise(x, self.height, self.width)
        """
        #print(x.shape)
        return transfo_x
    def sub_mean(self, x):
        mean = x.mean(2, keepdim=True).mean(3, keepdim=True)
        x -= mean
        return x, mean
    

In [23]:
#NIO_Af = AFNONet(img_size=100, patch_size=15, in_chans = 3, num_classes=0)
NIO_Af = NIO_transformer(img_size=(256,256))
summary(NIO_Af, input_size=(2,10,  3,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
NIO_transformer                          [10, 3, 256, 256]         --
├─AFNONet: 1-1                           [10, 3, 256, 256]         12,593,664
│    └─PatchEmbed: 2-1                   [10, 16384, 768]          --
│    │    └─Conv2d: 3-1                  [10, 768, 128, 128]       19,200
│    └─Dropout: 2-2                      [10, 16384, 768]          --
│    └─ModuleList: 2-3                   --                        --
│    │    └─Block: 3-2                   [10, 128, 128, 768]       4,876,032
│    │    └─Block: 3-3                   [10, 128, 128, 768]       4,876,032
│    │    └─Block: 3-4                   [10, 128, 128, 768]       4,876,032
│    │    └─Block: 3-5                   [10, 128, 128, 768]       4,876,032
├─UNO_vanilla: 1-2                       --                        --
│    └─Linear: 2-4                       --                        3
│    └─SpectralConv2d: 2-5               --   

In [24]:
import cv2
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
import numpy as np
import random
def LoadDataBatches(batchNumber, batches=4, isNormalized=True, isTrain=True, percent=1):
    if(isTrain):
        x_train = torch.load('/scratch/gilbreth/hviswan/x_train_vimeo_256_diff_type.pt')
        y_train = torch.load('/scratch/gilbreth/hviswan/y_train_vimeo_256_diff_type.pt')
        ntrain = x_train.shape[0]
        
        ntrain_slice = int(ntrain*percent)
        if(ntrain_slice < ntrain-30):
            start_index = min(0, random.randint(0, ntrain-ntrain_slice-30))
        else:
            start_index = 0
        x_train = x_train[0+start_index:start_index+ntrain_slice]
        x_train = torch.permute(x_train, (1, 0,2,3,4))
        print(x_train.shape)
        #ntrain = x_train.shape[1]
        y_train = y_train.reshape(ntrain, x_train.shape[2], x_train.shape[3], x_train.shape[4])
        y_train = y_train[0+start_index:ntrain_slice+start_index]
        print(x_train.shape)
        print(y_train.shape)
        if(isNormalized):
            y_normalizer = MinMaxNormalizer(y_train)
            x_normalizer = MinMaxNormalizer(x_train)
            x_train = x_normalizer.encode(x_train)
            y_train = y_normalizer.encode(y_train)
        train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train[0], x_train[1], y_train), batch_size=batches, shuffle=True)
        if(isNormalized):
            return train_loader, y_normalizer
        return train_loader
    else:
        x_test = torch.load('/scratch/gilbreth/hviswan/x_test_vimeo_256_diff_type.pt')
        y_test = torch.load('/scratch/gilbreth/hviswan/y_test_vimeo_256_diff_type.pt')
        ntest = x_test.shape[0]
        #x_test = x_test[0: 500]
        x_test = torch.permute(x_test, (1, 0,2,3,4))
        
        y_test = y_test.reshape(ntest, x_test.shape[2], x_test.shape[3], x_test.shape[4])
        #y_test = y_test[0:500]
        print(x_test.shape)
        print(y_test.shape)
        if(isNormalized):
            y_normalizer = MinMaxNormalizer(y_test)
            x_normalizer = MinMaxNormalizer(x_test)
            x_test = x_normalizer.encode(x_test)
            y_test = y_normalizer.encode(y_test)
        test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test[0], x_test[1], y_test), batch_size=batches, shuffle=True)
        if(isNormalized):
            return test_loader, y_normalizer
        return test_loader
#LoadDataBatches(2, batches=10, isNormalized=False, percent=1)
#LoadDataBatches(2, batches=10, isNormalized=False, percent=1, isTrain=False)

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class MeanShift(nn.Conv2d):
    def __init__(self, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False

class VGG(nn.Module):
    def __init__(self, loss_type):
        super(VGG, self).__init__()
        vgg_features = models.vgg19(pretrained=True).features
        modules = [m.cuda() for m in vgg_features]
        conv_index = loss_type
        if conv_index == '22':
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index == '33':
            self.vgg = nn.Sequential(*modules[:16])
        elif conv_index == '44':
            self.vgg = nn.Sequential(*modules[:26])
        elif conv_index == '54':
            self.vgg = nn.Sequential(*modules[:35])
        elif conv_index == 'P':
            self.vgg = nn.ModuleList([
                nn.Sequential(*modules[:8]),
                nn.Sequential(*modules[8:16]),
                nn.Sequential(*modules[16:26]),
                nn.Sequential(*modules[26:35])
            ])
        self.vgg = nn.DataParallel(self.vgg).cuda()

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229, 0.224, 0.225)
        self.sub_mean = MeanShift(vgg_mean, vgg_std)
        self.vgg.requires_grad = False
        # self.criterion = nn.L1Loss()
        self.conv_index = conv_index

    def forward(self, sr, hr):
        def _forward(x):
            x = x.cpu()
            x = self.sub_mean(x)
            x = self.vgg(x)
            return x.cuda()
        def _forward_all(x):
            feats = []
            x = x.cpu()
            x = self.sub_mean(x)
            for module in self.vgg.module:
                x = module(x.cuda())
                feats.append(x.cuda())
            return feats

        if self.conv_index == 'P':
            vgg_sr_feats = _forward_all(sr)
            with torch.no_grad():
                vgg_hr_feats = _forward_all(hr.detach())
            loss = 0
            for i in range(len(vgg_sr_feats)):
                loss_f = F.mse_loss(vgg_sr_feats[i], vgg_hr_feats[i])
                #print(loss_f)
                loss += loss_f
            #print()
        else:
            vgg_sr = _forward(sr)
            with torch.no_grad():
                vgg_hr = _forward(hr.detach())
            loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss

In [30]:
from datetime import datetime as time
from skimage.metrics import structural_similarity as ssim
import torchvision.transforms.functional as FF
import cv2
import PIL
myloss = torch.nn.MSELoss()
l1loss = torch.nn.L1Loss()
kldivloss = torch.nn.KLDivLoss()
vggloss = VGG('22')
RHO = 0.05
BETA = 0.01
def train_NIO_transformer(D, G, train_data, epochs, D_optim, G_optim, scheduler=None):
    losses_D = np.zeros(epochs)
    losses_G = np.zeros(epochs)
    ssims_G = np.zeros(epochs)
    #train_data = LoadDataBatches(0, isNormalized=False, batches=10, isTrain=True, percent=0.2)
    train_data = LoadDataBatches(0, isNormalized=False, batches=2, isTrain=True)
    #base_G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_vgg_rgb150.pt')
    #complex_G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_huge_rgb1600.pt')
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(G_optim, factor=0.2, patience=5, mode='min')
    for i in range(1, epochs+1):
        t1 = time.now()
        loss_D = 0.0
        loss_G = 0.0
        ssim_G = 0.0
        train_counter = 0
        ssim_batch=0
        for j in range(1):
            #del train_data
            #train_data = LoadDataBatches(j, batches=10)
            for xl,xr,y in train_data:
                train_counter += 1
                temp = 0.05*xl + 0.1*xr
                
                xl = xl.cpu().detach().numpy()
                y = y.cpu().detach().numpy()
                #xr = xr.cpu().detach().numpy()
                x = torch.from_numpy(np.asarray([xl,y]).astype(np.float32))
                xr = xr.cuda()
                x = x.cuda()
                #y = y.cuda()
                G_optim.zero_grad()
                #base_x_out = base_G(x).cuda()
                #base_x_out = torch.permute(base_x_out, (0,3, 1, 2))
                #x_avg = torch.from_numpy(np.asarray((xl+xr)/2.0).astype(np.float32)).cuda()
                #x = torch.stack((base_x_out, base_x_out))
                #complex_x_out = complex_G(x).reshape(x.shape[1], x.shape[2],x.shape[3],x.shape[4])
                #x = torch.stack((complex_x_out, complex_x_out))
                #x = torch.permute(complex_x_out, (0, 3, 1, 2))
                #xl_in = torch.permute(xl_in, (0,3,1,2))
                #base_x_gen = base_G(x).permute(0,3,1,2)
                x_syn = G(x)
                #x_syn = x_syn #+ 0.001*base_x_gen
                #x_syn_colors = torch.permute(x_syn, (1,0,2,3))
                #x_r = x_syn_colors[0]
                #x_b = x_syn_colors[2]
                #x_g = x_syn_colors[1]
                
                #x_syn = torch.cat((x_r, x_g, x_b), dim=1)
                #print(x_syn.shape)
                #x_syn = torch.permute(x_syn, (0, 2, 3, 1))
                #x_syn = x_syn + temp.cuda()
                #x_syn = y_normalizer.decode(x_syn)
                #y = y_normalizer.decode(y)
                #x_syn2 = torch.permute(x_syn, (0, 3, 1, 2)).type(torch.float32)
                #x_syn2 = x_syn2.detach().cpu().numpy()
                #for k in range(x_syn2.shape[0]):
                    #x_syn2[k] = FF.adjust_saturation(x_syn2[k], 1.5)
                #x_syn2 = torch.from_numpy(x_syn2.astype(np.float32))
                #x_syn2 = x_syn
                y2 = torch.permute(xr, (0, 2, 3,1)).type(torch.float32)
                #y2_colors = torch.permute(y2, (1,0,2,3))
                #y2_r = y2_colors[0]
                #y2_g = y2_colors[1]
                #y2_b = y2_colors[2]
                
                x_syn2 = torch.permute(x_syn, (0, 2, 3, 1))
                y = xr.reshape(xr.shape[0], xr.shape[1],xr.shape[2],xr.shape[3])
               # W_loss = l1loss(x_r.type(torch.float32), y2_r.type(torch.float32)) + l1loss(x_g.type(torch.float32), y2_g.type(torch.float32))+ l1loss(x_b.type(torch.float32), y2_b.type(torch.float32))
                
                W_loss = l1loss(x_syn.type(torch.float32), y.type(torch.float32)) #+ 0.001*vggloss(x_syn.type(torch.float32), y2.type(torch.float32))#+ 0.01*torch.abs(kldivloss(x_syn2.type(torch.float32), y2.type(torch.float32)))
                ssim_value = W_loss.item()
                ssim_individual = 0.0
                for k in range(y.shape[0]):
                    ssim_individual += ssim(x_syn2[k].detach().cpu().numpy(), y2[k].detach().cpu().numpy(), multichannel=True)
                
                ssim_batch += ssim_individual/y.shape[0]
                loss = W_loss
                loss.backward()
                loss_G += loss.item() 
                torch.nn.utils.clip_grad_norm_(G.parameters(), 0.1)
                G_optim.step()
                #scheduler.step()
                if(train_counter %100 == 0 or train_counter == 1):
                    print("Batch = ",train_counter," Train Loss = ", loss_G/train_counter, " SSIM = ", ssim_batch/(train_counter))
            losses_D[i] = loss_D / train_counter
            losses_G[i] = loss_G / train_counter
            ssims_G[i] = ssim_batch / train_counter
            t2 = time.now()
            test_counter = 1
            """
            test_counter = 0
            ssim_batch = 0
            for xl, xr, y in validation_data:
                temp = 0.05*xl + 0.1*xr
                xl = xl.cpu().detach().numpy()
                xr = xr.cpu().detach().numpy()
                x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
                x = x.cuda()
                y = y.cuda()
                x_syn = G(x).cuda()
                x_syn = x_syn + temp.cuda()
                x_syn = y_normalizer_test.decode(x_syn)
                y = y_normalizer_test.decode(y)
                x_syn2 = torch.permute(x_syn, (0, 3, 1, 2)).type(torch.float32)
                y2 = torch.permute(y, (0, 3, 1,2)).type(torch.float32)
                y = y.reshape(y.shape[0], y.shape[1],y.shape[2],y.shape[3])
                W_loss = huberloss(x_syn2.type(torch.float32), y2.type(torch.float32))*10000
                ssim_individual = 0.0
                for k in range(y.shape[0]):
                    ssim_individual += ssim(x_syn[k].detach().cpu().numpy(), y[k].detach().cpu().numpy(), multichannel=True)
                
                ssim_batch += ssim_individual/y.shape[0]
                loss = W_loss
                loss.backward()
                G_optim.step()
                loss_D += loss.item()
                test_counter += 1
                if(test_counter %50 == 0 or test_counter == 1):
                    print("Batch = ",test_counter," Validation Loss = ", loss_D/test_counter, " SSIM = ", ssim_batch/(test_counter))
            losses_D[i] = loss_D/test_counter
            """
            print("Loader #: ", j, " Epoch: ", i, "/ ",epochs," -  Time: ", t2-t1, "s - Validation Loss: ", losses_D[i], " - Train Loss: ", losses_G[i], " - train SSIM: ", ssims_G[i], " validation SSIM: ", ssim_batch/test_counter)
        if(i%20 == 0):
            torch.save(G, '/depot/bera89/data/hviswan/NIO_transformer'+str(i)+'.pt')
            
    return losses_D, losses_G, D, G

In [31]:
from types import SimpleNamespace
from timm.optim import create_optimizer
#from apex import optimizers
args = SimpleNamespace()
args.opt = 'adamw'
args.lr = 1e-4
args.weight_decay=1e-4 
args.opt_eps = 1e-8
args.momentum = 0.9
optimizer = create_optimizer(args, NIO_Af)
epochs = 300
NIO_Af.train()
device = 'cuda:11'

G_optimizer = torch.optim.Adam(NIO_Af.parameters(), lr=1e-3, weight_decay=1e-4)
losses_D, losses_G, D, G = train_NIO_transformer(None, NIO_Af, None, epochs, None, G_optimizer)

torch.Size([2, 4400, 3, 256, 256])
torch.Size([2, 4400, 3, 256, 256])
torch.Size([4400, 3, 256, 256])


  ssim_individual += ssim(x_syn2[k].detach().cpu().numpy(), y2[k].detach().cpu().numpy(), multichannel=True)


Batch =  1  Train Loss =  1.9852361679077148  SSIM =  -2.3378030164167285e-05
Batch =  100  Train Loss =  2.0262391632795334  SSIM =  0.002629652093024788
Batch =  200  Train Loss =  1.281402186602354  SSIM =  0.012921692827115408
Batch =  300  Train Loss =  0.9808120762308439  SSIM =  0.027834878926407403
Batch =  400  Train Loss =  0.7877435138449073  SSIM =  0.055349041271903446
Batch =  500  Train Loss =  0.6662572989314794  SSIM =  0.08037478417458738
Batch =  600  Train Loss =  0.5761757799734671  SSIM =  0.1190946121011145
Batch =  700  Train Loss =  0.5109575137283121  SSIM =  0.1488050001748091
Batch =  800  Train Loss =  0.45988248120527714  SSIM =  0.17990669568955653
Batch =  900  Train Loss =  0.41930727440863846  SSIM =  0.20561411690754805
Batch =  1000  Train Loss =  0.385874004624784  SSIM =  0.2315097665894906
Batch =  1100  Train Loss =  0.35818545007908886  SSIM =  0.2597158741371993
Batch =  1200  Train Loss =  0.3350719873793423  SSIM =  0.28280007028611837
Batch 

Batch =  500  Train Loss =  0.04906866979412734  SSIM =  0.7444135282337666
Batch =  600  Train Loss =  0.048617805185106895  SSIM =  0.7491575050602357
Batch =  700  Train Loss =  0.04844485450402967  SSIM =  0.7494508224938597
Batch =  800  Train Loss =  0.04836893605301157  SSIM =  0.7497583018429578
Batch =  900  Train Loss =  0.04833603214679493  SSIM =  0.7502779449853633
Batch =  1000  Train Loss =  0.04828092296794057  SSIM =  0.7512542942464352
Batch =  1100  Train Loss =  0.04839171094819903  SSIM =  0.7511204634877768
Batch =  1200  Train Loss =  0.04854441503528505  SSIM =  0.7505928872028986
Batch =  1300  Train Loss =  0.04859428480554086  SSIM =  0.7499961943580554
Batch =  1400  Train Loss =  0.04858346268268568  SSIM =  0.7499245857154685
Batch =  1500  Train Loss =  0.04847074403241277  SSIM =  0.7503753724868099
Batch =  1600  Train Loss =  0.04843471724772826  SSIM =  0.750532159244176
Batch =  1700  Train Loss =  0.048252997143084515  SSIM =  0.7515128518180813
Bat

Batch =  1000  Train Loss =  0.04716753929434344  SSIM =  0.7535637502186
Batch =  1100  Train Loss =  0.04709193829574029  SSIM =  0.754429669634185
Batch =  1200  Train Loss =  0.04696562137726384  SSIM =  0.7550901977562656
Batch =  1300  Train Loss =  0.04719817934092134  SSIM =  0.754240158506884
Batch =  1400  Train Loss =  0.047272380946669725  SSIM =  0.7537545837328903
Batch =  1500  Train Loss =  0.04723288032133132  SSIM =  0.7539837591523926
Batch =  1600  Train Loss =  0.04727843918808503  SSIM =  0.7538343140739017
Batch =  1700  Train Loss =  0.047214366791715076  SSIM =  0.7540796475967063
Batch =  1800  Train Loss =  0.04711560657894653  SSIM =  0.7549220788230498
Batch =  1900  Train Loss =  0.047074228511682074  SSIM =  0.7552175739956529
Batch =  2000  Train Loss =  0.047166847471380606  SSIM =  0.7547406126633287
Batch =  2100  Train Loss =  0.04705287262863879  SSIM =  0.7553785379869598
Batch =  2200  Train Loss =  0.04708786593940617  SSIM =  0.7553873244131153


Batch =  1500  Train Loss =  0.04711488491296768  SSIM =  0.7564963608980179
Batch =  1600  Train Loss =  0.047202464184956626  SSIM =  0.7557568147778511
Batch =  1700  Train Loss =  0.047274982584092545  SSIM =  0.7549527084564461
Batch =  1800  Train Loss =  0.047169085790713626  SSIM =  0.7555526764276955
Batch =  1900  Train Loss =  0.04694596687380813  SSIM =  0.7566670762551458
Batch =  2000  Train Loss =  0.04690125779155642  SSIM =  0.7569154429063201
Batch =  2100  Train Loss =  0.04701116324269346  SSIM =  0.7556509350843372
Batch =  2200  Train Loss =  0.046987179048698056  SSIM =  0.756076258431104
Loader #:  0  Epoch:  13 /  300  -  Time:  0:05:54.868749 s - Validation Loss:  0.0  - Train Loss:  0.046987179048698056  - train SSIM:  0.756076258431104  validation SSIM:  1663.367768548429
Batch =  1  Train Loss =  0.0231393501162529  SSIM =  0.8768608570098877
Batch =  100  Train Loss =  0.04668420023284853  SSIM =  0.7544915230944753
Batch =  200  Train Loss =  0.0458101103

KeyboardInterrupt: 