# GANO trained on GRF data on 2D domain

The inputs to the generative model are samples of a GRF




In [1]:
import torch
import numpy as np
import pylab as plt
import torch.nn.functional as F
import torch.nn as nn
from torchinfo import summary
#from random_fields import *

  from .autonotebook import tqdm as notebook_tqdm


#### Parameters

In [2]:
res = 64-8 #resolution
npad = 4  # padding of latio 8/128 in U-NO
ntrain = 5156 # number of training samples
ntest = 1290 # number of testing samples
modes = 50 # number of Fourier modes in the initial FNO layer
d_co_domain = 16  #the dimension of the co-domain of the initial U-NO layer.
lr = 2e-4 #learning rate of the optimizer
device = 'cuda:11'
epochs = 2000
λ_grad = 10.0 # Lagrange coefficinet for gradient penalty
n_critic = 5 # every n_critic iteration the generator is updated
batch_size = 200

In [3]:
# normalization, pointwise gaussian
class InputNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(InputNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x).mean(dim=0)
        self.std = torch.std(x).mean(dim=0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx]+ self.eps # T*batch*n
                mean = self.mean[:,sample_idx]

        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()
class MinMaxNormalizer:
    def __init__(self, batch):
        self.batch = batch
        self.min = batch.min()
        self.max = batch.max()
    def encode(self, batch): # min max normalization -> [0,1]
        encoded = (batch - self.min) / (self.max - self.min)
        return encoded 
    def decode(self, batch): # returns to standard values -> [a,b]
        decoded = batch * (self.max - self.min) + self.min
        return decoded 
    def cuda(self):
        self.min = self.min.cuda()
        self.max = self.max.cuda()
    def cpu(self):
        self.min = self.min.cpu()
        self.max = self.max.cpu()

def compute_acovf(z):
    from scipy.stats import binned_statistic
    z_hat = torch.fft.rfft2(z)
    acf = torch.fft.irfft2(torch.conj(z_hat) * z_hat)
    acf = torch.fft.fftshift(acf).mean(dim=0) / z[0].numel()
    acf_r = acf.view(-1).cpu().detach().numpy()
    lags_x, lags_y = torch.meshgrid(torch.arange(res) - res//2, torch.arange(res) - res//2)
    lags_r = torch.sqrt(lags_x**2 + lags_y**2).view(-1).cpu().detach().numpy()

    idx = np.argsort(lags_r)
    lags_r = lags_r[idx]
    acf_r = acf_r[idx]

    bin_means, bin_edges, binnumber = binned_statistic(lags_r, acf_r, 'mean', bins=np.linspace(0.0, res, 50))
    return bin_edges[:-1], bin_means

In [4]:
import cv2
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
import numpy as np
import random
def LoadDataBatches(batchNumber, batches=4, isNormalized=True, isTrain=True, percent=1):
    if(isTrain):
        x_train = torch.load('/scratch/gilbreth/hviswan/x_train_vimeo_256_diff_type.pt')
        y_train = torch.load('/scratch/gilbreth/hviswan/y_train_vimeo_256_diff_type.pt')
        ntrain = x_train.shape[0]
        
        ntrain_slice = int(ntrain*percent)
        if(ntrain_slice < ntrain-30):
            start_index = min(0, random.randint(0, ntrain-ntrain_slice-30))
        else:
            start_index = 0
        x_train = x_train[0+start_index:start_index+ntrain_slice]
        x_train = torch.permute(x_train, (1, 0,2,3,4))
        print(x_train.shape)
        #ntrain = x_train.shape[1]
        y_train = y_train.reshape(ntrain, x_train.shape[2], x_train.shape[3], x_train.shape[4])
        y_train = y_train[0+start_index:ntrain_slice+start_index]
        print(x_train.shape)
        print(y_train.shape)
        if(isNormalized):
            y_normalizer = MinMaxNormalizer(y_train)
            x_normalizer = MinMaxNormalizer(x_train)
            x_train = x_normalizer.encode(x_train)
            y_train = y_normalizer.encode(y_train)
        train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train[0], x_train[1], y_train), batch_size=batches, shuffle=True)
        if(isNormalized):
            return train_loader, y_normalizer
        return train_loader
    else:
        x_test = torch.load('/scratch/gilbreth/hviswan/x_test_vimeo_256_diff_type.pt')
        y_test = torch.load('/scratch/gilbreth/hviswan/y_test_vimeo_256_diff_type.pt')
        ntest = x_test.shape[0]
        #x_test = x_test[0: 500]
        x_test = torch.permute(x_test, (1, 0,2,3,4))
        
        y_test = y_test.reshape(ntest, x_test.shape[2], x_test.shape[3], x_test.shape[4])
        #y_test = y_test[0:500]
        print(x_test.shape)
        print(y_test.shape)
        if(isNormalized):
            y_normalizer = MinMaxNormalizer(y_test)
            x_normalizer = MinMaxNormalizer(x_test)
            x_test = x_normalizer.encode(x_test)
            y_test = y_normalizer.encode(y_test)
        test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test[0], x_test[1], y_test), batch_size=batches, shuffle=True)
        if(isNormalized):
            return test_loader, y_normalizer
        return test_loader
#LoadDataBatches(2, batches=10, isNormalized=False, percent=1)
#LoadDataBatches(2, batches=10, isNormalized=False, percent=1, isTrain=False)

In [5]:
def kernel(in_chan=2, up_dim=32):
    """
        Kernel network apply on grid
    """
    layers = nn.Sequential(
                nn.Linear(in_chan, up_dim, bias=True), torch.nn.GELU(),
                nn.Linear(up_dim, up_dim, bias=True), torch.nn.GELU(),
                nn.Linear(up_dim, 1, bias=False)
            )
    return layers

### Generator operator and Discriminator functional

