In [None]:
# Implementation of the ADP and modified ADP frameworks for the deblurring problem with IFT based algorithms

import torch
import torchvision
from torch import nn, optim
from torch.autograd import Variable
from torch.autograd.functional import hessian
from torch.autograd.functional import jacobian
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor, Lambda
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.autograd.functional import hessian
from torch.autograd.functional import jacobian
import torch.nn.functional as F

# Define the transformation to be applied to images
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])


# Define the device to use for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Bilevel IFT ADP with smoothed TV
theta = torch.tensor([-2.95, -10.0, -12], requires_grad=True)


stepsize = 5*1e-3
losslist = []

class ADP_IFT:
    def __init__(self,A,B,x, x0,y_delta,modified = True, inexact = False,max_iter = 100, eps = 1e-12) -> None:
        self.A = A
        self.B = B
        self.ydelta = y_delta
        self.x = x.unsqueeze(1)
        self.x0 = x0
        self.modified = modified
        self.max_iter = max_iter
        self.inexact = inexact
        self.eps = eps
    def TV(self,x, nu):
        return (torch.sum(torch.sqrt((x[1:]-x[:-1])**2+ nu**2)))
    def TV2d(self,x, nu):
        return (torch.sum(torch.sqrt((x[:,1:]-x[:,:-1])**2+ nu**2)) + torch.sum(torch.sqrt((x[1:,:]-x[-1,:])**2+ nu**2)))
    def phi(self,x,B):
        if self.modified:
            return (0.5 *torch.linalg.norm(torch.matmul(B,torch.matmul(self.A,x)- self.ydelta).float())**2 + torch.exp(theta[0]).float() * self.TV(x,torch.exp(theta[1])) + torch.exp(theta[2]).float() * torch.linalg.norm(x)**2 ).float()
        return (0.5 *torch.linalg.norm(torch.matmul(B,x)- self.ydelta.float())**2 + torch.exp(theta[0]).float() * self.TV(x,torch.exp(theta[1]))).float() + torch.exp(theta[2]).float() * torch.linalg.norm(x)**2
    def Hess(self,x,B,d):
        x.requires_grad_(True)
        out = self.phi(x,B).float()
        grad_x = torch.autograd.grad(outputs=out, inputs=x, grad_outputs=None, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0] #first get grad using autograd
        hvp = torch.autograd.grad(outputs=grad_x, inputs=x, grad_outputs=d, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0]
        return hvp.detach()
    def Jac(self,x,B,d):
        x.requires_grad_(True)
        B.requires_grad_(True)
        out = self.phi(x,B)
        grad_x = torch.autograd.grad(outputs=out, inputs=x, grad_outputs=None, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0] #first get grad using autograd
        gradvp = torch.tensordot(grad_x.flatten(),d,dims=([0],[0]))
        jvp = torch.autograd.grad(outputs=gradvp, inputs=B, grad_outputs=torch.ones(gradvp.shape).requires_grad_(True), create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0]
        return jvp.detach()
    def CG(self,x,B,b,tol):
        iteration = 0
        r = b
        p = r
        rsold = float(torch.linalg.norm(r)**2)
        solution = torch.zeros(x.shape)
        while torch.linalg.norm(r)>tol and iteration < 2000:
            Ap = torch.tensor(self.Hess(x,B,p), dtype=torch.float32)
            alpha = rsold/torch.tensordot(p,Ap, dims=([0],[0]))
            solution = solution + alpha*p
            r = r - alpha*Ap
            rsnew = torch.linalg.norm(r)**2
            if torch.sqrt(rsnew) < tol:
                return solution
            p = r + (rsnew/rsold)*p
            rsold = rsnew
            iteration += 1
        return solution
    def lbfgs(self,x,B, tol, max_iter= 10000):
        lbfgs_optimiser= torch.optim.LBFGS([x], lr=0.05, max_iter=max_iter, max_eval=None, tolerance_grad=tol, tolerance_change=tol, history_size=100, line_search_fn=None)
        def closure():
            lbfgs_optimiser.zero_grad()
            loss = self.phi(x,B)
            loss.backward(retain_graph=True)
            return loss
        lbfgs_optimiser.step(closure)
        return x
    def Upper_level(self,x_hat):
        print ("dicrepancy ", torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta))
        if self.modified:
            return 0.5 * torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta)**2+ 0.05*torch.linalg.norm((-torch.eye(self.A.shape[0])+self.B))**2
        return 0.5 * torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta)**2+ 0.05*torch.linalg.norm((-self.A+self.B))**2
    def solver(self):
        stochastic = False
        eps = self.eps
        x_hat = self.x0
        Bk = torch.nn.parameter.Parameter(data=self.B.clone().detach(), requires_grad=True)
        optimiser = torch.optim.Adam([Bk], lr= 7*1e-4, betas=(0.9, 0.999), eps=1e-10, weight_decay=0, amsgrad=True)
        dp = torch.linalg.norm(torch.matmul(self.A,self.x0)- self.ydelta)
        print ("initial discrepancy ", dp)
        for k in range(self.max_iter):
            x_hat = self.lbfgs(x_hat,Bk, eps)
            losslist.append(self.Upper_level(x_hat))
            if k%100 == 0:
                print('loss at iteration', k, 'is', losslist[-1].item())
            q = self.CG(x_hat,Bk,torch.matmul(self.A.transpose(0,1),torch.matmul(self.A,x_hat)-self.ydelta),eps)
            p = - self.Jac(x_hat,Bk,q) + 0.1*Bk.T@(-torch.eye(self.A.shape[0])+Bk)

            if stochastic:
                with torch.no_grad():
                    param_shape = Bk.shape
                    Bk.grad = (p.reshape(param_shape))
                optimiser.step(lambda : self.Upper_level(x_hat))
                if (k+1)%20 == 0:
                    optimiser.param_groups[0]['lr'] *= 0.9
            else:
                if self.inexact:
                    eps = max(eps * 0.9, 1e-12)
                stepsize = 1e-5
                Bk = Bk - stepsize * p
            if (torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta) ) <= dp:
                break

            self.B = Bk
        return self.lbfgs(x_hat,Bk, 1e-14, max_iter=10000)

