In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
from torchsummary import summary
from PIL import Image
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import numpy as np
import skimage
import skimage
from skimage import io
import time
from networks2.conv_layers import *
import sys

#
# This code refers to the implementation of DisplacementMLP+ described in the paper:
# D. Mangileva et. al. DisplacementMLP+: Unsupervised Neural Network for Dynamic Scene Analysis, 2025

###
# setting the size of the small coordinate grid (60x60)
st = 60
st0 = int(st/2)
st00 = int(240/st)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

def np_to_torch(img_np):
    return torch.from_numpy(img_np)[None, :]
###
# Sinelayer implementation, taken from the work:
# Sitzmann V. et al. Implicit neural representations with periodic activation functions, 2020

class SineLayer(nn.Module):

    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30,need_sigmoid = True, need_tanh = False):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first

        self.in_features = in_features
        self.linear = nn.Conv2d(in_features, out_features, bias=bias, kernel_size = 1, padding = 0)

        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                             1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                             np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

    def forward_with_intermediate(self, input):

        intermediate = self.omega_0 * self.linear(input)
        return torch.sin(intermediate), intermediate


class Siren(nn.Module):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30, need_sigmoid = True, need_tanh = False):
        super().__init__()

        self.net = []
        self.net.append(SineLayer(in_features, hidden_features,
                                  is_first=True, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(SineLayer(hidden_features, hidden_features,
                                      is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = nn.Conv2d(hidden_features, out_features, kernel_size = 1, padding = 0)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
                                              np.sqrt(6 / hidden_features) / hidden_omega_0)

            self.net.append(final_linear)
        else:
            self.net.append(SineLayer(hidden_features, out_features,
                                      is_first=False, omega_0=hidden_omega_0))
        if need_sigmoid:
            self.net.append(nn.Sigmoid())
        elif need_tanh:
            self.net.append(nn.Tanh())
        self.net = nn.Sequential(*self.net)

    def forward(self, coords):
        output = self.net(coords)
        return output, coords

    def forward_with_activations(self, coords, retain_grad=False):
        activations = OrderedDict()
        activation_count = 0
        x = coords.clone().detach().requires_grad_(True)
        activations['input'] = x
        for i, layer in enumerate(self.net):
            if isinstance(layer, SineLayer):
                x, intermed = layer.forward_with_intermediate(x)

                if retain_grad:
                    x.retain_grad()
                    intermed.retain_grad()

                activations['_'.join((str(layer.__class__), "%d" % activation_count))] = intermed
                activation_count += 1
            else:
                x = layer(x)

                if retain_grad:
                    x.retain_grad()

            activations['_'.join((str(layer.__class__), "%d" % activation_count))] = x
            activation_count += 1

        return activations
###
# link to the folder number where the pair of consecutive images is located
n = int(sys.argv[1]) - 1
###
# creating a large coordinate grid (size (240,240,2))
xy_grid_batch2 = []
coords_x2 = np.linspace(-1, 1, 240)
coords_y2 = np.linspace(-1, 1, 240)
xy_grid2 = np.stack(np.meshgrid(coords_x2, coords_y2), -1)
xy_grid_var2 = np_to_torch(xy_grid2.transpose(2,0,1)).type(dtype).cuda()
xy_grid_batch_var2 = xy_grid_var2.repeat(1, 1, 1, 1)
grid_input_single_gd2 = xy_grid_var2.detach().clone()
model_input2 = grid_input_single_gd2
###
# creating a small coordinate grid (size (60,60,2))
xy_grid_batch = []
coords_x = np.linspace(-1, 1, st)
coords_y = np.linspace(-1, 1, st)
xy_grid = np.stack(np.meshgrid(coords_x, coords_y), -1)
xy_grid_var = np_to_torch(xy_grid.transpose(2,0,1)).type(dtype).cuda()
xy_grid_batch_var = xy_grid_var.repeat(1, 1, 1, 1)
grid_input_single_gd = xy_grid_var.detach().clone()
model_input = grid_input_single_gd
###
#reading images
image0 = io.imread('./{n}/0.png'.format(n = n))
image1 = io.imread('./{n}/1.png'.format(n = n))
image = image0
images_warp_np = np.array(image).reshape(-1,240,240)
images_warp_np = np.array(images_warp_np/255,dtype = np.float32)
img_gt_batch_var = torch.from_numpy(images_warp_np).type(dtype).cuda()
ground_truth = img_gt_batch_var

image = image1
images_warp_np1 = np.array(image1).reshape(-1,240,240)
images_warp_np1 = np.array(images_warp_np1/255,dtype = np.float32)

img_gt_batch_var1 = torch.from_numpy(images_warp_np1).type(dtype).cuda()
ground_truth1 = img_gt_batch_var1
###
# creating MLP image generator
img_siren = Siren(in_features=2, out_features=1, hidden_features=256,
                        hidden_layers=4, outermost_linear=True).cuda()

###
# create a set of MLP image generators
models = [[] for i in range(st00)]
for v1 in range(st00):
    for v2 in range(st00):
        models[v1].append(Siren(in_features=2, out_features=1, hidden_features=256,
                        hidden_layers=4, outermost_linear=True).cuda())

###
# creating a set of model parameters
models = np.array(models)
model_params_list = [{'params':img_siren.parameters()}]
for v1 in range(st00):
    for v2 in range(st00):
        model_params_list.append({'params':models[v1,v2].parameters()})
###
# training of the image generator according to the novel proposed scheme
loss00 = []
total_steps = 3001
optim = torch.optim.Adam(lr=1e-4, params=model_params_list)
z05 = torch.zeros_like(ground_truth).cuda()
for step in range(total_steps):
    z00 = torch.zeros_like(ground_truth).cuda()
    z01 = torch.zeros_like(ground_truth).cuda()
    for v1, i in enumerate(range(st0,240,st)):
        for v2, j in enumerate(range(st0,240,st)):
            images_warp_np = np.array(image[i-st0:i+st0,j-st0:j+st0]).reshape(-1,st,st)
            images_warp_np = np.array(images_warp_np/255,dtype = np.float32)
            img_gt_batch_var = torch.from_numpy(images_warp_np).type(dtype).cuda()
            ground_truth0 = img_gt_batch_var
            img_siren2 = models[v1,v2]

            model_output, h = img_siren2(model_input)
            z00[:,i-st0:i+st0,j-st0:j+st0] = model_output[0]

            model_output, h = img_siren(model_input2[:,:,i-st0:i+st0,j-st0:j+st0])
            z01[:,i-st0:i+st0,j-st0:j+st0] = model_output[0]
    loss = torch.nn.functional.l1_loss(z00,ground_truth)
    loss += torch.nn.functional.mse_loss(z00,ground_truth)
    loss += torch.nn.functional.l1_loss(z01,z00)
    loss += torch.nn.functional.mse_loss(z01,z00)
    loss += torch.nn.functional.l1_loss(z01,ground_truth)
    loss += torch.nn.functional.mse_loss(z01,ground_truth)
    loss0 = torch.nn.functional.l1_loss(z01,ground_truth)
    loss0 += torch.nn.functional.mse_loss(z01,ground_truth)

    if step % 25 == 0:


        print('Epoch %d, loss = %.06f' % (step, float(loss0)))
        loss00.append(round(float(loss0),6))


    optim.zero_grad()
    loss.backward()
    optim.step()

model_params_list.append({'params':img_siren.parameters()})
total_steps = 301
for step in range(total_steps):

    model_output2, h = img_siren(model_input2)
    loss = torch.nn.functional.l1_loss(model_output2[0],ground_truth)
    loss += torch.nn.functional.mse_loss(model_output2[0],ground_truth)

    if step % 25 == 0:

        print('Epoch %d, loss = %.06f' % (step, float(loss)))
        loss00.append(round(float(loss),6))

    optim.zero_grad()
    loss.backward()
    optim.step()
loss00 = np.array(loss00)

np.save('./{n}/loss_new_0_60_0.npy'.format(n = n), loss00)
torch.save(img_siren.state_dict(), './{n}/best-model-img00.pt'.format(n = n))

###

# loading pre-trained MLP images generator with the best parameters obtained in the previous step

img_siren = Siren(in_features=2, out_features=1, hidden_features=256,
                        hidden_layers=4, outermost_linear=True).cuda()
img_siren.load_state_dict(torch.load('./{n}/best-model-img00.pt'.format(n = n)))
img_siren.eval()
###
# load pre-trained MLP grid generator with initialized weights
img_grid = conv_layers(2,2, num_hidden = 256, need_sigmoid = False, need_tanh = True)
img_grid.cuda()
img_grid.load_state_dict(torch.load('./best-model-grid02.pt'))
img_grid.eval()
###

# creating a set of model parameters
model_params_list = [{'params':img_grid.parameters()}]

# training of MLP grid generator to compute xy displacement field

optim = torch.optim.Adam(model_params_list, lr=1e-4)
total_steps = 5001
grid = []
losses = []
loss00 = []
for step in range(total_steps):
    h0 = img_grid(model_input2)
    h1 = h0 + torch.randn_like(h0)*0.0005
    h3 = 1.1*torch.cat([h0])
    grid_output0 = 1.1*torch.cat([h1])
    model_output0,w= img_siren(grid_output0)
    loss = torch.nn.functional.l1_loss(model_output0[0][:,18:221,18:221],ground_truth1[:,18:221,18:221])
    loss += torch.nn.functional.mse_loss(model_output0[0][:,18:221,18:221],ground_truth1[:,18:221,18:221])
    if step % 25 == 0:
        print('Epoch %d, loss = %.09f' % (step, float(loss)))
        loss00.append(round(float(loss),6))
    grid.append(h3)
    losses.append(round(float(loss),9))
    optim.zero_grad()
    loss.backward()
    optim.step()

loss00 = np.array(loss00)
np.save('./{n}/loss_new_1_60_0.npy'.format(n = n), loss00)
###


# convert and save displasement field to file new_60_0.npy
ind = losses.index(min(losses))
torch.save(grid[ind],'./{n}/tensor.pt'.format(n = n))
refined_xy = torch.load('./{n}/tensor.pt'.format(n = n))
z00 = torch.zeros(1,2,240,240).cuda()
refined_warp = (refined_xy - xy_grid_batch_var2)

refined_uv = torch.cat(((240 - 1.0)*refined_warp[:, 0:1, :, :]/2 , (240 - 1.0)*refined_warp[:, 1:2, :, :]/2), 1)

warp_img = refined_uv[0].detach().cpu().numpy().transpose(1,2,0)
z = np.zeros(shape = (240,240,2))

np.save('./{n}/new_60_0.npy'.format(n = n), warp_img)
###