In [6]:
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, dim1, dim2,modes1 = None, modes2 = None):
        super(SpectralConv2d, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """
        in_channels = int(in_channels)
        out_channels = int(out_channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dim1 = dim1 #output dimensions
        self.dim2 = dim2
        #self.batch_norm = nn.BatchNorm2d(out_channels)
        #self.dropout = nn.Dropout(p=0.42)
        if modes1 is not None:
            self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
            self.modes2 = modes2
        else:
            self.modes1 = dim1//2 + 1 #if not given take the highest number of modes can be taken
            self.modes2 = dim2//2 
        self.scale = (1 / (2*in_channels))**(1.0/2.0)
        self.weights1 = nn.Parameter(self.scale * (torch.randn(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)))
        self.weights2 = nn.Parameter(self.scale * (torch.randn(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x, dim1 = None,dim2 = None):
        if dim1 is not None:
            self.dim1 = dim1
            self.dim2 = dim2
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes

        out_ft = torch.zeros(batchsize, self.out_channels,  self.dim1, self.dim2//2 + 1 , dtype=torch.cfloat, device=x.device)
        #print("Out FT Shape = ", out_ft.shape)
        #print("x_ft Shape = ", x_ft.shape)
        #print("Modes1 = ", self.modes1, " Modes2 = ", self.modes2, " self.dim1 = ", self.dim1, " self.dim2//2 + 1 = ", self.dim2//2+1)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(self.dim1, self.dim2))
        #x = self.dropout(x)
        #x = self.batch_norm(x)
        return x


class pointwise_op(nn.Module):
    def __init__(self, in_channel, out_channel,dim1, dim2):
        super(pointwise_op,self).__init__()
        self.conv = nn.Conv2d(int(in_channel), int(out_channel), 1)
        self.dim1 = int(dim1)
        self.dim2 = int(dim2)
        #self.alpha_drop = nn.AlphaDropout(p=0.42)
        #self.layer_norm = nn.LayerNorm([int(out_channel), self.dim1, self.dim2])
        #self.batch_norm = nn.BatchNorm2d(int(out_channel))
        #self.upsample = nn.Upsample(size=[self.dim1, self.dim2], mode='bicubic', align_corners=True)

    def forward(self,x, dim1 = None, dim2 = None):
        if dim1 is None:
            dim1 = self.dim1
            dim2 = self.dim2
        x_out = self.conv(x)
        x_out = torch.nn.functional.interpolate(x_out, size = (dim1, dim2),mode = 'bicubic',align_corners=True)
        #x_out = self.upsample(x_out)
        #x_out = self.alpha_drop(x_out)
        #x_out = self.layer_norm(x_out)
        #x_out = self.batch_norm(x_out)
        return x_out

class UNO(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 16/4):
        super(UNO, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.padding = pad  # pad the domain if input is non-periodic
        self.reduceColor = nn.Conv2d(3, 1, kernel_size=1)
        self.interpolation = nn.Conv2d(3, 3, kernel_size=1)

        self.fc0 = nn.Linear(self.in_d_co_domain, self.d_co_domain) # input channel is 3: (a(x, y), x, y)

        self.conv0 = SpectralConv2d(self.d_co_domain, 4*factor*self.d_co_domain, 16, 16, 32, 32)

        self.conv1 = SpectralConv2d(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16, 26,26)

        self.conv2 = SpectralConv2d(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,13,13)
        
        self.conv2_1 = SpectralConv2d(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 4, 4,7,7)
        
        self.conv2_9 = SpectralConv2d(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,7,7)
    

        self.conv3 = SpectralConv2d(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16,13,13)

        self.conv4 = SpectralConv2d(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 32, 32,26,26)

        self.conv5 = SpectralConv2d(8*factor*self.d_co_domain, self.d_co_domain, 48, 48,32,32) # will be reshaped

        self.w0 = pointwise_op(self.d_co_domain,4*factor*self.d_co_domain,75, 75) #
        
        self.w1 = pointwise_op(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w2 = pointwise_op(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25) #
        
        self.w2_1 = pointwise_op(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 12, 12)
        
        #VAE Code begin
        self.linearFlatten1 = nn.Linear(384*12*12, 128)
        self.linearFlatten2 = nn.Linear(128, 4)
        self.linearFlatten3 = nn.Linear(128, 4)
        self.softmax1 = nn.Softmax(dim=1)
        self.N = torch.distributions.Normal(0, 10)
        self.N.loc = self.N.loc.cuda()
        self.N.scale = self.N.scale.cuda()
        self.kl = 0
        self.linearUnflatten = nn.Linear(4, 128)
        self.relu = nn.ReLU(True)
        self.linearUnflatten2 = nn.Linear(128, 384*12*12)
        self.relu2 = nn.ReLU(True)
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(384, 12, 12))
        #VAE Code End
        
        self.w2_9 = pointwise_op(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25)
        
        self.w3 = pointwise_op(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w4 = pointwise_op(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 75, 75)
        
        self.w5 = pointwise_op(8*factor*self.d_co_domain, self.d_co_domain, 100, 100) # will be reshaped

        self.fc1 = nn.Linear(2*self.d_co_domain, 4*self.d_co_domain)
        self.fc2 = nn.Linear(4*self.d_co_domain, 3)
        self.increaseColor = nn.Conv2d(1, 3, kernel_size=1)

    def forward(self, x):
        
        grid = self.get_grid(x[0].shape, x.device)
        #print(grid.shape)
        
        #print(x[0].shape)
        #x_l = torch.cat((x[0], grid), dim=-1).cuda()
        #x_r = torch.cat((x[1], grid), dim=-1).cuda()

        x_l = x[0].permute(0,3,1,2)
        x_r = x[1].permute(0,3,1,2)
        #print(x_l.shape)
        x1 = self.interpolation(x_l)
        x2 = self.interpolation(x_r)
        ##print(x1.shape)
        x = x1 + x2
        x3 = x
        #x = self.reduceColor(x)
        x = x.permute(0,2,3,1)
        #x = torch.cat((x, grid), dim=-1).cuda()
        #print(x.shape)
        x_fc0 = self.fc0(x)

        x_fc0 = F.gelu(x_fc0)
        
        x_fc0 = x_fc0.permute(0, 3, 1, 2)
        
        
        x_fc0 = F.pad(x_fc0, [0,self.padding, 0,self.padding])
        
        D1,D2 = x_fc0.shape[-2],x_fc0.shape[-1]
        
        x1_c0 = self.conv0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x2_c0 = self.w0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x_c0 = x1_c0 + x2_c0
        x_c0 = F.gelu(x_c0)

        x1_c1 = self.conv1(x_c0 ,D1//2,D2//2)
        x2_c1 = self.w1(x_c0 ,D1//2,D2//2)
        x_c1 = x1_c1 + x2_c1
        x_c1 = F.gelu(x_c1)

        x1_c2 = self.conv2(x_c1 ,D1//4,D2//4)
        x2_c2 = self.w2(x_c1 ,D1//4,D2//4)
        x_c2 = x1_c2 + x2_c2
        x_c2 = F.gelu(x_c2 )

        x1_c2_1 = self.conv2_1(x_c2,D1//8,D2//8)
        x2_c2_1 = self.w2_1(x_c2,D1//8,D2//8)
        x_c2_1 = x1_c2_1 + x2_c2_1
        x_c2_1 = F.gelu(x_c2_1)
        """
        #Variational Autoencoder part - Extract mu and sigma
        x_c2_1 = torch.flatten(x_c2_1, start_dim=1)
        x_c2_1 = self.linearFlatten1(x_c2_1)
        mu = self.linearFlatten2(x_c2_1)
        sigma = self.linearFlatten3(x_c2_1)
        std = torch.exp(0.5*sigma)
        eps = torch.randn_like(std)
        z = sigma*self.N.sample(mu.shape) + mu
        #z = mu + std*eps

        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        z = self.linearUnflatten(z)
        z = self.linearUnflatten2(z)
        x_c2_1 = self.unflatten(z)
        #print("x_c2_1 shape = ", x_c2_1.shape)
        #Variational Autoencoder -end
        """
        
        x1_c2_9 = self.conv2_9(x_c2_1,D1//4,D2//4)
        x2_c2_9 = self.w2_9(x_c2_1,D1//4,D2//4)
        x_c2_9 = x1_c2_9 + x2_c2_9
        x_c2_9 = F.gelu(x_c2_9)
        x_c2_9 = torch.cat([x_c2_9, x_c2], dim=1)

        x1_c3 = self.conv3(x_c2_9,D1//2,D2//2)
        x2_c3 = self.w3(x_c2_9,D1//2,D2//2)
        x_c3 = x1_c3 + x2_c3
        x_c3 = F.gelu(x_c3)
        x_c3 = torch.cat([x_c3, x_c1], dim=1)

        x1_c4 = self.conv4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x2_c4 = self.w4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x_c4 = x1_c4 + x2_c4
        x_c4 = F.gelu(x_c4)
        x_c4 = torch.cat([x_c4, x_c0], dim=1)

        x1_c5 = self.conv5(x_c4,D1,D2)
        x2_c5 = self.w5(x_c4,D1,D2)
        x_c5 = x1_c5 + x2_c5
        x_c5 = F.gelu(x_c5)


        x_c5 = torch.cat([x_c5, x_fc0], dim=1)
        if self.padding!=0:
            x_c5 = x_c5[..., :-self.padding, :-self.padding]

        x_c5 = x_c5.permute(0, 2, 3, 1)
        
        x_fc1 = self.fc1(x_c5)
        x_fc1 = F.gelu(x_fc1)
        
        x_out = self.fc2(x_fc1)
        
        #x3 = torch.permute(x3, (0, 2,3,1))
        x3 = (x_l+x_r)/2
        #print(x3.shape)
        x3 = torch.permute(x3, (0,2,3,1))
        #x_fin = x_out + x3
        #x_fin = torch.permute(x_fin, (0, 3, 1, 2))
        #x_out = self.interpolation(x_fin)
        #x_out = torch.permute(x_out, (0, 3, 1, 2))
        #print(x_out.shape)
        #x_out = self.increaseColor(x_out)
        return x_out
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).cuda()

In [7]:
class UNO_vanilla(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 16/4):
        super(UNO_vanilla, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.padding = pad  # pad the domain if input is non-periodic

        self.fc0 = nn.Linear(self.in_d_co_domain, self.d_co_domain) # input channel is 3: (a(x, y), x, y)

        self.conv0 = SpectralConv2d(self.d_co_domain, 4*factor*self.d_co_domain, 16, 16, 42, 42)

        self.conv1 = SpectralConv2d(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16, 21,21)

        self.conv2 = SpectralConv2d(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,10,10)
        
        self.conv2_1 = SpectralConv2d(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 4, 4,5,5)
        
        self.conv2_9 = SpectralConv2d(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,5,5)
    

        self.conv3 = SpectralConv2d(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16,10,10)

        self.conv4 = SpectralConv2d(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 32, 32,21,21)

        self.conv5 = SpectralConv2d(8*factor*self.d_co_domain, self.d_co_domain, 48, 48,42,42) # will be reshaped

        self.w0 = pointwise_op(self.d_co_domain,4*factor*self.d_co_domain,75, 75) #
        
        self.w1 = pointwise_op(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w2 = pointwise_op(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25) #
        
        self.w2_1 = pointwise_op(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 12, 12)

        self.w2_9 = pointwise_op(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25)
        
        self.w3 = pointwise_op(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w4 = pointwise_op(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 75, 75)
        
        self.w5 = pointwise_op(8*factor*self.d_co_domain, self.d_co_domain, 100, 100) # will be reshaped

        self.fc1 = nn.Linear(2*self.d_co_domain, 4*self.d_co_domain)
        self.fc2 = nn.Linear(4*self.d_co_domain, self.d_co_domain)

    def forward(self, x):
        
        grid = self.get_grid(x[0].shape, x.device)
        x = x.permute(0,2,3,1)

        x_fc0 = self.fc0(x)

        x_fc0 = F.gelu(x_fc0)
        
        x_fc0 = x_fc0.permute(0, 3, 1, 2)
        
        
        x_fc0 = F.pad(x_fc0, [0,self.padding, 0,self.padding])
        
        D1,D2 = x_fc0.shape[-2],x_fc0.shape[-1]
        
        x1_c0 = self.conv0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x2_c0 = self.w0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x_c0 = x1_c0 + x2_c0
        x_c0 = F.gelu(x_c0)

        x1_c1 = self.conv1(x_c0 ,D1//2,D2//2)
        x2_c1 = self.w1(x_c0 ,D1//2,D2//2)
        x_c1 = x1_c1 + x2_c1
        x_c1 = F.gelu(x_c1)

        x1_c2 = self.conv2(x_c1 ,D1//4,D2//4)
        x2_c2 = self.w2(x_c1 ,D1//4,D2//4)
        x_c2 = x1_c2 + x2_c2
        x_c2 = F.gelu(x_c2 )

        x1_c2_1 = self.conv2_1(x_c2,D1//8,D2//8)
        x2_c2_1 = self.w2_1(x_c2,D1//8,D2//8)
        x_c2_1 = x1_c2_1 + x2_c2_1
        x_c2_1 = F.gelu(x_c2_1)

        
        x1_c2_9 = self.conv2_9(x_c2_1,D1//4,D2//4)
        x2_c2_9 = self.w2_9(x_c2_1,D1//4,D2//4)
        x_c2_9 = x1_c2_9 + x2_c2_9
        x_c2_9 = F.gelu(x_c2_9)
        x_c2_9 = torch.cat([x_c2_9, x_c2], dim=1)

        x1_c3 = self.conv3(x_c2_9,D1//2,D2//2)
        x2_c3 = self.w3(x_c2_9,D1//2,D2//2)
        x_c3 = x1_c3 + x2_c3
        x_c3 = F.gelu(x_c3)
        x_c3 = torch.cat([x_c3, x_c1], dim=1)

        x1_c4 = self.conv4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x2_c4 = self.w4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x_c4 = x1_c4 + x2_c4
        x_c4 = F.gelu(x_c4)
        x_c4 = torch.cat([x_c4, x_c0], dim=1)

        x1_c5 = self.conv5(x_c4,D1,D2)
        x2_c5 = self.w5(x_c4,D1,D2)
        x_c5 = x1_c5 + x2_c5
        x_c5 = F.gelu(x_c5)


        x_c5 = torch.cat([x_c5, x_fc0], dim=1)
        if self.padding!=0:
            x_c5 = x_c5[..., :-self.padding, :-self.padding]

        x_c5 = x_c5.permute(0, 2, 3, 1)
        
        x_fc1 = self.fc1(x_c5)
        x_fc1 = F.gelu(x_fc1)
        
        x_out = self.fc2(x_fc1)
        x_out = x_out.permute(0,3,1,2)
        x_out = F.gelu(x_out)
        return x_out
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).cuda()

In [8]:
class SpectralConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, dim1, dim2,modes1 = None, modes2 = None):
        super(SpectralConv2d, self).__init__()

        """
        2D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """
        in_channels = int(in_channels)
        out_channels = int(out_channels)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dim1 = dim1 #output dimensions
        self.dim2 = dim2
        #self.batch_norm = nn.BatchNorm2d(out_channels)
        #self.dropout = nn.Dropout(p=0.42)
        if modes1 is not None:
            self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
            self.modes2 = modes2
        else:
            self.modes1 = dim1//2 + 1 #if not given take the highest number of modes can be taken
            self.modes2 = dim2//2 
        self.scale = (1 / (2*in_channels))**(1.0/2.0)
        self.weights1 = nn.Parameter(self.scale * (torch.randn(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)))
        self.weights2 = nn.Parameter(self.scale * (torch.randn(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)))

    # Complex multiplication
    def compl_mul2d(self, input, weights):
        # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y)
        return torch.einsum("bixy,ioxy->boxy", input, weights)

    def forward(self, x, dim1 = None,dim2 = None):
        if dim1 is not None:
            self.dim1 = dim1
            self.dim2 = dim2
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft2(x)

        # Multiply relevant Fourier modes

        out_ft = torch.zeros(batchsize, self.out_channels,  self.dim1, self.dim2//2 + 1 , dtype=torch.cfloat, device=x.device)
        #print("Out FT Shape = ", out_ft.shape)
        #print("x_ft Shape = ", x_ft.shape)
        #print("Modes1 = ", self.modes1, " Modes2 = ", self.modes2, " self.dim1 = ", self.dim1, " self.dim2//2 + 1 = ", self.dim2//2+1)
        out_ft[:, :, :self.modes1, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2] = \
            self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)

        #Return to physical space
        x = torch.fft.irfft2(out_ft, s=(self.dim1, self.dim2))
        #x = self.dropout(x)
        #x = self.batch_norm(x)
        return x


class pointwise_op(nn.Module):
    def __init__(self, in_channel, out_channel,dim1, dim2):
        super(pointwise_op,self).__init__()
        self.conv = nn.Conv2d(int(in_channel), int(out_channel), 1)
        self.dim1 = int(dim1)
        self.dim2 = int(dim2)
        #self.alpha_drop = nn.AlphaDropout(p=0.42)
        #self.layer_norm = nn.LayerNorm([int(out_channel), self.dim1, self.dim2])
        #self.batch_norm = nn.BatchNorm2d(int(out_channel))
        #self.upsample = nn.Upsample(size=[self.dim1, self.dim2], mode='bicubic', align_corners=True)

    def forward(self,x, dim1 = None, dim2 = None):
        if dim1 is None:
            dim1 = self.dim1
            dim2 = self.dim2
        x_out = self.conv(x)
        x_out = torch.nn.functional.interpolate(x_out, size = (dim1, dim2),mode = 'bicubic',align_corners=True)
        #x_out = self.upsample(x_out)
        #x_out = self.alpha_drop(x_out)
        #x_out = self.layer_norm(x_out)
        #x_out = self.batch_norm(x_out)
        return x_out

class UNO(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 16/4):
        super(UNO, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the coefficient function and locations (a(x, y), x, y)
        input shape: (batchsize, x=s, y=s, c=3)
        output: the solution 
        output shape: (batchsize, x=s, y=s, c=1)
        """
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.padding = pad  # pad the domain if input is non-periodic
        self.reduceColor = nn.Conv2d(3, 1, kernel_size=1)
        self.interpolation = nn.Conv2d(3, 3, kernel_size=1)

        self.fc0 = nn.Linear(self.in_d_co_domain, self.d_co_domain) # input channel is 3: (a(x, y), x, y)

        self.conv0 = SpectralConv2d(self.d_co_domain, 4*factor*self.d_co_domain, 16, 16, 42, 42)

        self.conv1 = SpectralConv2d(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16, 21,21)

        self.conv2 = SpectralConv2d(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,10,10)
        
        self.conv2_1 = SpectralConv2d(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 4, 4,5,5)
        
        self.conv2_9 = SpectralConv2d(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 8, 8,5,5)
    

        self.conv3 = SpectralConv2d(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 16, 16,10,10)

        self.conv4 = SpectralConv2d(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 32, 32,21,21)

        self.conv5 = SpectralConv2d(8*factor*self.d_co_domain, self.d_co_domain, 48, 48,42,42) # will be reshaped

        self.w0 = pointwise_op(self.d_co_domain,4*factor*self.d_co_domain,75, 75) #
        
        self.w1 = pointwise_op(4*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w2 = pointwise_op(8*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25) #
        
        self.w2_1 = pointwise_op(16*factor*self.d_co_domain, 32*factor*self.d_co_domain, 12, 12)
        """
        #VAE Code begin
        self.linearFlatten1 = nn.Linear(384*12*12, 128)
        self.linearFlatten2 = nn.Linear(128, 4)
        self.linearFlatten3 = nn.Linear(128, 4)
        self.softmax1 = nn.Softmax(dim=1)
        self.N = torch.distributions.Normal(0, 10)
        self.N.loc = self.N.loc.cuda()
        self.N.scale = self.N.scale.cuda()
        self.kl = 0
        self.linearUnflatten = nn.Linear(4, 128)
        self.relu = nn.ReLU(True)
        self.linearUnflatten2 = nn.Linear(128, 384*12*12)
        self.relu2 = nn.ReLU(True)
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(384, 12, 12))
        #VAE Code End
        """
        self.w2_9 = pointwise_op(32*factor*self.d_co_domain, 16*factor*self.d_co_domain, 25, 25)
        
        self.w3 = pointwise_op(32*factor*self.d_co_domain, 8*factor*self.d_co_domain, 50, 50) #
        
        self.w4 = pointwise_op(16*factor*self.d_co_domain, 4*factor*self.d_co_domain, 75, 75)
        
        self.w5 = pointwise_op(8*factor*self.d_co_domain, self.d_co_domain, 100, 100) # will be reshaped

        self.fc1 = nn.Linear(2*self.d_co_domain, 4*self.d_co_domain)
        self.fc2 = nn.Linear(4*self.d_co_domain, 3)
        self.increaseColor = nn.Conv2d(1, 3, kernel_size=1)

    def forward(self, x):
        
        grid = self.get_grid(x[0].shape, x.device)
        #print(grid.shape)
        
        #print(x[0].shape)
        #x_l = torch.cat((x[0], grid), dim=-1).cuda()
        #x_r = torch.cat((x[1], grid), dim=-1).cuda()

        x_l = x[0]
        x_r = x[1]
        #print(x_l.shape)
        x1 = self.interpolation(x_l)
        x2 = self.interpolation(x_r)
        ##print(x1.shape)
        x = x1 + x2
        x3 = x
        #x = self.reduceColor(x)
        x = x.permute(0,2,3,1)
        #x = torch.cat((x, grid), dim=-1).cuda()
        #print(x.shape)
        x_fc0 = self.fc0(x)

        x_fc0 = F.gelu(x_fc0)
        
        x_fc0 = x_fc0.permute(0, 3, 1, 2)
        
        
        x_fc0 = F.pad(x_fc0, [0,self.padding, 0,self.padding])
        
        D1,D2 = x_fc0.shape[-2],x_fc0.shape[-1]
        
        x1_c0 = self.conv0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x2_c0 = self.w0(x_fc0,int(D1*self.factor2),int(D2*self.factor2))
        x_c0 = x1_c0 + x2_c0
        x_c0 = F.gelu(x_c0)

        x1_c1 = self.conv1(x_c0 ,D1//2,D2//2)
        x2_c1 = self.w1(x_c0 ,D1//2,D2//2)
        x_c1 = x1_c1 + x2_c1
        x_c1 = F.gelu(x_c1)

        x1_c2 = self.conv2(x_c1 ,D1//4,D2//4)
        x2_c2 = self.w2(x_c1 ,D1//4,D2//4)
        x_c2 = x1_c2 + x2_c2
        x_c2 = F.gelu(x_c2 )

        x1_c2_1 = self.conv2_1(x_c2,D1//8,D2//8)
        x2_c2_1 = self.w2_1(x_c2,D1//8,D2//8)
        x_c2_1 = x1_c2_1 + x2_c2_1
        x_c2_1 = F.gelu(x_c2_1)
        """
        #Variational Autoencoder part - Extract mu and sigma
        x_c2_1 = torch.flatten(x_c2_1, start_dim=1)
        x_c2_1 = self.linearFlatten1(x_c2_1)
        mu = self.linearFlatten2(x_c2_1)
        sigma = self.linearFlatten3(x_c2_1)
        std = torch.exp(0.5*sigma)
        eps = torch.randn_like(std)
        z = sigma*self.N.sample(mu.shape) + mu
        #z = mu + std*eps

        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
        z = self.linearUnflatten(z)
        z = self.linearUnflatten2(z)
        x_c2_1 = self.unflatten(z)
        #print("x_c2_1 shape = ", x_c2_1.shape)
        #Variational Autoencoder -end
        """
        
        x1_c2_9 = self.conv2_9(x_c2_1,D1//4,D2//4)
        x2_c2_9 = self.w2_9(x_c2_1,D1//4,D2//4)
        x_c2_9 = x1_c2_9 + x2_c2_9
        x_c2_9 = F.gelu(x_c2_9)
        x_c2_9 = torch.cat([x_c2_9, x_c2], dim=1)

        x1_c3 = self.conv3(x_c2_9,D1//2,D2//2)
        x2_c3 = self.w3(x_c2_9,D1//2,D2//2)
        x_c3 = x1_c3 + x2_c3
        x_c3 = F.gelu(x_c3)
        x_c3 = torch.cat([x_c3, x_c1], dim=1)

        x1_c4 = self.conv4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x2_c4 = self.w4(x_c3,int(D1*self.factor2),int(D2*self.factor2))
        x_c4 = x1_c4 + x2_c4
        x_c4 = F.gelu(x_c4)
        x_c4 = torch.cat([x_c4, x_c0], dim=1)

        x1_c5 = self.conv5(x_c4,D1,D2)
        x2_c5 = self.w5(x_c4,D1,D2)
        x_c5 = x1_c5 + x2_c5
        x_c5 = F.gelu(x_c5)


        x_c5 = torch.cat([x_c5, x_fc0], dim=1)
        if self.padding!=0:
            x_c5 = x_c5[..., :-self.padding, :-self.padding]

        x_c5 = x_c5.permute(0, 2, 3, 1)
        
        x_fc1 = self.fc1(x_c5)
        x_fc1 = F.gelu(x_fc1)
        
        x_out = self.fc2(x_fc1)
        
        #x3 = torch.permute(x3, (0, 2,3,1))
        x3 = (x_l+x_r)/2
        #print(x3.shape)
        x3 = torch.permute(x_out, (0,3,1,2))
        #x_fin = x_out + x3
        #x_fin = torch.permute(x_fin, (0, 3, 1, 2))
        #x_out = self.interpolation(x_fin)
        #x_out = torch.permute(x_out, (0, 3, 1, 2))
        #print(x_out.shape)
        #x_out = self.increaseColor(x_out)
        return x3
    
    def get_grid(self, shape, device):
        batchsize, size_x, size_y = shape[0], shape[1], shape[2]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1])
        gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float)
        gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1])
        return torch.cat((gridx, gridy), dim=-1).cuda()

In [9]:
class NIO(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 10/4):
        super(NIO, self).__init__()
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.interpolation = nn.Conv2d(3, 3, kernel_size=1)
        self.UNO_block_1 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_2 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_3 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_4 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        #self.UNO_block_5 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=20/4)
    def forward(self, x):
        uno_1_out = self.UNO_block_1(x)
        #print(uno_1_out.shape)
        #print(x[0].shape)
        #print(x.shape)
        x_int_l = self.interpolation(x[0])
        x_int_r = self.interpolation(x[1])
        x_int = x_int_l + x_int_r
        #x_int = torch.permute(x_int, (0,2,3,1))
        y = torch.stack((x_int, uno_1_out))
        
        #y = x[0] + uno_1_out
        #print("Y Shape = ", y.shape)
        uno_2_out = self.UNO_block_2(y)
        y = torch.stack((x_int, uno_2_out))
        #y = x[1] + uno_2_out
        uno_3_out = self.UNO_block_3(y)
        y = torch.stack((x_int, uno_3_out))
        uno_4_out = self.UNO_block_4(x)
        #y = torch.stack((x_int, uno_4_out))
        #x = self.UNO_block_5(y)
        return uno_4_out

In [10]:
class NIO_split_channel(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad=0, factor=3/4):
        super(NIO_split_channel, self).__init__()
        #self.UNO_block_red = UNO_vanilla(1, 1, pad=0, factor=16/4)
        #self.UNO_block_green = UNO_vanilla(1, 1, pad=0, factor=16/4)
        #self.UNO_block_blue = UNO_vanilla(1, 1, pad=0, factor=16/4)
        self.UNO_block_1 = UNO_vanilla(6, 3, pad=0, factor=16/4)
        self.UNO_block_2 = UNO_vanilla(9, 3, pad=0, factor=16/4)
        #self.UNO_block_3 = UNO_vanilla(12, 3, pad=0, factor=16/4)
        #self.UNO_block_4 = UNO_vanilla(15, 3, pad=0, factor=16/4)
        #self.UNO_block_merge = UNO_vanilla(3, 3, pad=0, factor=16/4)
        
    def forward(self, x):
        x_l = x[0].permute(0,3,1,2)
        x_r = x[1].permute(0,3,1,2)
        x_l_colors = x_l.permute(1,0,2,3)
        x_r_colors = x_r.permute(1,0,2,3)
        
        input_stack_1 = torch.cat((x_l, x_r), dim=1)
        
        stack_out_1 = self.UNO_block_1(input_stack_1)

        #print(input_stack.shape)
        #input_sum = (x_l+x_r)/2.0
        
        #input_stack_2 = torch.cat((stack_out_1, input_stack_1), dim=1)
        #print(input_stack_2.shape)
        #stack_out_2 = self.UNO_block_2(input_stack_2)
        
        #input_stack_3 = torch.cat((stack_out_2, input_stack_2), dim=1)
        #stack_out_3 = self.UNO_block_3(input_stack_3)
        
        #input_stack_4 = torch.cat((stack_out_3, input_stack_3), dim=1)
        #stack_out_4 = self.UNO_block_4(input_stack_4)
        
        #add_out = self.UNO_block_add(input_sum)
        #x_out = self.UNO_block_merge(stack_out_4)
        """
        #print(x_l_colors.shape)
        x_l_red = x_l_colors[0]
        x_l_green = x_l_colors[1]
        x_l_blue = x_l_colors[2]
        
        x_r_red = x_r_colors[0]
        x_r_green = x_r_colors[1]
        x_r_blue = x_r_colors[2]
        
        #x_red = torch.stack((x_l_red, x_r_red)).permute(1,0,2,3)
        #x_green = torch.stack((x_l_green, x_r_green)).permute(1,0,2,3)
        #x_blue = torch.stack((x_l_blue, x_r_blue)).permute(1,0,2,3)
        x_red = (x_r_red + x_l_red).reshape(x.shape[1], 1, x.shape[2], x.shape[3])
        x_green = (x_r_green + x_l_green).reshape(x.shape[1], 1, x.shape[2], x.shape[3])
        x_blue = (x_l_blue + x_r_blue).reshape(x.shape[1], 1, x.shape[2], x.shape[3])
        
        r = self.UNO_block_red(x_red)
        g = self.UNO_block_green(x_green)
        b = self.UNO_block_blue(x_blue)
        uno_x = torch.cat((r,g,b), dim=1) + 0.001*self.sub_mean(x_l)[0] + 0.001*self.sub_mean(x_r)[0]
        x_out = self.UNO_block_merge(uno_x)
        """
        return stack_out_1
    
    def sub_mean(self, x):
        mean = x.mean(2, keepdim=True).mean(3, keepdim=True)
        x -= mean
        return x, mean
        

In [11]:
class NIO_fine(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad=0, factor=3/4):
        super(NIO_fine, self).__init__()
        self.in_d_co_domain = in_d_co_domain
        self.d_co_domain = d_co_domain
        self.factor = factor
        self.factor2 = factor/4
        self.UNO_block = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        #self.UNO_block2 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=12/4)
        """self.layer1 = nn.Conv2d(3, 3, 3, padding=1)
        self.norm1 = nn.BatchNorm2d(3)
        self.layer2 = nn.Conv2d(3, 3, 3, padding=1)
        self.norm2 = nn.BatchNorm2d(3)
        self.layer3 = nn.Conv2d(3, 3, 3, padding=1)
        self.norm3 = nn.BatchNorm2d(3)
        self.layer4 = nn.Conv2d(3, 3, 3, padding=1)
        self.norm4 = nn.BatchNorm2d(3)
        self.layer5 = nn.Conv2d(3, 3, 3, padding=1)
        self.norm5 = nn.BatchNorm2d(3)
        #self.layer2 = nn.Conv2d(3, 3, 3, padding=1)"""
    def forward(self, x):
        UNO_out = self.UNO_block(x)
        #UNO_out = torch.stack((x[0], UNO_out))
        #UNO_out = self.UNO_block2(UNO_out)
        #UNO_out = torch.permute(UNO_out, (0,3, 1, 2))
        #layer1out = self.norm1(self.layer1(UNO_out))
        #layer2out = self.norm2(self.layer2(layer1out+UNO_out))
        #layer3out = self.norm3(self.layer3(layer2out+layer1out))
        #layer4out = self.norm4(self.layer4(layer3out+layer2out))
        #layer5out = self.norm5(self.layer5(layer4out+layer3out))
        #layer5out = torch.permute(layer5out, (0,2,3,1))
        #output = layer5out
        return UNO_out

In [12]:
class NIO_huge(nn.Module):
    def __init__(self, in_d_co_domain, d_co_domain, pad = 0, factor = 10/4):
        super(NIO_huge, self).__init__()
        self.in_d_co_domain = in_d_co_domain # input channel
        self.d_co_domain = d_co_domain 
        self.factor = factor
        self.factor2 = factor/4
        self.interpolation = nn.Conv2d(3, 3, kernel_size=1)
        self.UNO_block_1 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_2 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_3 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_4 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
        self.UNO_block_5 = UNO(self.in_d_co_domain, d_co_domain, pad=0, factor=16/4)
    def forward(self, x):
        uno_1_out = self.UNO_block_1(x)
        #print(uno_1_out.shape)
        #print(x[0].shape)
        #print(x.shape)
        x_int_l = self.interpolation(x[0].permute(0,3,1,2))
        x_int_r = self.interpolation(x[1].permute(0,3,1,2))
        x_int = x_int_l + x_int_r
        x_int = torch.permute(x_int, (0,2,3,1))
        y = torch.stack((x_int, uno_1_out))
        
        #y = x[0] + uno_1_out
        #print("Y Shape = ", y.shape)
        uno_2_out = self.UNO_block_2(y)
        y = torch.stack((x_int, uno_2_out))
        #y = x[1] + uno_2_out
        uno_3_out = self.UNO_block_3(y)
        y = torch.stack((x_int, uno_3_out))
        uno_4_out = self.UNO_block_4(x)
        y = torch.stack((x_int, uno_4_out))
        x = self.UNO_block_5(y)
        return x

In [13]:
G = NIO(3, 3, pad=1).cuda()
#NF = NIO_fine(3,4, pad=1).cuda()
#nn_params = sum(p.numel() for p in G.parameters() if p.requires_grad)
#print("Number generator parameters: ", nn_params)
#randTensor = torch.rand((200,100,100,3), dtype=torch.float32).cuda()
summary(G, input_size=(2,10, 3,85,85))


Layer (type:depth-idx)                   Output Shape              Param #
NIO                                      [10, 3, 85, 85]           --
├─UNO: 1-1                               [10, 3, 85, 85]           10
│    └─Conv2d: 2-1                       [10, 3, 85, 85]           12
│    └─Conv2d: 2-2                       [10, 3, 85, 85]           (recursive)
│    └─Linear: 2-3                       [10, 85, 85, 3]           12
│    └─SpectralConv2d: 2-4               [10, 48, 85, 85]          508,032
│    └─pointwise_op: 2-5                 [10, 48, 85, 85]          --
│    │    └─Conv2d: 3-1                  [10, 48, 85, 85]          192
│    └─SpectralConv2d: 2-6               [10, 96, 42, 42]          4,064,256
│    └─pointwise_op: 2-7                 [10, 96, 42, 42]          --
│    │    └─Conv2d: 3-2                  [10, 96, 85, 85]          4,704
│    └─SpectralConv2d: 2-8               [10, 192, 21, 21]         3,686,400
│    └─pointwise_op: 2-9                 [10, 192, 21

In [14]:
def _fspecial_gauss_1d(size, sigma):
    r"""Create 1-D gauss kernel
    Args:
        size (int): the size of gauss kernel
        sigma (float): sigma of normal distribution
    Returns:
        torch.Tensor: 1D kernel (1 x 1 x size)
    """
    coords = torch.arange(size, dtype=torch.float)
    coords -= size // 2

    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g /= g.sum()

    return g.unsqueeze(0).unsqueeze(0)


def gaussian_filter(input, win):
    r""" Blur input with 1-D kernel
    Args:
        input (torch.Tensor): a batch of tensors to be blurred
        window (torch.Tensor): 1-D gauss kernel
    Returns:
        torch.Tensor: blurred tensors
    """
    assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
    if len(input.shape) == 4:
        conv = F.conv2d
    elif len(input.shape) == 5:
        conv = F.conv3d
    else:
        raise NotImplementedError(input.shape)

    C = input.shape[1]
    out = input
    for i, s in enumerate(input.shape[2:]):
        if s >= win.shape[-1]:
            out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
        else:
            warnings.warn(
                f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
            )

    return out


def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)):

    r""" Calculate ssim index for X and Y
    Args:
        X (torch.Tensor): images
        Y (torch.Tensor): images
        win (torch.Tensor): 1-D gauss kernel
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
    Returns:
        torch.Tensor: ssim results.
    """
    K1, K2 = K
    # batch, channel, [depth,] height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range) ** 2
    C2 = (K2 * data_range) ** 2

    win = win.to(X.device, dtype=X.dtype)

    mu1 = gaussian_filter(X, win)
    mu2 = gaussian_filter(Y, win)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
    sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
    sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)  # set alpha=beta=gamma=1
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
    cs = torch.flatten(cs_map, 2).mean(-1)
    return ssim_per_channel, cs


def ssim(
    X,
    Y,
    data_range=255,
    size_average=True,
    win_size=11,
    win_sigma=1.5,
    win=None,
    K=(0.01, 0.03),
    nonnegative_ssim=False,
):
    r""" interface of ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,H,W)
        Y (torch.Tensor): a batch of images, (N,C,H,W)
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
    Returns:
        torch.Tensor: ssim results
    """
    if not X.shape == Y.shape:
        raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")

    for d in range(len(X.shape) - 1, 1, -1):
        X = X.squeeze(dim=d)
        Y = Y.squeeze(dim=d)

    if len(X.shape) not in (4, 5):
        raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")

    if not X.type() == Y.type():
        raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")

    if win is not None:  # set win_size
        win_size = win.shape[-1]

    if not (win_size % 2 == 1):
        raise ValueError("Window size should be odd.")

    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
    if nonnegative_ssim:
        ssim_per_channel = torch.relu(ssim_per_channel)

    if size_average:
        return ssim_per_channel.mean()
    else:
        return ssim_per_channel.mean(1)


def ms_ssim(
    X, Y, data_range=255, size_average=True, win_size=11, win_sigma=1.5, win=None, weights=None, K=(0.01, 0.03)
):

    r""" interface of ms-ssim
    Args:
        X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
        Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
        data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
        size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
        win_size: (int, optional): the size of gauss kernel
        win_sigma: (float, optional): sigma of normal distribution
        win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
        weights (list, optional): weights for different levels
        K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
    Returns:
        torch.Tensor: ms-ssim results
    """
    if not X.shape == Y.shape:
        raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")

    for d in range(len(X.shape) - 1, 1, -1):
        X = X.squeeze(dim=d)
        Y = Y.squeeze(dim=d)

    if not X.type() == Y.type():
        raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")

    if len(X.shape) == 4:
        avg_pool = F.avg_pool2d
    elif len(X.shape) == 5:
        avg_pool = F.avg_pool3d
    else:
        raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")

    if win is not None:  # set win_size
        win_size = win.shape[-1]

    if not (win_size % 2 == 1):
        raise ValueError("Window size should be odd.")

    smaller_side = min(X.shape[-2:])
    assert smaller_side > (win_size - 1) * (
        2 ** 4
    ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))

    if weights is None:
        weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
    weights = X.new_tensor(weights)

    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))

    levels = weights.shape[0]
    mcs = []
    for i in range(levels):
        ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)

        if i < levels - 1:
            mcs.append(torch.relu(cs))
            padding = [s % 2 for s in X.shape[2:]]
            X = avg_pool(X, kernel_size=2, padding=padding)
            Y = avg_pool(Y, kernel_size=2, padding=padding)

    ssim_per_channel = torch.relu(ssim_per_channel)  # (batch, channel)
    mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0)  # (level, batch, channel)
    ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0)

    if size_average:
        return ms_ssim_val.mean()
    else:
        return ms_ssim_val.mean(1)