In [None]:
# discrete setting
h=0.01
t=torch.arange(-1+h/2,1,h,dtype=torch.float)
bsp = 0
# torch.manual_seed(0)
if bsp == 0:
    x = 1*(t>-0.6) - 0.7*(t>-0.5) - 0.3*(t>0) + 0.7*(t>0.2) - 0.7*(t>0.5)
elif bsp == 1:
    x = 0.8*torch.exp(-32*(t+0.3)**2) + 0.4*torch.exp(-16*(t-0.1)**2)
elif bsp == 2:
    x = (t+0.7)*(t>-0.7) - (2*t + 1)*(t>-0.5) + (t + 0.3)*(t>-0.3) + (0.9-t)*(t>-0.1) \
        + (t-0.2)*(t>0.2) - 0.7*(t>0.4) + (3*t-1.8)*(t>0.6) - (3*t-1.8)*(t>0.8)
A = torch.zeros((x.shape[0],x.shape[0]), dtype=torch.float)
for i in range(x.shape[0]):
    for j in range(x.shape[0]):
        A[i,j] = 1/(np.sqrt(2*np.pi)*5)*np.exp(-1/50*(i-j)**2)
B = torch.eye(A.shape[0], dtype=torch.float)
y = torch.matmul(A,x.reshape(x.shape[0],1))

eta = 0.005*torch.randn(y.shape, dtype=y.dtype)
print('noise', torch.linalg.norm(eta))
ydelta = y + eta

# Starting point
M = torch.matmul(A.transpose(0,1),A) + 0.0015 *torch.eye(x.shape[0])
b = torch.matmul(A.transpose(0,1),ydelta)
x_tik = torch.linalg.solve(M,b)
plt.plot(t,x, label='Ground truth')
plt.plot(t,ydelta, label='Blurry and noisy data')
plt.plot(t,x_tik, label='Tikhonov')
plt.legend()
plt.savefig('G_B_T.png', dpi=300)
plt.show()



In [None]:
def psnr(x, x_hat):
    mse = torch.mean((x-x_hat)**2)
    return 10*torch.log10(1/mse)
# xtik = torch.randn(x.shape, dtype=x.dtype)
x_start = torch.clone(x_tik)
model = ADP_IFT(A, B, x, x_start, ydelta, max_iter = 1000, inexact = True, eps= 1e-1)
x_k = model.solver().detach().numpy()
torch.save(model.B, 'B_Mod.pt')
np.save('ADPIFT2.npy', x_k)
# x_k = np.load('ADPIFT2.npy')
plt.plot(t,x_k, label='Modified ADP_IFT', linestyle='--', color='red')
plt.plot(t,x, label='Ground truth')
print('PSNR of modified ADP_IFT is', psnr(x,x_k).item())