class SSIM(torch.nn.Module):
    def __init__(
        self,
        data_range=255,
        size_average=True,
        win_size=11,
        win_sigma=1.5,
        channel=3,
        spatial_dims=2,
        K=(0.01, 0.03),
        nonnegative_ssim=False,
    ):
        r""" class for ssim
        Args:
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            channel (int, optional): input channels (default: 3)
            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
            nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
        """

        super(SSIM, self).__init__()
        self.win_size = win_size
        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
        self.size_average = size_average
        self.data_range = data_range
        self.K = K
        self.nonnegative_ssim = nonnegative_ssim

    def forward(self, X, Y):
        return ssim(
            X,
            Y,
            data_range=self.data_range,
            size_average=self.size_average,
            win=self.win,
            K=self.K,
            nonnegative_ssim=self.nonnegative_ssim,
        )


class MS_SSIM(torch.nn.Module):
    def __init__(
        self,
        data_range=255,
        size_average=True,
        win_size=11,
        win_sigma=1.5,
        channel=3,
        spatial_dims=2,
        weights=None,
        K=(0.01, 0.03),
    ):
        r""" class for ms-ssim
        Args:
            data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
            size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
            win_size: (int, optional): the size of gauss kernel
            win_sigma: (float, optional): sigma of normal distribution
            channel (int, optional): input channels (default: 3)
            weights (list, optional): weights for different levels
            K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
        """

        super(MS_SSIM, self).__init__()
        self.win_size = win_size
        self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
        self.size_average = size_average
        self.data_range = data_range
        self.weights = weights
        self.K = K

    def forward(self, X, Y):
        return ms_ssim(
            X,
            Y,
            data_range=self.data_range,
            size_average=self.size_average,
            win=self.win,
            weights=self.weights,
            K=self.K,
        )

In [15]:
import torch
import torchvision

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval().cuda())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval().cuda())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval().cuda())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval().cuda())
        for bl in blocks:
            for p in bl.parameters():
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks).cuda()
        self.transform = torch.nn.functional.interpolate
        self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = input
        y = target
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
            if i in feature_layers:
                loss += torch.nn.functional.l1_loss(x, y)
            if i in style_layers:
                act_x = x.reshape(x.shape[0], x.shape[1], -1)
                act_y = y.reshape(y.shape[0], y.shape[1], -1)
                gram_x = act_x @ act_x.permute(0, 2, 1)
                gram_y = act_y @ act_y.permute(0, 2, 1)
                loss += torch.nn.functional.l1_loss(gram_x, gram_y)
        return loss

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class MeanShift(nn.Conv2d):
    def __init__(self, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.requires_grad = False

class VGG(nn.Module):
    def __init__(self, loss_type):
        super(VGG, self).__init__()
        vgg_features = models.vgg19(pretrained=True).features
        modules = [m.cuda() for m in vgg_features]
        conv_index = loss_type
        if conv_index == '22':
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index == '33':
            self.vgg = nn.Sequential(*modules[:16])
        elif conv_index == '44':
            self.vgg = nn.Sequential(*modules[:26])
        elif conv_index == '54':
            self.vgg = nn.Sequential(*modules[:35])
        elif conv_index == 'P':
            self.vgg = nn.ModuleList([
                nn.Sequential(*modules[:8]),
                nn.Sequential(*modules[8:16]),
                nn.Sequential(*modules[16:26]),
                nn.Sequential(*modules[26:35])
            ])
        self.vgg = nn.DataParallel(self.vgg).cuda()

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229, 0.224, 0.225)
        self.sub_mean = MeanShift(vgg_mean, vgg_std)
        self.vgg.requires_grad = False
        # self.criterion = nn.L1Loss()
        self.conv_index = conv_index

    def forward(self, sr, hr):
        def _forward(x):
            x = x.cpu()
            x = self.sub_mean(x)
            x = self.vgg(x)
            return x.cuda()
        def _forward_all(x):
            feats = []
            x = x.cpu()
            x = self.sub_mean(x)
            for module in self.vgg.module:
                x = module(x.cuda())
                feats.append(x.cuda())
            return feats

        if self.conv_index == 'P':
            vgg_sr_feats = _forward_all(sr)
            with torch.no_grad():
                vgg_hr_feats = _forward_all(hr.detach())
            loss = 0
            for i in range(len(vgg_sr_feats)):
                loss_f = F.mse_loss(vgg_sr_feats[i], vgg_hr_feats[i])
                #print(loss_f)
                loss += loss_f
            #print()
        else:
            vgg_sr = _forward(sr)
            with torch.no_grad():
                vgg_hr = _forward(hr.detach())
            loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss

In [17]:
from datetime import datetime as time
from skimage.metrics import structural_similarity as ssim
import torchvision.transforms.functional as FF
import cv2
import PIL
myloss = torch.nn.MSELoss()
huberloss = torch.nn.HuberLoss()
kldivloss = torch.nn.KLDivLoss()
vggloss = VGG('22')
RHO = 0.05
BETA = 0.01
def kl_divergence(rho, rho_hat):
    rho_hat = torch.mean(F.sigmoid(rho_hat), 1) # sigmoid because we need the probability distributions
    rho = torch.tensor([rho] * len(rho_hat)).cuda()
    return torch.sum(rho * torch.log(rho/rho_hat) + (1 - rho) * torch.log((1 - rho)/(1 - rho_hat)))

# define the sparse loss function
def sparse_loss(rho, images):
    values = images
    loss = 0
    return kl_divergence(rho, images)
#ssim_loss = SSIM()
epochs = 2500
class MS_SSIM_Loss(MS_SSIM):
    def forward(self, img1, img2):
        return 100*( 1 - super(MS_SSIM_Loss, self).forward(img1, img2) )

class SSIM_Loss(SSIM):
    def forward(self, img1, img2):
        return 100*( 1 - super(SSIM_Loss, self).forward(img1, img2) )