x_start = torch.clone(x_tik)
B = A.clone().detach()
model2 = ADP_IFT(A, B, x, x_tik, ydelta,modified= False,inexact = True, max_iter=1000, eps= 1e-1)
x_k = model2.solver().detach().numpy()
np.save('ADP2.npy', x_k)
torch.save(model.B, 'B.pt')
# x_k = np.load('ADP.npy')
plt.plot(t,x_k, label='ADP_IFT', linestyle='dotted', color='green')
plt.legend()
print('PSNR of ADP_IFT is', psnr(x,x_k).item())
plt.savefig('ADPvsModifiedADP2.png', dpi=300)
plt.show()

In [None]:
#################################
#
# ADP solved by an unrolled like algorithm
#
#################################
def proxl1(x, alph):
    return F.relu(F.relu(x)-alph) - F.relu(F.relu(-x)-alph)
Bk = A
# Bk = 1.1*torch.eye(xk.shape[0], dtype=A.dtype) - 0.1*A

al_adp_l1=  0.108
al_adp_l2= 0.027
lamb_B=0.1
lamb_x=0.1
lamb = 0.05
xk = x_tik
for k in range(10000):
    xkn=xk-lamb*(torch.matmul(A.transpose(0,1),torch.matmul(A,xk)-ydelta))-lamb*0.027*xk
    xkn = proxl1(xkn,0.108*lamb)
    if abs(xkn-xk).max()<0.000001:
        print(k)
        break
    xk=xkn
    
eye=torch.eye(xk.shape[0],dtype=xk.dtype)
# Elastic net ADP
for k in range(1,7000):
    
    for j in range(500):
        #compute x(B) with a classical proximal gradient method
        xkn=xk-lamb_x*(torch.matmul(Bk.transpose(0,1),torch.matmul(Bk,xk)-ydelta))-lamb_x*al_adp_l2*xk
        xkn=proxl1(xkn,al_adp_l1*lamb_x)
        if (abs(xkn-xk)).sum()<0.000001:
            break
        xk=xkn

    #compute the gradient of x(B) w.r.t. B
    Axy = torch.matmul(A, xkn) - ydelta
    xi = torch.matmul(A.transpose(0,1),Axy)

    step = (1-al_adp_l2)*xkn - torch.matmul(Bk.transpose(0,1), torch.matmul(Bk,xkn)- ydelta) 

    IdBB = (1-al_adp_l2)*eye - torch.matmul(Bk.transpose(0,1),Bk)
    for i in range(step.shape[0]):
        if abs(step[i])<al_adp_l1:
            IdBB[i,:] = 0 

    v= torch.linalg.solve(eye - IdBB,xi)

    for i in range(step.shape[0]):
        if abs(step[i])<al_adp_l1:
            v[i,:] = 0 

    first = h*torch.matmul(torch.matmul(Bk,xk),v.transpose(0,1))
    second = h*torch.matmul(torch.matmul(Bk,v),xk.transpose(0,1))
    third = h*torch.matmul(ydelta,v.transpose(0,1))
    # fourth = 2 *0.05*B.T @(B- A)

    gradB = -first-second+third

    Bk = Bk - lamb_B*gradB

plt.plot(t,xk.detach().numpy(), label='ADP', linestyle='dotted', color='green')
plt.plot(t,x, label='Ground truth')
print('PSNR of ADP is', psnr(x,xk).item())
# Elastic net ADP modified
xk = x_tik
Bk = torch.eye(x_tik.shape[0], dtype=A.dtype)
for k in range(10000):
    xkn=xk-lamb*(torch.matmul(A.transpose(0,1),torch.matmul(A,xk)-ydelta))-lamb*0.027*xk
    xkn = proxl1(xkn,0.108*lamb)
    if abs(xkn-xk).max()<0.000001:
        print(k)
        break
    xk=xkn
    