ssim_loss = MS_SSIM_Loss(win_size=3, win_sigma=1.5, data_range=1.0, size_average=True, channel=3)
def train_JSGANO(D, G, train_data, epochs, D_optim, G_optim, scheduler=None):
    losses_D = np.zeros(epochs)
    losses_G = np.zeros(epochs)
    ssims_G = np.zeros(epochs)
    #train_data,y_normalizer  = LoadDataBatches(0, isNormalized=True, batches=10, isTrain=True, percent=0.1)
    train_data = LoadDataBatches(0, isNormalized=False, batches=2, isTrain=True)
    #base_G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_vgg_rgb150.pt')
    #base_G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_lowres_rgb300.pt')
    #complex_G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_fine_norm_rgb200.pt')
    for i in range(1, epochs+1):
        t1 = time.now()
        loss_D = 0.0
        loss_G = 0.0
        ssim_G = 0.0
        train_counter = 0
        ssim_batch=0
        for j in range(1):
            #del train_data
            #train_data = LoadDataBatches(j, batches=10)
            for xl,xr,y in train_data:
                train_counter += 1
                temp = 0.0001*xl + 0.0001*xr
                temp = temp.cuda()
                xl = xl.cpu().detach().numpy()
                xr = xr.cpu().detach().numpy()
                #y = y.cpu().detach().numpy()
                x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
                y = y.cuda()
                x = x.cuda()
                #y = y.cuda()
                G_optim.zero_grad()
                #base_x_out = base_G(x).cuda()
                #x_avg = torch.from_numpy(np.asarray((xl+xr)/2.0).astype(np.float32)).cuda()
                #x = torch.stack((base_x_out, base_x_out))
                #complex_x_out = complex_G(x).reshape(x.shape[1], x.shape[2],x.shape[3],x.shape[4])
                #x = torch.stack((complex_x_out, complex_x_out))
                #x = torch.permute(complex_x_out, (0, 3, 1, 2))
                x_syn = G(x)
                #x_complex = complex_G(x)
                #x_complex = y_normalizer.decode(x_complex)
                #print(x_syn.shape)
                #base_x_out = base_G(x).permute(0,3,1,2)
                #x_syn = x_syn #+ 0.001*base_x_out
                #x_syn = torch.permute(x_syn, (0, 2, 3, 1))
                #x_syn = x_syn + temp.cuda()
                
                #x_syn = 0.6*x_syn #+ 0.4*base_x_out
                #x_syn = y_normalizer.decode(x_syn)
                #y = y_normalizer.decode(y)
                x_syn2 = torch.permute(x_syn, (0, 2, 3, 1)).type(torch.float32)
                #x_syn2 = x_syn2.detach().cpu().numpy()
                #for k in range(x_syn2.shape[0]):
                    #x_syn2[k] = FF.adjust_saturation(x_syn2[k], 1.5)
                #x_syn2 = torch.from_numpy(x_syn2.astype(np.float32))
                y2 = torch.permute(y, (0, 2, 3,1)).type(torch.float32)
                y = y.reshape(y.shape[0], y.shape[1],y.shape[2],y.shape[3])
                #print("Y2 = ", y2.shape, " Y = ", y.shape, " x_syn = ", x_syn.shape, " x_syn2 = ", x_syn2.shape)
                
                W_loss = huberloss(x_syn.type(torch.float32), y.type(torch.float32)) #+ 0.001*vggloss(x_syn2.type(torch.float32), y2.type(torch.float32))#+ 0.01*torch.abs(kldivloss(x_syn2.type(torch.float32), y2.type(torch.float32)))
                ssim_value = W_loss.item()
                ssim_individual = 0.0
                for k in range(y.shape[0]):
                    ssim_individual += ssim(x_syn2[k].detach().cpu().numpy(), y2[k].detach().cpu().numpy(), multichannel=True)
                
                ssim_batch += ssim_individual/y.shape[0]
                loss = W_loss
                loss.backward()
                loss_G += loss.item()            
                G_optim.step()
                if(train_counter %25 == 0 or train_counter == 1):
                    print("Batch = ",train_counter," Train Loss = ", loss_G/train_counter, " SSIM = ", ssim_batch/(train_counter))
            losses_D[i] = loss_D / train_counter
            losses_G[i] = loss_G / train_counter
            ssims_G[i] = ssim_batch / train_counter
            t2 = time.now()
            test_counter = 1
            """
            test_counter = 0
            ssim_batch = 0
            for xl, xr, y in validation_data:
                temp = 0.05*xl + 0.1*xr
                xl = xl.cpu().detach().numpy()
                xr = xr.cpu().detach().numpy()
                x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
                x = x.cuda()
                y = y.cuda()
                x_syn = G(x).cuda()
                x_syn = x_syn + temp.cuda()
                x_syn = y_normalizer_test.decode(x_syn)
                y = y_normalizer_test.decode(y)
                x_syn2 = torch.permute(x_syn, (0, 3, 1, 2)).type(torch.float32)
                y2 = torch.permute(y, (0, 3, 1,2)).type(torch.float32)
                y = y.reshape(y.shape[0], y.shape[1],y.shape[2],y.shape[3])
                W_loss = huberloss(x_syn2.type(torch.float32), y2.type(torch.float32))*10000
                ssim_individual = 0.0
                for k in range(y.shape[0]):
                    ssim_individual += ssim(x_syn[k].detach().cpu().numpy(), y[k].detach().cpu().numpy(), multichannel=True)
                
                ssim_batch += ssim_individual/y.shape[0]
                loss = W_loss
                loss.backward()
                G_optim.step()
                loss_D += loss.item()
                test_counter += 1
                if(test_counter %50 == 0 or test_counter == 1):
                    print("Batch = ",test_counter," Validation Loss = ", loss_D/test_counter, " SSIM = ", ssim_batch/(test_counter))
            losses_D[i] = loss_D/test_counter
            """
            print("Loader #: ", j, " Epoch: ", i, "/ ",epochs," -  Time: ", t2-t1, "s - Validation Loss: ", losses_D[i], " - Train Loss: ", losses_G[i], " - train SSIM: ", ssims_G[i], " validation SSIM: ", ssim_batch/test_counter)
        if(i%20 == 0):
            torch.save(G, '/depot/bera89/data/hviswan/NIO_diff_data'+str(i)+'.pt')
            
    return losses_D, losses_G, D, G

In [18]:
def calculate_gradient_penalty(model, real_images, fake_images, device):
    """Calculates the gradient penalty loss for GANO"""
    # Random weight term for interpolation between real and fake data
    alpha = torch.randn((real_images.size(0), 1, 1, 1)).cuda()
    # Get random interpolation between real and fake data
    interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)

    model_interpolates = model(interpolates)
    grad_outputs = torch.ones(model_interpolates.size(), requires_grad=False).cuda()

    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=model_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1/np.sqrt(res * res)) ** 2)
    return gradient_penalty

In [None]:
#D = torch.load('/depot/bera89/data/hviswan/GANO_DISC300.pt')
#D.eval()
NF = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_lowres_vgg_rgb_huge800.pt')
summary(NF, input_size=(2,10, 85,85,3))
NF.train()
device = 'cuda:11'
#G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_full50.pt')
#D.train()
NF.train()
G_optimizer = torch.optim.Adam(NF.parameters(), lr=1e-3, weight_decay=1e-3)
#D_optimizer = torch.optim.Adam(D.parameters(), lr=lr, weight_decay=1e-4)
losses_D, losses_G, D, G = train_JSGANO(None, NF, None, epochs, None, G_optimizer)

In [19]:
#D = torch.load('/depot/bera89/data/hviswan/GANO_DISC300.pt')
#D.eval()
#G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_huge_rgb20.pt')
#G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_complex_norm_rgb400.pt')
#G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_vgg_rgb100.pt')
#G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_lowres_rgb300.pt')
#G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_lowres_vgg_rgb_huge940.pt')
#summary(G, input_size=(2,10, 3,85,85))
G.train()
device = 'cuda:11'
#G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_full50.pt')
#D.train()
G.train()
G_optimizer = torch.optim.Adam(G.parameters(), lr=1e-3, weight_decay=1e-4)
#D_optimizer = torch.optim.Adam(D.parameters(), lr=lr, weight_decay=1e-4)
losses_D, losses_G, D, G = train_JSGANO(None, G, None, epochs, None, G_optimizer)

torch.Size([2, 4400, 3, 256, 256])
torch.Size([2, 4400, 3, 256, 256])
torch.Size([4400, 3, 256, 256])


  ssim_individual += ssim(x_syn2[k].detach().cpu().numpy(), y2[k].detach().cpu().numpy(), multichannel=True)


Batch =  1  Train Loss =  0.01808980107307434  SSIM =  0.4963040351867676
Batch =  25  Train Loss =  0.050299235358834266  SSIM =  0.474745671749115
Batch =  50  Train Loss =  0.04651483990252018  SSIM =  0.47658284023404124
Batch =  75  Train Loss =  0.04389621866246064  SSIM =  0.47789837390184403
Batch =  100  Train Loss =  0.04261285550892353  SSIM =  0.4753354243189096
Batch =  125  Train Loss =  0.04097736694663763  SSIM =  0.48061338770389556
Batch =  150  Train Loss =  0.03960523299872875  SSIM =  0.4877914507190386
Batch =  175  Train Loss =  0.03804480896464416  SSIM =  0.4940035080909729
Batch =  200  Train Loss =  0.03626095469109714  SSIM =  0.500408344939351
Batch =  225  Train Loss =  0.033812807930840384  SSIM =  0.5135735202497906
Batch =  250  Train Loss =  0.03144796752370894  SSIM =  0.5231404889822007
Batch =  275  Train Loss =  0.029301699197597123  SSIM =  0.5370410579442978
Batch =  300  Train Loss =  0.027526127914121994  SSIM =  0.549080718656381
Batch =  325 

Batch =  375  Train Loss =  0.004663847658783198  SSIM =  0.7559711496432622
Batch =  400  Train Loss =  0.0047232862983946685  SSIM =  0.7537263138592243
Batch =  425  Train Loss =  0.0046815972854657205  SSIM =  0.7547705804600435
Batch =  450  Train Loss =  0.004642137735952727  SSIM =  0.756720436182287
Batch =  475  Train Loss =  0.004656219087222493  SSIM =  0.7560167724207828
Batch =  500  Train Loss =  0.004671575756161474  SSIM =  0.7544475065469742
Batch =  525  Train Loss =  0.004644907453163926  SSIM =  0.7538551415148236
Batch =  550  Train Loss =  0.004665819248713722  SSIM =  0.751953940933401
Batch =  575  Train Loss =  0.004700095163678508  SSIM =  0.7513903702860293
Batch =  600  Train Loss =  0.004682861519686412  SSIM =  0.7513046279797951
Batch =  625  Train Loss =  0.004698662759084255  SSIM =  0.7505844020843506
Batch =  650  Train Loss =  0.00467893469801101  SSIM =  0.750648327103028
Batch =  675  Train Loss =  0.004683684570525117  SSIM =  0.7514386234460053
B

Batch =  750  Train Loss =  0.004426995695801452  SSIM =  0.7588753321568171
Batch =  775  Train Loss =  0.004416786466589978  SSIM =  0.7586694799123271
Batch =  800  Train Loss =  0.004408074664024752  SSIM =  0.7586009951680899
Batch =  825  Train Loss =  0.00440317212004271  SSIM =  0.7588097379063115
Batch =  850  Train Loss =  0.004377384093685952  SSIM =  0.7590428896160687
Batch =  875  Train Loss =  0.004373155800426113  SSIM =  0.7589303884335926
Batch =  900  Train Loss =  0.004363353888491272  SSIM =  0.759194438672728
Batch =  925  Train Loss =  0.004339089296204415  SSIM =  0.759556931788857
Batch =  950  Train Loss =  0.004357575191748573  SSIM =  0.7588500407181288
Batch =  975  Train Loss =  0.0043430239412312706  SSIM =  0.7590584739049275
Batch =  1000  Train Loss =  0.0043590031834319235  SSIM =  0.7587548944503069
Batch =  1025  Train Loss =  0.004344939658668165  SSIM =  0.7593044965587011
Batch =  1050  Train Loss =  0.004348146293777972  SSIM =  0.75989340055201

Batch =  1125  Train Loss =  0.003192560206102725  SSIM =  0.7719193821913666
Batch =  1150  Train Loss =  0.003195160064813377  SSIM =  0.7718948610872031
Batch =  1175  Train Loss =  0.0031991457140464535  SSIM =  0.7715363407927625
Batch =  1200  Train Loss =  0.003207601931038274  SSIM =  0.7712726442795247
Batch =  1225  Train Loss =  0.003207417959513675  SSIM =  0.7714798916633032
Batch =  1250  Train Loss =  0.0032102135952329263  SSIM =  0.7716656857162714
Batch =  1275  Train Loss =  0.0032180467825474253  SSIM =  0.7713639074476326
Batch =  1300  Train Loss =  0.0032084929260264078  SSIM =  0.7715025214077188
Batch =  1325  Train Loss =  0.003211693874481222  SSIM =  0.7715302766184762
Batch =  1350  Train Loss =  0.0032204305844950594  SSIM =  0.7710920913048365
Batch =  1375  Train Loss =  0.0032114532570862633  SSIM =  0.771312474686991
Batch =  1400  Train Loss =  0.0032119661251116278  SSIM =  0.7713807400768357
Batch =  1425  Train Loss =  0.0032006008882614735  SSIM =

Batch =  1475  Train Loss =  0.0030913917214108505  SSIM =  0.7762815603639109
Batch =  1500  Train Loss =  0.0030982607500918673  SSIM =  0.7759709618762135
Batch =  1525  Train Loss =  0.003093203701084617  SSIM =  0.7762438558921462
Batch =  1550  Train Loss =  0.003090053397078713  SSIM =  0.7765819236900537
Batch =  1575  Train Loss =  0.0030942651380633287  SSIM =  0.7763731043844943
Batch =  1600  Train Loss =  0.0030874161937299504  SSIM =  0.7765915574110113
Batch =  1625  Train Loss =  0.0030953907664202584  SSIM =  0.7766198519720481
Batch =  1650  Train Loss =  0.0030918775815216852  SSIM =  0.7768666698422396
Batch =  1675  Train Loss =  0.0030922511251514724  SSIM =  0.7768577472554214
Batch =  1700  Train Loss =  0.00309661461145665  SSIM =  0.7766431876381531
Batch =  1725  Train Loss =  0.003094222648971903  SSIM =  0.7768023140512514
Batch =  1750  Train Loss =  0.003091253862974034  SSIM =  0.776823261531336
Batch =  1775  Train Loss =  0.00308206279584187  SSIM =  0

Batch =  1825  Train Loss =  0.0030742654240489236  SSIM =  0.7811169219690643
Batch =  1850  Train Loss =  0.003087057042654566  SSIM =  0.7806518769928732
Batch =  1875  Train Loss =  0.0030915141214694205  SSIM =  0.7801189768572648
Batch =  1900  Train Loss =  0.0030810025621950013  SSIM =  0.7806921087460298
Batch =  1925  Train Loss =  0.003080845190138048  SSIM =  0.7805148883473564
Batch =  1950  Train Loss =  0.003085680568004206  SSIM =  0.7801362512375299
Batch =  1975  Train Loss =  0.0030821947687043206  SSIM =  0.780556160311533
Batch =  2000  Train Loss =  0.003087416369577113  SSIM =  0.7803976668287068
Batch =  2025  Train Loss =  0.003096990434784891  SSIM =  0.7799243863534044
Batch =  2050  Train Loss =  0.003102531190869319  SSIM =  0.779629661379064
Batch =  2075  Train Loss =  0.0031004698348795455  SSIM =  0.7796815175787513
Batch =  2100  Train Loss =  0.0031019428419780783  SSIM =  0.7796739663077252
Batch =  2125  Train Loss =  0.003103516168700641  SSIM =  0

Batch =  2175  Train Loss =  0.0030947238768868407  SSIM =  0.778871737757976
Batch =  2200  Train Loss =  0.003102096862620039  SSIM =  0.778642097208649
Loader #:  0  Epoch:  7 /  2500  -  Time:  0:04:09.132578 s - Validation Loss:  0.0  - Train Loss:  0.003102096862620039  - train SSIM:  0.778642097208649  validation SSIM:  1713.0126138590276
Batch =  1  Train Loss =  0.0034896358847618103  SSIM =  0.7317146956920624
Batch =  25  Train Loss =  0.0028736126003786923  SSIM =  0.7915590351819992
Batch =  50  Train Loss =  0.002777359586325474  SSIM =  0.7800050657987595
Batch =  75  Train Loss =  0.002677100361712898  SSIM =  0.7844027398029964
Batch =  100  Train Loss =  0.002814306407526601  SSIM =  0.7801142793893814
Batch =  125  Train Loss =  0.0028487937869504094  SSIM =  0.7833125944137573
Batch =  150  Train Loss =  0.002863559191270421  SSIM =  0.7836991451183954
Batch =  175  Train Loss =  0.0028464900642367346  SSIM =  0.7846872872114181
Batch =  200  Train Loss =  0.0029043

Batch =  250  Train Loss =  0.003271714201953728  SSIM =  0.770489084661007
Batch =  275  Train Loss =  0.0032530897291144357  SSIM =  0.7718002804301002
Batch =  300  Train Loss =  0.003266960070711017  SSIM =  0.7703640607992808
Batch =  325  Train Loss =  0.0032672493650954073  SSIM =  0.7694063438819005
Batch =  350  Train Loss =  0.0032992450481729714  SSIM =  0.7671142120446478
Batch =  375  Train Loss =  0.0032714507275183376  SSIM =  0.7683396993478139
Batch =  400  Train Loss =  0.0032384985545286325  SSIM =  0.7705017585679889
Batch =  425  Train Loss =  0.003239193599883412  SSIM =  0.7700715032044579
Batch =  450  Train Loss =  0.0032261210468843477  SSIM =  0.770528435740206
Batch =  475  Train Loss =  0.003200655746806756  SSIM =  0.7711695428584752
Batch =  500  Train Loss =  0.0032040384663851  SSIM =  0.7712117271721363
Batch =  525  Train Loss =  0.003199968310571941  SSIM =  0.7718220379522869
Batch =  550  Train Loss =  0.003204541736737486  SSIM =  0.77146753603761

Batch =  600  Train Loss =  0.0029768216250522527  SSIM =  0.7821094790597757
Batch =  625  Train Loss =  0.0029654431622009726  SSIM =  0.7816044093370438
Batch =  650  Train Loss =  0.0029756873895754464  SSIM =  0.7807073315519553
Batch =  675  Train Loss =  0.0029983419510622128  SSIM =  0.779312359823121
Batch =  700  Train Loss =  0.0030117231496842576  SSIM =  0.7792375825132642
Batch =  725  Train Loss =  0.002997830272652209  SSIM =  0.7802560262433413
Batch =  750  Train Loss =  0.0029809717434303214  SSIM =  0.7803068728248278
Batch =  775  Train Loss =  0.0029776357015925308  SSIM =  0.7806800186634064
Batch =  800  Train Loss =  0.00297586015331035  SSIM =  0.7809019935876131
Batch =  825  Train Loss =  0.0029799843797544863  SSIM =  0.7812115350455949
Batch =  850  Train Loss =  0.002980474454063155  SSIM =  0.7816314440790345
Batch =  875  Train Loss =  0.0029818507891281377  SSIM =  0.7815921046904155
Batch =  900  Train Loss =  0.0029804066059619396  SSIM =  0.78154150

Batch =  950  Train Loss =  0.0031190612775236857  SSIM =  0.7823336828225538
Batch =  975  Train Loss =  0.0031258048124623316  SSIM =  0.7818213746792231
Batch =  1000  Train Loss =  0.0031232828889915256  SSIM =  0.7819427665472031
Batch =  1025  Train Loss =  0.003114493347437507  SSIM =  0.7824488812685013
Batch =  1050  Train Loss =  0.0031017793143733536  SSIM =  0.783002932710307
Batch =  1075  Train Loss =  0.003135016529311881  SSIM =  0.7816911805923595
Batch =  1100  Train Loss =  0.0031223856240235777  SSIM =  0.7820280502736568
Batch =  1125  Train Loss =  0.0031072359836349884  SSIM =  0.7824186215930515
Batch =  1150  Train Loss =  0.0031093924326579207  SSIM =  0.7822683750028195
Batch =  1175  Train Loss =  0.003115233487469402  SSIM =  0.781755488451491
Batch =  1200  Train Loss =  0.0031036904126085572  SSIM =  0.7824176789199313
Batch =  1225  Train Loss =  0.003107174885262051  SSIM =  0.7819002103318974
Batch =  1250  Train Loss =  0.003106102610810194  SSIM =  0

Batch =  1300  Train Loss =  0.003091409479216404  SSIM =  0.7791423220789203
Batch =  1325  Train Loss =  0.003097285915150764  SSIM =  0.778842470508139
Batch =  1350  Train Loss =  0.003087262179437352  SSIM =  0.7792806511427517
Batch =  1375  Train Loss =  0.003077221319936639  SSIM =  0.7797594768215309
Batch =  1400  Train Loss =  0.0030720037049468372  SSIM =  0.7795191246244524
Batch =  1425  Train Loss =  0.0030634617335326447  SSIM =  0.7796998240419647
Batch =  1450  Train Loss =  0.0030753144965073543  SSIM =  0.7792451221341716
Batch =  1475  Train Loss =  0.0030695437636005423  SSIM =  0.7796664479400142
Batch =  1500  Train Loss =  0.003088959076364214  SSIM =  0.7787642786478003
Batch =  1525  Train Loss =  0.0030927912834421045  SSIM =  0.7788211634095575
Batch =  1550  Train Loss =  0.003084393329355085  SSIM =  0.7794797086883937
Batch =  1575  Train Loss =  0.003089188546448621  SSIM =  0.7791244068292398
Batch =  1600  Train Loss =  0.003091754588349431  SSIM =  0

Batch =  1650  Train Loss =  0.003087055037948162  SSIM =  0.7793354682917848
Batch =  1675  Train Loss =  0.0030866580706527596  SSIM =  0.7793829981493416
Batch =  1700  Train Loss =  0.003092438823557185  SSIM =  0.7792763994524584
Batch =  1725  Train Loss =  0.0030906257731249504  SSIM =  0.7792745630702247
Batch =  1750  Train Loss =  0.003091289583967799  SSIM =  0.7791504060519593
Batch =  1775  Train Loss =  0.0030920869200207575  SSIM =  0.7790041027585386
Batch =  1800  Train Loss =  0.0030868350140756066  SSIM =  0.7791766882501543
Batch =  1825  Train Loss =  0.0030881437788097895  SSIM =  0.7790474486575552
Batch =  1850  Train Loss =  0.003081908559556846  SSIM =  0.7796113112951453
Batch =  1875  Train Loss =  0.0030839756248965083  SSIM =  0.7795871953388055
Batch =  1900  Train Loss =  0.00308560115859444  SSIM =  0.7792311363216293
Batch =  1925  Train Loss =  0.0030850722308572256  SSIM =  0.779283864116901
Batch =  1950  Train Loss =  0.0030938080772709023  SSIM = 

Batch =  2000  Train Loss =  0.0031069120233150897  SSIM =  0.7782062455564738
Batch =  2025  Train Loss =  0.0031061212293499  SSIM =  0.7783116244975431
Batch =  2050  Train Loss =  0.003107757278385845  SSIM =  0.7782331279865126
Batch =  2075  Train Loss =  0.003102641662495794  SSIM =  0.778449426645256
Batch =  2100  Train Loss =  0.003104677383020143  SSIM =  0.7786480883402483
Batch =  2125  Train Loss =  0.0030945091961663874  SSIM =  0.7792159280075747
Batch =  2150  Train Loss =  0.0030984855517149404  SSIM =  0.7789989973015563
Batch =  2175  Train Loss =  0.003098878314178009  SSIM =  0.7790861517670511
Batch =  2200  Train Loss =  0.0031058137339740905  SSIM =  0.7788160467892885
Loader #:  0  Epoch:  14 /  2500  -  Time:  0:04:09.264203 s - Validation Loss:  0.0  - Train Loss:  0.0031058137339740905  - train SSIM:  0.7788160467892885  validation SSIM:  1713.3953029364347
Batch =  1  Train Loss =  0.0012357706436887383  SSIM =  0.9006042182445526
Batch =  25  Train Loss =

Batch =  75  Train Loss =  0.0028841786684157948  SSIM =  0.7854289493958155
Batch =  100  Train Loss =  0.002808906235732138  SSIM =  0.7890642014145851
Batch =  125  Train Loss =  0.002839019351406023  SSIM =  0.7849629414081574
Batch =  150  Train Loss =  0.002907514341835243  SSIM =  0.7816949301958084
Batch =  175  Train Loss =  0.002950042761097263  SSIM =  0.7802728175265449
Batch =  200  Train Loss =  0.0030185347357473804  SSIM =  0.7776013704389334
Batch =  225  Train Loss =  0.0030432726110383454  SSIM =  0.7775490064091153
Batch =  250  Train Loss =  0.00303743092273362  SSIM =  0.7793477480411529
Batch =  275  Train Loss =  0.002968073976226151  SSIM =  0.7827562207525427
Batch =  300  Train Loss =  0.003059631379728671  SSIM =  0.7801860929528872
Batch =  325  Train Loss =  0.0031454424378283036  SSIM =  0.7770981231561074
Batch =  350  Train Loss =  0.0031026256076958298  SSIM =  0.7773273464185851
Batch =  375  Train Loss =  0.0031108097650576383  SSIM =  0.777270465095

KeyboardInterrupt: 

ite = 2
s = '~/GANO/Figures/GANO_GRF/{}GRF_HistogramGANO.pdf'.format(ite)
print(s)