for k in range(1,7000):
    for j in range(500):
        #compute x(B) with a classical proximal gradient method
        xkn=xk-lamb_x*(torch.matmul(A.T@Bk.transpose(0,1),torch.matmul(Bk,A@xk-ydelta)))-lamb_x*al_adp_l2*xk
        xkn=proxl1(xkn,al_adp_l1*lamb_x)
        if (abs(xkn-xk)).sum()<0.000001:
            break
        xk=xkn

    #compute the gradient of x(B) w.r.t. B
    Axy = torch.matmul(A, xkn) - ydelta
    xi = torch.matmul(A.transpose(0,1),Axy)

    step = (1-al_adp_l2)*xkn - torch.matmul(A.T@Bk.transpose(0,1), torch.matmul(Bk,A@xkn- ydelta)) 

    IdBB = (1-al_adp_l2)*eye - torch.matmul(A.T@Bk.transpose(0,1),Bk@A)
    for i in range(step.shape[0]):
        if abs(step[i])<al_adp_l1:
            IdBB[i,:] = 0 

    v= torch.linalg.solve(eye - IdBB,xi)

    for i in range(step.shape[0]):
        if abs(step[i])<al_adp_l1:
            v[i,:] = 0 

    first = h* 2*((xk.T@B)@A.T)@B.T@x @v.T
    second = h* (xk.T@B.T)@A.T@ydelta @v.T
    third = 2 *0.05*B.T @(B- eye)
    gradB = -first-second+third

    Bk = Bk - lamb_B*gradB
plt.plot(t,xk.detach().numpy(), label='Modified ADP', linestyle='dashed', color='red')
plt.legend()
plt.savefig('ADPvsModifiedADP_ElasticNet.png', dpi=300)
print('PSNR of ADP modified: ', psnr(xk.detach(), x))
plt.show()
 

In [None]:
# Implementation of the ADP and modified ADP frameworks for the deblurring problem with IFT based algorithm

import torch
import torchvision
from torch import nn, optim
from torch.autograd import Variable
from torch.autograd.functional import hessian
from torch.autograd.functional import jacobian
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor, Lambda
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.autograd.functional import hessian
from torch.autograd.functional import jacobian
import torch.nn.functional as F

# Define the transformation to be applied to images
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

# Define the device to use for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Bilevel IFT ADP with smoothed TV
# theta = torch.tensor([-7.0, -7.9, -12], requires_grad=True)
theta = torch.tensor([-8.7, -11.7, -12], requires_grad=True)

stepsize = 5*1e-3
losslist = []

class ADP_IFT:
    def __init__(self,A,B,x, x0,y_delta,modified = True, inexact = False,max_iter = 100, eps = 1e-12) -> None:
        self.A = A
        self.B = B
        self.ydelta = y_delta
        self.x = x.unsqueeze(1)
        self.x0 = x0
        self.modified = modified
        self.max_iter = max_iter
        self.inexact = inexact
        self.eps = eps
    def TV(self,x, nu):
        return (torch.sum(torch.sqrt((x[1:]-x[:-1])**2+ nu**2)))
    def TV2d(self,x, nu):
        tv_x = (x[1:,:]-x[:-1,:])**2
        tv_y = (x[:,1:]-x[:,:-1])**2
        tv = torch.sum(torch.sqrt(torch.flatten(tv_x) + torch.flatten(tv_y)+ nu**2))
        return tv
    def phi(self,x,B):
        if self.modified:
            return (0.5 *torch.linalg.norm(torch.matmul(B,torch.matmul(self.A,x)- self.ydelta).float())**2 + torch.exp(theta[0]).float() * self.TV2d(x,torch.exp(theta[1])) + torch.exp(theta[2]).float() * torch.linalg.norm(x)**2 ).float()
        return (0.5 *torch.linalg.norm(torch.matmul(B,x)- self.ydelta.float())**2 + torch.exp(theta[0]).float() * self.TV2d(x,torch.exp(theta[1]))).float() + torch.exp(theta[2]).float() * torch.linalg.norm(x)**2
    def Hess(self,x,B,d):
        x.requires_grad_(True)
        out = self.phi(x,B).float()
        grad_x = torch.autograd.grad(outputs=out, inputs=x, grad_outputs=None, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0] #first get grad using autograd
        hvp = torch.autograd.grad(outputs=grad_x, inputs=x, grad_outputs=d, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0]
        return hvp.detach()
    def Jac(self,x,B,d):
        x.requires_grad_(True)
        B.requires_grad_(True)
        out = self.phi(x,B)
        grad_x = torch.autograd.grad(outputs=out, inputs=x, grad_outputs=None, create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0] #first get grad using autograd
        gradvp = torch.tensordot(grad_x,d,dims=([0,1],[0,1])).reshape(-1)
        jvp = torch.autograd.grad(outputs=gradvp, inputs=B, grad_outputs=torch.ones(gradvp.shape).requires_grad_(True), create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0]
        return jvp.detach()
    def CG(self,x,B,b,tol):
        iteration = 0
        r = b
        p = r
        rsold = float((torch.linalg.norm(r)**2).item())
        solution = torch.zeros(x.shape)
        while torch.linalg.norm(r)>tol and iteration < 2000:
            Ap = torch.tensor(self.Hess(x,B,p), dtype=torch.float32)
            alpha = rsold/torch.tensordot(p,Ap, dims=([0,1],[0,1]))
            solution = solution + alpha*p
            r = r - alpha*Ap
            rsnew = torch.linalg.norm(r)**2
            if torch.sqrt(rsnew) < tol:
                return solution
            p = r + (rsnew/rsold)*p.clone().detach()
            rsold = rsnew
            iteration += 1
        return solution
    def lbfgs(self,x,B, tol, max_iter= 10000):
        shape = x.shape
        x = x.flatten().clone().detach().requires_grad_(True)
        lbfgs_optimiser= torch.optim.LBFGS([x], lr=0.05, max_iter=max_iter, max_eval=None, tolerance_grad=tol, tolerance_change=tol, history_size=100, line_search_fn=None)
        def closure():
            lbfgs_optimiser.zero_grad()
            loss = self.phi(x.reshape(shape),B)
            loss.backward(retain_graph=True)
            return loss
        lbfgs_optimiser.step(closure)
        return x.reshape(shape)
    def Upper_level(self,x_hat):
        print ("dicrepancy ", torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta))
        if self.modified:
            return 0.5 * torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta)**2+ 0.05*torch.linalg.norm((-torch.eye(self.A.shape[0])+self.B))**2
        return 0.5 * torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta)**2+ 0.05*torch.linalg.norm((-self.A+self.B))**2
    def solver(self):
        stochastic = False
        eps = self.eps
        x_hat = self.x0
        Bk = torch.nn.parameter.Parameter(data=self.B.clone().detach(), requires_grad=True)
        optimiser = torch.optim.Adam([Bk], lr= 7*1e-4, betas=(0.9, 0.999), eps=1e-10, weight_decay=0, amsgrad=True)
        # optimiser = torch.optim.SGD([Bk], lr= 9*1e-4, momentum=0.9, dampening=0, weight_decay=0, nesterov=True)
        dp = torch.linalg.norm(torch.matmul(self.A,self.x0)- self.ydelta)
        print ("initial discrepancy ", dp)
        for k in range(self.max_iter):
            x_hat = self.lbfgs(x_hat,Bk, eps)
            losslist.append(self.Upper_level(x_hat))
            if k%10 == 0:
                print('loss at iteration', k, 'is', losslist[-1].item())
            q = self.CG(x_hat,Bk,torch.matmul(self.A.transpose(0,1),torch.matmul(self.A,x_hat)-self.ydelta).squeeze(0).squeeze(0),eps)
            p = - self.Jac(x_hat,Bk,q) + 0.1*Bk.T@(-torch.eye(self.A.shape[0])+Bk)

            if stochastic:
                with torch.no_grad():
                    param_shape = Bk.shape
                    Bk.grad = (p.reshape(param_shape))
                optimiser.step(lambda : self.Upper_level(x_hat))
                if (k+1)%20 == 0:
                    optimiser.param_groups[0]['lr'] *= 0.9
            else:
                if self.inexact:
                    eps = max(eps * 0.9, 1e-12)
                stepsize = 1e-5
                Bk = Bk - stepsize * p
            if torch.linalg.norm(torch.matmul(self.A,x_hat)- self.ydelta) <= dp:
                break

            self.B = Bk
        return self.lbfgs(x_hat,Bk, 1e-14, max_iter=10000)

input_size = 512
x = Image.open('SheppLogan_Phantom.png')
# x = Image.open('Cameraman.png')
x = transforms.Resize((input_size, input_size))(x)
x = transforms.Grayscale(num_output_channels=1)(x)
x = transforms.ToTensor()(x).clamp(0, 1)
plt.imshow(x[0].detach().numpy(), cmap='gray')
plt.show()

h = 0.01
t = torch.arange(-1 + h / 2, 1, h, dtype=torch.float)
bsp = 0
sigma = 3.0  # Adjust the blurring sigma here