In [None]:
import cv2
#import pylab as plt

import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
#from skimage.metrics import peak_signal_noise_ratio as PSNR
from math import log10, sqrt
myloss = torch.nn.L1Loss()
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr
def mse(imageA, imageB):
    # the 'Mean Squared Error' between the two images is the sum of the squared difference between the two images
    mse_error = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    mse_error /= float(imageA.shape[0] * imageA.shape[1])
    return mse_error
#model_loaded = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_rgb595.pt')
base_G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_lowres_rgb300.pt')
complex_G = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_lowres_vgg_rgb_huge940.pt')
backSub = cv2.createBackgroundSubtractorMOG2()

#x_train = torch.load('/scratch/gilbreth/hviswan/full_res_vimeo_rgb.pt')
#y_train = torch.load('/scratch/gilbreth/hviswan/full_res_gt_rgb.pt')
#x_train = torch.permute(x_train, (1, 0,2,3,4))
#print(x_train.shape)
#ntrain = x_train.shape[1]
#y_train = y_train.reshape(ntrain, x_train.shape[2], x_train.shape[3], x_train.shape[4])

#y_normalizer = MinMaxNormalizer(y_train)
#x_normalizer = MinMaxNormalizer(x_train)
#x_train = x_normalizer.encode(x_train)
#y_train = y_normalizer.encode(y_train)

#test_data = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train[0], x_train[1], y_train), batch_size=1, shuffle=True)
complex_G.eval()
base_G.eval()
#test_data, y_normalizer = LoadDataBatches(0, batches=10, isNormalized=True, isTrain=False)
"""


input_l = x_test[0][71].reshape(1, 100, 100).cpu().detach().numpy()
input_r = x_test[1][71].reshape(1, 100, 100).cpu().detach().numpy()
input = torch.from_numpy(np.asarray([input_l,input_r]).astype(np.float32))



model_out = model_loaded(input.cuda())
model_out = model_out.reshape(100,100).cpu().detach().numpy()
output = y_test[71].reshape(100,100).cpu().detach().numpy()
#model_out = y_normalizer.decode(model_out).cpu().detach().numpy()
model_out = model_out.reshape(100,100)
plt.imshow(model_out)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/vimeo_generated.png', model_out)
plt.imshow( output)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/vimeo_groundTruth.png', output)
#x_train = torch.permute(x_train, (0,2,3,1))
#input = torch.permute(input, (0,3,1,2))
#input1 = ((input[0][0]+input[0][1])/2.0).reshape(100,100).cpu().detach().numpy()
input1 = ((input_l+ input_r)/2.0).reshape(100, 100)
plt.imshow( input1)
plt.imsave('/depot/bera89/data/hviswan/vimeo_mean.png', input1)
plt.imsave('/depot/bera89/data/hviswan/vimeo_left.png', input_l.reshape(100,100))
plt.imsave('/depot/bera89/data/hviswan/vimeo_right.png', input_r.reshape(100,100))
plt.show()

print(ssim(model_out, output))
print(ssim(input1, output))
print(ssim(model_out, input1))
print(mse(model_out, output))
print(mse(input1, output))
print(mse(model_out, input1))


"""
train_data, y_normalizer = LoadDataBatches(0, isNormalized=True, batches=2, isTrain=False)
loss_D = 0.0
loss_G = 0.0
ssim_G = 0.0
train_counter = 0
ssim_batch=0
psnr_G = 0.0
counter = 0.0
print_counter = 0
for xl,xr,y in train_data:
    train_counter += 1
    mean = (xl[0]+xr[0])/2.0
    mean = y_normalizer.decode(mean)
    mean = mean.cpu().detach().numpy()
    xl = xl.cpu().detach().numpy()
    xr = xr.cpu().detach().numpy()
    x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
    x = x.cuda()
    y = y.cuda()
    base_x_out = base_G(x)
    x = torch.stack((base_x_out, base_x_out))
    x_syn = complex_G(x).reshape(x.shape[1], x.shape[2],x.shape[3],x.shape[4])
    y = y.reshape(y.shape[0], y.shape[1],y.shape[2],y.shape[3])
    x_syn = y_normalizer.decode(x_syn)
    y = y_normalizer.decode(y)
    #y2 = y_normalizer.encode(y).detach().cpu().numpy()
    #x2 = y_normalizer.encode(x_syn).detach().cpu().numpy()
    W_loss = myloss(x_syn, y).item()
    y = y.detach().cpu().numpy()
    x_syn = x_syn.detach().cpu().numpy()
    ssval = ssim(x_syn[0], y[0], multichannel=True)
    #print("SSVAL = ", ssval)
    
    if(ssval>0.75 and print_counter<20):
        mean = cv2.cvtColor(mean, cv2.COLOR_BGR2RGB)
        plt.imshow(mean.astype('uint8'))
        plt.show()
        
        x_syn_show = cv2.cvtColor(x_syn[0], cv2.COLOR_BGR2RGB)
        plt.imshow(x_syn_show.astype('uint8'))
        plt.show()
        gt = cv2.cvtColor(y[0], cv2.COLOR_BGR2RGB)
        plt.imshow(gt.astype('uint8'))
        plt.show()
        print("SSIM MEANvSYN ", ssim(mean, x_syn[0], multichannel=True))
        print("SSIM MEANvEXP ", ssim(mean, y[0], multichannel=True))
        print("SSIM SYNvEXP ", ssim(x_syn[0], y[0], multichannel=True))
        print("SSIM = ", ssim(x_syn[0], y[0], multichannel=True))
        print_counter += 1
    #if(x.shape[1]>=7):
    ssim_batch += ssim(x_syn[0], y[0], multichannel=True)
    ssim_G = ssim_batch
    #loss = loss_D_real + loss_D_syn
    loss = W_loss
    #print(loss.shape)
    #loss.backward()
    loss_G += loss
    psnr_G += PSNR(x_syn, y)
    if(train_counter %5 == 0):
        print("Batch = ",train_counter," Train Loss = ", loss_G/train_counter, " SSIM = ", ssim_batch/(train_counter-1), " PSNR = ", psnr_G/train_counter)
losses_D = loss_D / train_counter
losses_G = loss_G / train_counter
ssims_G = ssim_G / train_counter
print(losses_G)
print(ssims_G)
print(psnr_G /train_counter)
print(counter)

In [None]:
"""plt.plot(np.arange(epochs), losses_D, c='k', label='D')
plt.plot(np.arange(epochs), losses_G, c='b', label='G')
plt.legend()
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.show()"""
print(losses_G)
print(ssims_G)

# DISFA+ Dataset

In [None]:
import cv2
#import pylab as plt
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
#from skimage.metrics import peak_signal_noise_ratio as PSNR
from math import log10, sqrt
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 10 * log10(max_pixel / sqrt(mse))
    return psnr
def mse(imageA, imageB):
    # the 'Mean Squared Error' between the two images is the sum of the squared difference between the two images
    mse_error = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    mse_error /= float(imageA.shape[0] * imageA.shape[1])
    return mse_error
model_loaded = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_full60.pt')
model_loaded.eval()
x_train = torch.load('../Fall22 IDEAS/x_train.pt')
y_train = torch.load('../Fall22 IDEAS/y_train.pt')
x_test = torch.load('../Fall22 IDEAS/x_test.pt')
y_test = torch.load('../Fall22 IDEAS/y_test.pt')
x_train = torch.permute(x_train, (1,0,2,3))
x_test = torch.permute(x_test, (1,0,2,3))
y_normalizer = MinMaxNormalizer(y_train)
x_normalizer = MinMaxNormalizer(x_train)
#x_train = x_normalizer.encode(x_train)
#x_train = y_normalizer.encode(y_train)
y_normalizer = MinMaxNormalizer(y_test)
x_normalizer = MinMaxNormalizer(x_test)
#x_test = x_normalizer.encode(x_test)
#x_test = y_normalizer.encode(y_test)
print(x_train.shape)
print(y_train.shape)
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train[0], x_train[1], y_train), batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test[0], x_test[1], y_test), batch_size=10, shuffle=True)
input_l = x_test[0][12].reshape(1, 100, 100).cpu().detach().numpy()
input_r = x_test[1][12].reshape(1, 100, 100).cpu().detach().numpy()
input = torch.from_numpy(np.asarray([input_l,input_r]).astype(np.float32))



model_out = model_loaded(input.cuda())
model_out = model_out.reshape(100,100).cpu().detach().numpy()
output = y_test[12].reshape(100,100).cpu().detach().numpy()
#model_out = y_normalizer.decode(model_out).cpu().detach().numpy()
model_out = model_out.reshape(100,100)
plt.imshow( model_out)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/disfa_generated.png', model_out)
plt.imshow( output)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/disfa_groundTruth.png', output)
#x_train = torch.permute(x_train, (0,2,3,1))
#input = torch.permute(input, (0,3,1,2))
#input1 = ((input[0][0]+input[0][1])/2.0).reshape(100,100).cpu().detach().numpy()
input1 = ((input_l+ input_r)/2.0).reshape(100, 100)
plt.imshow( input1)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/disfa_mean.png', input1)
plt.imsave('/depot/bera89/data/hviswan/disfa_left.png', input_l.reshape(100,100))
plt.imsave('/depot/bera89/data/hviswan/disfa_right.png', input_r.reshape(100,100))
print(ssim(model_out, output))
print(ssim(input1, output))
print(ssim(model_out, input1))
print(mse(model_out, output))
print(mse(input1, output))
print(mse(model_out, input1))
print(PSNR(model_out, output))
print(PSNR(input1, output))
print(PSNR(model_out, input1))

loss_D = 0.0
loss_G = 0.0
ssim_G = 0.0
train_counter = 0
ssim_batch=0
psnr_G = 0.0
for xl,xr,y in test_loader:
    train_counter += 1
    xl = xl.cpu().detach().numpy()
    xr = xr.cpu().detach().numpy()
    x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
    x = x.cuda()
    y = y.cuda()

    x_syn = model_loaded(x).reshape(x.shape[1], x.shape[2],x.shape[3],x.shape[4])
    y = y.reshape(x.shape[1], 100, 100)
    y2 = y_normalizer.encode(y).detach().cpu().numpy()
    x2 = y_normalizer.encode(x_syn).detach().cpu().numpy()

    y = y.detach().cpu().numpy()
    x_syn = x_syn.detach().cpu().numpy()
    W_loss = myloss(x_syn, y).item()
    if(x.shape[1]>=7):
        ssim_batch += ssim(x_syn, y, multichannel=True)
        ssim_G = ssim_batch

    loss = W_loss
    psnr_G += PSNR(x2, y2)

    loss_G += loss

    if(train_counter %10 == 0):
        print("Batch = ",train_counter," Train Loss = ", loss_G/train_counter, " SSIM = ", ssim_batch/(train_counter))
losses_D = loss_D / train_counter
losses_G = loss_G / train_counter
ssims_G = ssim_G / train_counter
print(losses_G)
print(ssims_G)
print(psnr_G / train_counter)

# DAVIS Dataset

In [None]:
import cv2
#import pylab as plt
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
#from skimage.metrics import peak_signal_noise_ratio as PSNR
from math import log10, sqrt
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = np.max(original)
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr
def mse(imageA, imageB):
    # the 'Mean Squared Error' between the two images is the sum of the squared difference between the two images
    mse_error = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    mse_error /= float(imageA.shape[0] * imageA.shape[1])
    return mse_error
model_loaded = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_full60.pt')
model_loaded.eval()
x_train = torch.load('../Fall22 IDEAS/x_train_davis.pt')
y_train = torch.load('../Fall22 IDEAS/y_train_davis.pt')
x_test = torch.load('../Fall22 IDEAS/x_test_davis.pt')
y_test = torch.load('../Fall22 IDEAS/y_test_davis.pt')
x_train = torch.permute(x_train, (1,0,2,3))
x_test = torch.permute(x_test, (1,0,2,3))
y_normalizer = MinMaxNormalizer(y_train)
x_normalizer = MinMaxNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
y_train = y_normalizer.encode(y_train)
y_normalizer = MinMaxNormalizer(y_test)
x_normalizer = MinMaxNormalizer(x_test)
x_test = x_normalizer.encode(x_test)
y_test = y_normalizer.encode(y_test)
print(x_train.shape)
print(y_train.shape)
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train[0], x_train[1], y_train), batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test[0], x_test[1], y_test), batch_size=10, shuffle=True)
input_l = x_test[0][11].reshape(1, 100, 100).cpu().detach().numpy()
input_r = x_test[1][11].reshape(1, 100, 100).cpu().detach().numpy()
input = torch.from_numpy(np.asarray([input_l,input_r]).astype(np.float32))