A = torch.zeros((x.shape[2], x.shape[2]), dtype=torch.float)
for i in range(x.shape[2]):
    for j in range(x.shape[2]):
        A[i, j] = 1 / (np.sqrt(2 * np.pi) * sigma) * np.exp(-1 / (2 * sigma ** 2) * (i - j) ** 2)

B = torch.eye(A.shape[0], dtype=torch.float)
y = torch.matmul(A, x[0])
eta = 0.005 * torch.randn(y.shape, dtype=y.dtype)
print('noise:', torch.linalg.norm(eta))
ydelta = y + eta

torch.save(ydelta, 'ydelta.pt')

plt.imshow(ydelta.detach().numpy(), cmap='gray')
plt.show()
print('PSNR of noisy image is', psnr(x[0], ydelta))
# Starting point
M = torch.matmul(A.transpose(0, 1), A) + 0.0015 * torch.eye(x.shape[2])
b = torch.matmul(A.transpose(0, 1), ydelta)
x_tik = torch.linalg.solve(M, b)
plt.imshow(x_tik.detach().numpy(), cmap='gray')
plt.show()
print('PSNR of Tikhonov is', psnr(x[0], x_tik).item())
torch.save(x_tik, 'x_tik.pt')
def psnr(x, x_hat):
    mse = torch.mean((x-x_hat)**2)
    return 10*torch.log10(1/mse)
x_start = torch.clone(x_tik)
model = ADP_IFT(A, B, x, x_start, ydelta, max_iter = 200, inexact = True, eps= 1e-1, modified = True)
x_k = model.solver().detach().numpy()
np.save('x_k_mod.npy', x_k)
plt.imshow(x_k, cmap='gray')
plt.show()
# torch.save(model.B, 'B_mod_2d.pt')
print('PSNR of modified ADP_IFT is', psnr(x,x_k).item())
x_start = torch.clone(x_tik)
model = ADP_IFT(A, B, x, x_start, ydelta, max_iter = 200, inexact = True, eps= 1e-1, modified = False)
x_k = model.solver().detach().numpy()
torch.save(model.B, 'B_2d.pt')
np.save('x_k_ADP.npy', x_k)
plt.imshow(x_k, cmap='gray')
print('PSNR of ADP_IFT is', psnr(x,x_k).item())

plot = True
if plot:
    # Create a figure and subplots
    input_size = 256
    x = Image.open('SheppLogan_Phantom.png')
    x = transforms.Resize((input_size, input_size))(x)
    x = transforms.Grayscale(num_output_channels=1)(x)
    x = transforms.ToTensor()(x).clamp(0, 1)
    image1 = x[0].detach().numpy()

    image2 = torch.load('ydelta.pt').detach().numpy()
    fig, axes = plt.subplots(1, 2)

    # Disable axes ticks for both subplots
    for ax in axes:
        ax.axis('off')

    # Plot the first image in the first subplot
    axes[0].imshow(image1)
    axes[0].set_title('Ground truth')

    # Plot the second image in the second subplot
    axes[1].imshow(image2)
    axes[1].set_title('Blurred image\n PSNR = 23.95 dB')

    # Adjust the layout and display the plot
    plt.tight_layout()
    plt.savefig('Ground_blur_phantom.png', dpi = 300)
    plt.show()

    image1 = torch.load('x_tik.pt').detach().numpy()
    psnr1 = psnr(x[0], torch.from_numpy(image1)).item()
    image2 = np.load('x_k_ADP.npy')
    psnr2 = psnr(x[0], torch.from_numpy(image2)).item()
    image3 = np.load('x_k_mod.npy')
    psnr3 = psnr(x[0], torch.from_numpy(image3)).item()
    fig, axes = plt.subplots(1, 3)
    # Disable axes ticks for both subplots
    for ax in axes:
        ax.axis('off')
    axes[0].imshow(image1, cmap='gray')
    axes[0].set_title(f'Tikhonov\n PSNR = {psnr1:.2f} dB')
    axes[1].imshow(image2, cmap='gray')
    axes[1].set_title(f'ADP\n PSNR = {psnr2:.2f} dB')
    axes[2].imshow(image3, cmap='gray')
    axes[2].set_title(f'Modified ADP\n PSNR = {psnr3:.2f} dB')
    # Adjust the layout and display the plot
    plt.tight_layout()
    plt.savefig('Tik_ADP.png', dpi = 300)
    plt.show()