model_out = model_loaded(input.cuda())
model_out = model_out.reshape(100,100).cpu().detach().numpy()
output = y_test[11].reshape(100,100).cpu().detach().numpy()
#model_out = y_normalizer.decode(model_out).cpu().detach().numpy()
model_out = model_out.reshape(100,100)
plt.imshow( model_out)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/davis_generated.png', model_out)
plt.imshow( output)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/davis_groundTruth.png', output)
#x_train = torch.permute(x_train, (0,2,3,1))
#input = torch.permute(input, (0,3,1,2))
#input1 = ((input[0][0]+input[0][1])/2.0).reshape(100,100).cpu().detach().numpy()
input1 = ((input_l+ input_r)/2.0).reshape(100, 100)
plt.imshow( input1)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/davis_mean.png', input1)
plt.imsave('/depot/bera89/data/hviswan/davis_left.png', input_l.reshape(100,100))
plt.imsave('/depot/bera89/data/hviswan/davis_right.png', input_r.reshape(100,100))
print(ssim(model_out, output))
print(ssim(input1, output))
print(ssim(model_out, input1))
print(mse(model_out, output))
print(mse(input1, output))
print(mse(model_out, input1))


loss_D = 0.0
loss_G = 0.0
ssim_G = 0.0
train_counter = 0
ssim_batch=0
psnr_G = 0.0
for xl,xr,y in train_loader:
    train_counter += 1
    xl = xl.cpu().detach().numpy()
    xr = xr.cpu().detach().numpy()
    x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
    x = x.cuda()
    y = y.cuda()
    #D_optimizer.zero_grad()
    #D_optimizer.zero_grad()
    #G_optim.zero_grad()
    #print(y.shape)
    #y_real = D(y)
    #print(y_real.shape)
    #loss_D_real = fn_loss(y_real, torch.ones_like(y_real).cuda())
    #print(x.shape)
    x_syn = model_loaded(x).reshape(x.shape[1], 100,100)
    y = y.reshape(x.shape[1], 100, 100)
    #print(x_syn.shape)
    #print(x.shape)
    #print(y.shape)
    #print(x_syn.shape)
    #print(y.shape)
    #print(x_syn.shape)
    #print(y.shape)
    #print(x.shape)
    #print("___")
    #print(x_syn.shape)
    #y_syn = D(x_syn)
    #print(y_syn.shape)
    #loss_D_syn = fn_loss(y_syn, torch.zeros_like(y_syn).cuda())
    y2 = y_normalizer.encode(y).detach().cpu().numpy()
    x2 = y_normalizer.encode(x_syn).detach().cpu().numpy()
    y = y.detach().cpu().numpy()
    x_syn = x_syn.detach().cpu().numpy()
    W_loss = mse(x_syn, y)
    if(x.shape[1]>=7):
        ssim_batch += ssim(x_syn, y)
        ssim_G = ssim_batch
    #loss = loss_D_real + loss_D_syn
    loss = W_loss
    #print(loss.shape)
    #loss.backward()
    loss_G += loss
    #loss_D += loss.item()
    psnr_G += PSNR(x_syn, y)       
    #D_optim.step()
            
    #x_syn = G(x)
    #y_syn = D(x_syn)
            
    #loss = fn_loss(y_syn, torch.ones_like(y_syn).cuda())
    #loss = myloss(y, x_syn)
    #loss.backward()
    #loss_G += loss.item()
            
    #G_optim.step()
    if(train_counter %50 == 0):
        print("Batch = ",train_counter," Train Loss = ", loss_G/train_counter, " SSIM = ", ssim_batch/(train_counter))
losses_D = loss_D / train_counter
losses_G = loss_G / train_counter
ssims_G = ssim_G / train_counter
print(losses_G)
print(ssims_G)
print(psnr_G / (train_counter))

# UCF101 Dataset

In [None]:
import cv2
#import pylab as plt
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
#from skimage.metrics import peak_signal_noise_ratio as PSNR
from math import log10, sqrt
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 10 * log10(max_pixel / sqrt(mse))
    return psnr
def mse(imageA, imageB):
    # the 'Mean Squared Error' between the two images is the sum of the squared difference between the two images
    mse_error = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    mse_error /= float(imageA.shape[0] * imageA.shape[1])
    return mse_error
model_loaded = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_full60.pt')
model_loaded.eval()
x_train = torch.load('../Fall22 IDEAS/x_train_ucf.pt')
y_train = torch.load('../Fall22 IDEAS/y_train_ucf.pt')
x_test = torch.load('../Fall22 IDEAS/x_test_ucf.pt')
y_test = torch.load('../Fall22 IDEAS/y_test_ucf.pt')
x_train = torch.permute(x_train, (1,0,2,3))
x_test = torch.permute(x_test, (1,0,2,3))
y_normalizer = MinMaxNormalizer(y_train)
x_normalizer = MinMaxNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
y_train = y_normalizer.encode(y_train)
y_normalizer = MinMaxNormalizer(y_test)
x_normalizer = MinMaxNormalizer(x_test)
x_test = x_normalizer.encode(x_test)
y_test = y_normalizer.encode(y_test)
print(x_train.shape)
print(y_train.shape)
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train[0], x_train[1], y_train), batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test[0], x_test[1], y_test), batch_size=10, shuffle=True)
input_l = x_test[0][23].reshape(1, 100, 100).cpu().detach().numpy()
input_r = x_test[1][23].reshape(1, 100, 100).cpu().detach().numpy()
input = torch.from_numpy(np.asarray([input_l,input_r]).astype(np.float32))



model_out = model_loaded(input.cuda())
model_out = model_out.reshape(100,100).cpu().detach().numpy()
output = y_test[23].reshape(100,100).cpu().detach().numpy()
#model_out = y_normalizer.decode(model_out).cpu().detach().numpy()
model_out = model_out.reshape(100,100)
plt.imshow( model_out)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/ucf_generated.png', model_out)
plt.imshow( output)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/ucf_groundTruth.png', output)
#x_train = torch.permute(x_train, (0,2,3,1))
#input = torch.permute(input, (0,3,1,2))
#input1 = ((input[0][0]+input[0][1])/2.0).reshape(100,100).cpu().detach().numpy()
input1 = ((input_l+ input_r)/2.0).reshape(100, 100)
plt.imshow( input1)
plt.show()
plt.imsave('/depot/bera89/data/hviswan/ucf_mean.png', input1)
plt.imsave('/depot/bera89/data/hviswan/ucf_left.png', input_l.reshape(100,100))
plt.imsave('/depot/bera89/data/hviswan/ucf_right.png', input_r.reshape(100,100))
print(ssim(model_out, output))
print(ssim(input1, output))
print(ssim(model_out, input1))
print(mse(model_out, output))
print(mse(input1, output))
print(mse(model_out, input1))

loss_D = 0.0
loss_G = 0.0
ssim_G = 0.0
train_counter = 0
ssim_batch=0
psnr_G = 0.0
for xl,xr,y in train_loader:
    train_counter += 1
    xl = xl.cpu().detach().numpy()
    xr = xr.cpu().detach().numpy()
    x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
    x = x.cuda()
    y = y.cuda()
    #D_optimizer.zero_grad()
    #D_optimizer.zero_grad()
    #G_optim.zero_grad()
    #print(y.shape)
    #y_real = D(y)
    #print(y_real.shape)
    #loss_D_real = fn_loss(y_real, torch.ones_like(y_real).cuda())
    #print(x.shape)
    x_syn = model_loaded(x).reshape(x.shape[1], 100,100)
    y = y.reshape(x.shape[1], 100, 100)
    y2 = y_normalizer.encode(y)
    x2 = y_normalizer.encode(x_syn)
    #print(x_syn.shape)
    #print(x.shape)
    #print(y.shape)
    #print(x_syn.shape)
    #print(y.shape)
    #print(x_syn.shape)
    #print(y.shape)
    #print(x.shape)
    #print("___")
    #print(x_syn.shape)
    #y_syn = D(x_syn)
    #print(y_syn.shape)
    #loss_D_syn = fn_loss(y_syn, torch.zeros_like(y_syn).cuda())
    y = y.detach().cpu().numpy()
    x_syn = x_syn.detach().cpu().numpy()
    W_loss = mse(x_syn, y)
    psnr_G += PSNR(x_syn, y)
    if(x.shape[1]>=7):
        ssim_batch += ssim(x_syn, y)
        ssim_G = ssim_batch
    #loss = loss_D_real + loss_D_syn
    loss = W_loss
    #print(loss.shape)
    #loss.backward()
    loss_G += loss
    #loss_D += loss.item()
            
    #D_optim.step()
            
    #x_syn = G(x)
    #y_syn = D(x_syn)
            
    #loss = fn_loss(y_syn, torch.ones_like(y_syn).cuda())
    #loss = myloss(y, x_syn)
    #loss.backward()
    #loss_G += loss.item()
            
    #G_optim.step()
    if(train_counter %50 == 0):
        print("Batch = ",train_counter," Train Loss = ", loss_G/train_counter, " SSIM = ", ssim_batch/(train_counter))
losses_D = loss_D / train_counter
losses_G = loss_G / train_counter
ssims_G = ssim_G / train_counter
print(losses_G)
print(ssims_G)
print(psnr_G / (train_counter))

In [None]:
from PIL import Image
model_loaded = torch.load('/depot/bera89/data/hviswan/UNO_2f_vimeo_full60.pt')
model_loaded.eval()
left = torch.load('../Fall22 IDEAS/left_tig.pt')
right = torch.load('../Fall22 IDEAS/right_tig.pt')
print(left.shape)
#image_1_l = left[0].detach().cpu().numpy()
#image_1_r = right[0].detach().cpu().numpy()

image_1_l = torch.permute(left, (1,2,0)).detach().cpu().numpy()
image_1_r = torch.permute(right, (1,2,0)).detach().cpu().numpy()

slicer =  torch.from_numpy(np.asarray([image_1_l, image_1_l]).astype(np.float32)).reshape(2, 1, 700, 500, 1)
print("HERE ", slicer.shape)
x_normalizer = MinMaxNormalizer(slicer)
#slicer = x_normalizer.encode(slicer)
model_out = model_loaded(slicer.cuda()).reshape(700, 500).cpu().detach().numpy()
plt.style.use('grayscale')
plt.imshow(model_out)
print(model_out.shape)
plt.imsave('../Fall22 IDEAS/output1tig.png', model_out)


# DAVIS RGB

In [None]:
import matplotlib.pyplot as plt
def mse(imageA, imageB):
    # the 'Mean Squared Error' between the two images is the sum of the squared difference between the two images
    mse_error = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    mse_error /= float(imageA.shape[0] * imageA.shape[1])
    return mse_error
x_train = torch.load('../Fall22 IDEAS/x_train_davis_rgb.pt')
y_train = torch.load('../Fall22 IDEAS/y_train_davis_rgb.pt')
x_test = torch.load('../Fall22 IDEAS/x_test_davis_rgb.pt')
y_test = torch.load('../Fall22 IDEAS/y_test_davis_rgb.pt')
x_train = torch.permute(x_train, (1, 0,2,3,4))
x_test = torch.permute(x_test, (1, 0,2,3,4))

print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)
print(x_train[0][0].shape)
ntrain = x_train.shape[1]
ntest = x_test.shape[1]
y_train = y_train.reshape(ntrain, 100, 100,3)
y_test = y_test.reshape(ntest, 100, 100,3)
dim = (100, 100)
train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train[0], x_train[1], y_train), batch_size=1, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test[0], x_test[1], y_test), batch_size=1, shuffle=True)
loss_D = 0.0
loss_G = 0.0
ssim_G = 0.0
train_counter = 0
ssim_batch=0
psnr_G = 0.0
model_loaded = torch.load('/depot/bera89/data/hviswan/UNO_2f_davis_rgb5.pt')
model_loaded.eval()
input_l = x_train[0][12].reshape(1, 100, 100,3).cpu().detach().numpy()
input_r = x_train[1][12].reshape(1, 100, 100,3).cpu().detach().numpy()
output = y_train[12].reshape(100,100,3).cpu().detach().numpy()

input = torch.from_numpy(np.asarray([input_l,input_r]).astype(np.float32))
model_out = G(input.cuda()).reshape(100,100,3).cpu().detach().numpy()
print("OUTSHAPE = ", model_out.shape)
plt.imshow((model_out * 255).astype(np.uint8))
plt.show()
plt.imshow((output * 255).astype(np.uint8))
plt.show()

for xl,xr,y in train_loader:
    train_counter += 1
    xl = xl.cpu().detach().numpy()
    xr = xr.cpu().detach().numpy()
    x = torch.from_numpy(np.asarray([xl,xr]).astype(np.float32))
    x = x.cuda()
    y = y.cuda()
    x_syn = G(x).reshape(100,100,3).detach().cpu().numpy()
    y = y.reshape(100,100,3).detach().cpu().numpy()
    W_loss = mse(x_syn, y)
    ssim_batch += ssim(x_syn, y, multichannel=True)
    ssim_G = ssim_batch
    loss = W_loss
    loss_G += loss
    if(train_counter %50 == 0 or train_counter == 1):
        print("Batch = ",train_counter," Train Loss = ", loss_G/train_counter, " SSIM = ", ssim_batch/(train_counter))
print(loss_G/train_counter)
print(ssim_batch/train_counter)
