In [1]:
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage import io, util, color
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


  "cipher": algorithms.TripleDES,
  "class": algorithms.Blowfish,
  "class": algorithms.TripleDES,


In [2]:

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        # Final output
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        return block

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(nn.MaxPool2d(2)(e1))
        e3 = self.enc3(nn.MaxPool2d(2)(e2))
        e4 = self.enc4(nn.MaxPool2d(2)(e3))
        # Bottleneck
        b = self.bottleneck(nn.MaxPool2d(2)(e4))
        # Decoder
        d4 = self.up4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.dec4(d4)
        d3 = self.up3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)
        # Output
        out = self.final(d1)
        return out, b  # Return bottleneck features

# Custom Dataset with Multiple Domains
class MultiDomainDataset(Dataset):
    def __init__(self, image_paths, transform=None, mode='train'):
        self.image_paths = image_paths
        self.transform = transform
        self.mode = mode

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = io.imread(self.image_paths[idx])
        if len(image.shape) == 3:
            image = color.rgb2gray(image)
        image = (image * 255).astype(np.uint8)

        if self.transform:
            image = self.transform(image)

        sample = {'clean': image}

        if self.mode == 'train' or self.mode == 'val':
            noise_type = random.choice(['gaussian', 'speckle'])
            if self.mode == 'val':
                noise_type = random.choice([noise_type, 'salt_pepper', 'salt_pepper'])
                
            if noise_type == 'gaussian':
                noisy_image = util.random_noise(image, mode='gaussian', var=random.uniform(0.01, 0.05))
            elif noise_type == 'salt_pepper':
                noisy_image = util.random_noise(image, mode='s&p', amount=random.uniform(0.02, 0.1))
            elif noise_type == 'speckle':
                noisy_image = util.random_noise(image, mode='speckle', var=random.uniform(0.01, 0.05))
            sample['noisy'] = torch.tensor(noisy_image, dtype=torch.float32)
        elif self.mode == 'test':
            # Use specific noise types based on test requirements
            if idx % 4 == 0:
                noise_type = random.choice(['gaussian', 'speckle'])
            else:
                noise_type = random.choice(['poisson', 'salt_pepper', 'salt_pepper', 'salt_pepper'])
            if noise_type == 'gaussian':
                noisy_image = util.random_noise(image, mode='gaussian', var=random.uniform(0.01, 0.05))
            elif noise_type == 'speckle':
                noisy_image = util.random_noise(image, mode='speckle', var=random.uniform(0.01, 0.05))
            elif noise_type == 'poisson':
                noisy_image = util.random_noise(image, mode='poisson')
            elif noise_type == 'salt_pepper':
                noisy_image = util.random_noise(image, mode='s&p', amount=random.uniform(0.04, 0.1))
            sample['noisy'] = torch.tensor(noisy_image, dtype=torch.float32)

        return sample


In [3]:

# Custom Dataset with Multiple Domains
class MultiDomainDataset(Dataset):
    def __init__(self, image_paths, transform=None, mode='train'):
        self.image_paths = image_paths
        self.transform = transform
        self.mode = mode

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = io.imread(self.image_paths[idx])
        if len(image.shape) == 3:
            image = color.rgb2gray(image)
        image = (image * 255).astype(np.uint8)

        if self.transform:
            image = self.transform(image)

        sample = {'clean': image}

        if self.mode == 'train' or self.mode == 'val':
            noise_type = random.choice(['gaussian', 'speckle'])
            if self.mode == 'val':
                noise_type = random.choice([noise_type, 'salt_pepper', 'salt_pepper'])
                
            if noise_type == 'gaussian':
                noisy_image = util.random_noise(image, mode='gaussian', var=random.uniform(0.01, 0.05))
            elif noise_type == 'salt_pepper':
                noisy_image = util.random_noise(image, mode='s&p', amount=random.uniform(0.02, 0.05))
            elif noise_type == 'speckle':
                noisy_image = util.random_noise(image, mode='speckle', var=random.uniform(0.01, 0.05))
            sample['noisy'] = torch.tensor(noisy_image, dtype=torch.float32)
        elif self.mode == 'test':
            # Use specific noise types based on test requirements
            if idx % 4 == 0:
                noise_type = random.choice(['gaussian', 'speckle'])
            else:
                noise_type = random.choice(['poisson', 'salt_pepper', 'salt_pepper', 'salt_pepper'])
            if noise_type == 'gaussian':
                noisy_image = util.random_noise(image, mode='gaussian', var=random.uniform(0.01, 0.05))
            elif noise_type == 'speckle':
                noisy_image = util.random_noise(image, mode='speckle', var=random.uniform(0.01, 0.05))
            elif noise_type == 'poisson':
                noisy_image = util.random_noise(image, mode='poisson')
            elif noise_type == 'salt_pepper':
                noisy_image = util.random_noise(image, mode='s&p', amount=random.uniform(0.04, 0.05))
            sample['noisy'] = torch.tensor(noisy_image, dtype=torch.float32)

        return sample


In [4]:

import random

# Load all image paths
train_image_paths = glob.glob('/kaggle/input/berkeley-segmentation-dataset-500-bsds500/images/train/*.jpg')
val_image_paths = glob.glob('/kaggle/input/berkeley-segmentation-dataset-500-bsds500/images/val/*.jpg')
test_image_paths = glob.glob('/kaggle/input/berkeley-segmentation-dataset-500-bsds500/images/test/*.jpg')

# Combine all images
all_image_paths = train_image_paths + val_image_paths + test_image_paths

# Shuffle the images
random.shuffle(all_image_paths)

# Split into train, val, and test
train_image_paths = all_image_paths[:300]
val_image_paths = all_image_paths[300:400]
test_image_paths = all_image_paths[400:500]

print(f"Number of training images: {len(train_image_paths)}")
print(f"Number of validation images: {len(val_image_paths)}")
print(f"Number of testing images: {len(test_image_paths)}")

Number of training images: 0
Number of validation images: 0
Number of testing images: 0


In [None]:
# # Data Augmentation
# transform = transforms.Compose([
#     transforms.ToPILImage(),
#     transforms.Resize((256, 256)),
#     transforms.ToTensor(),
# ])

# # Create datasets
# train_dataset = MultiDomainDataset(train_image_paths, transform=transform, mode='train')
# val_dataset = MultiDomainDataset(val_image_paths, transform=transform, mode='val')
# test_dataset = MultiDomainDataset(test_image_paths, transform=transform, mode='test')

# # Create data loaders
# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
# val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)
# test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [5]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


In [6]:
# from torchmetrics.functional import structural_similarity_index_measure

def reconstruction_loss(output, target):
    # L1 Loss
    l1_loss = nn.L1Loss()(output, target)

    # SSIM Loss
    output = output.clamp(0, 1)
    target = target.clamp(0, 1)
    ssim_loss = 1

    # Weighted combination
    return 0.7 * l1_loss + 0.3 * ssim_loss


In [7]:
import torch
from torchmin2 import minimize

def pairwise_distances(x, y, power=2, sum_dim=2):
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)

    x = x.unsqueeze(1).expand(n,m,d)
    y = y.unsqueeze(0).expand(n,m,d)
    dist = torch.pow(x-y, power).sum(sum_dim)
    return dist

def StandardScaler(x,with_std=False):
    mean = x.mean(0, keepdim=True)
    std = x.std(0, unbiased=False, keepdim=True)
    x = x- mean
    if with_std:
        x /= (std + 1e-10)
    return x

def H_Distance(FX,y,l,sigma=None,lamda=1e-2,device=torch.device('cpu')):   
    if sigma is None:
        pairwise_dist = torch.cdist(FX,FX,p=2)**2 
        sigma = torch.median(pairwise_dist[pairwise_dist!=0])  
    domain_label = torch.unique(l)
    target_domain_idx = len(domain_label)-1
    FXt,yt = FX[l==target_domain_idx],y[l==target_domain_idx]
    nt = len(yt)
    div = 0.0
    for dl in domain_label[:-1]:
        FXs,ys = FX[l==dl],y[l==dl]
        ns = len(ys)
        FXst,yst = torch.cat((FXs,FXt),dim=0),torch.cat((ys,yt),dim=0)
        FXst_norm = torch.sum(FXst ** 2, axis = -1)
        Kst = torch.exp(-(FXst_norm[:,None] + FXst_norm[None,:] - 2 * torch.matmul(FXst, FXst.t())) / sigma) * (yst[:,None]==yst)
        def Obj(theta):
            """
            Approximation of Hellinger distance
            """
            div = 2. - (torch.mean(torch.exp(-torch.matmul(Kst[:ns],theta))) + torch.mean(torch.exp(torch.matmul(Kst[ns:],theta))))
            reg = lamda * torch.matmul(theta,theta) 
            return -div + reg
    
        theta_0 = torch.zeros(ns+nt, device=device)
        result = minimize(Obj,theta_0,method='l-bfgs')
        theta_hat = result.x
        div = div + 2. - (torch.mean(torch.exp(-torch.matmul(Kst[:ns],theta_hat))) + torch.mean(torch.exp(torch.matmul(Kst[ns:],theta_hat))))
    return div     


In [9]:
# Train Function for Proposed Model
def train_proposed_model(model, train_loader, val_loader, optimizer, num_epochs=50, lambda_h=0.5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for i in range(1):
            clean_train = torch.randn((4, 1, 256, 256)).to(device)
            noisy_train = torch.randn((4, 1, 256, 256)).to(device)
            noisy_val = torch.randn((4, 1, 256, 256)).to(device)

            l = torch.concat([torch.zeros(noisy_train.shape[0]), torch.ones(noisy_val.shape[0])]).to(device)
            noisy = torch.concat([noisy_train, noisy_val], dim=0)[:, :8]
            # Forward pass for training data
            output, X = model(noisy)
            X = X.reshape(X.shape[0], -1)
            output_train = output[l == 0]

            loss_rec = reconstruction_loss(output_train, clean_train)
            # Forward pass for validation data (no labels available, only compute latent)
            y = torch.zeros((X.shape[0],)).to(device)
            
            # Compute Hellinger distance using PCA
            loss_hellinger = H_Distance(X,y,l,sigma=None,lamda=1e-2,device=device)

            # Total loss
            loss = loss_rec + lambda_h * loss_hellinger

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# Instantiate and train
model_proposed = UNet()
optimizer_proposed = optim.Adam(model_proposed.parameters(), lr=1e-4)
print("Training Proposed Model with Hellinger Distance...")
train_proposed_model(model_proposed, " ", " ", optimizer_proposed, num_epochs=4, lambda_h=0.2)

Training Proposed Model with Hellinger Distance...
Epoch [1/4], Loss: 1.1509


KeyboardInterrupt: 

In [None]:
# Train Function for Baseline Model
def train_baseline_model(model, train_loader, optimizer, num_epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        for train_batch in train_loader:
            clean_train = train_batch['clean'].to(device)
            noisy_train = train_batch['noisy'].to(device)

            # Forward pass for training data
            output_train, _ = model(noisy_train)
            loss_rec = reconstruction_loss(output_train, clean_train)

            # Backward and optimize
            optimizer.zero_grad()
            loss_rec.backward()
            optimizer.step()

            epoch_loss += loss_rec.item()

        avg_loss = epoch_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')


model_baseline = UNet()
optimizer_baseline = optim.Adam(model_baseline.parameters(), lr=1e-4)
print("Training Baseline Model...")
train_baseline_model(model_baseline, train_loader, optimizer_baseline, num_epochs=20)

In [None]:
# Evaluation on Test Data
def evaluate_model(model, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    psnr_list = []
    ssim_list = []

    with torch.no_grad():
        for batch in test_loader:
            clean = batch['clean'].to(device)
            noisy = batch['noisy'].to(device)

            output, _ = model(noisy)

            for i in range(clean.size(0)):
                clean_img = clean[i].cpu().numpy().squeeze()
                output_img = output[i].cpu().numpy().squeeze()
                psnr_list.append(psnr(clean_img, output_img, data_range=clean_img.max() - clean_img.min()))
                ssim_list.append(ssim(clean_img, output_img, data_range=clean_img.max() - clean_img.min()))

    avg_psnr = np.mean(psnr_list)
    avg_ssim = np.mean(ssim_list)
    print(f'Average PSNR: {avg_psnr:.2f}, Average SSIM: {avg_ssim:.4f}')

print("Evaluating Proposed Model...")
evaluate_model(model_proposed, test_loader)

print("Evaluating Baseline Model...")
evaluate_model(model_baseline, test_loader)


In [None]:

# Visualization for a test sample with salt and pepper noise
def visualize_sample(model_proposed, model_baseline, sample_idx=0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_proposed.to(device)
    model_baseline.to(device)
    model_proposed.eval()
    model_baseline.eval()

    sample = test_dataset[sample_idx]
    noisy_image = sample['noisy'].unsqueeze(0).to(device)
    clean_image = sample['clean'].cpu().numpy().squeeze()

    with torch.no_grad():
        output_proposed, _ = model_proposed(noisy_image)
        output_baseline, _ = model_baseline(noisy_image)

    output_proposed = output_proposed.cpu().numpy().squeeze()
    output_baseline = output_baseline.cpu().numpy().squeeze()

    psnr_proposed = psnr(clean_image, output_proposed, data_range=clean_image.max() - clean_image.min())
    ssim_proposed = ssim(clean_image, output_proposed, data_range=clean_image.max() - clean_image.min())
    psnr_baseline = psnr(clean_image, output_baseline, data_range=clean_image.max() - clean_image.min())
    ssim_baseline = ssim(clean_image, output_baseline, data_range=clean_image.max() - clean_image.min())

    print(f'Proposed Model - PSNR: {psnr_proposed:.2f}, SSIM: {ssim_proposed:.4f}')
    print(f'Baseline Model - PSNR: {psnr_baseline:.2f}, SSIM: {ssim_baseline:.4f}')

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 4, 1)
    plt.title('Clean Image')
    plt.imshow(clean_image, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 4, 2)
    plt.title('Noisy Image (Salt & Pepper)')
    plt.imshow(noisy_image.cpu().numpy().squeeze(), cmap='gray')
    plt.axis('off')

    plt.subplot(1, 4, 3)
    plt.title('Proposed Model Output')
    plt.imshow(output_proposed, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 4, 4)
    plt.title('Baseline Model Output')
    plt.imshow(output_baseline, cmap='gray')
    plt.axis('off')

    plt.show()

In [None]:
print("Visualizing a sample with salt and pepper noise...")
visualize_sample(model_proposed, model_baseline, sample_idx=10)


In [None]:
print("Visualizing a sample with salt and pepper noise...")
visualize_sample(model_proposed, model_baseline, sample_idx=90)


In [None]:
print("Visualizing a sample with salt and pepper noise...")
visualize_sample(model_proposed, model_baseline, sample_idx=60)


In [None]:
print("Visualizing a sample with salt and pepper noise...")
visualize_sample(model_proposed, model_baseline, sample_idx=33)


In [None]:
print("Visualizing a sample with salt and pepper noise...")
visualize_sample(model_proposed, model_baseline, sample_idx=51)


In [None]:
from abc import ABC, abstractmethod
import torch
from torch import Tensor
from scipy.optimize import OptimizeResult
from scipy.optimize.optimize import _status_message

from .function import ScalarFunction
from .line_search import strong_wolfe


class HessianUpdateStrategy(ABC):
    def __init__(self):
        self.n_updates = 0

    @abstractmethod
    def solve(self, grad):
        pass

    @abstractmethod
    def _update(self, s, y, rho_inv):
        pass

    def update(self, s, y):
        rho_inv = y.dot(s)
        if rho_inv <= 1e-10:
            # curvature is negative; do not update
            return
        self._update(s, y, rho_inv)
        self.n_updates += 1


class L_BFGS(HessianUpdateStrategy):
    def __init__(self, x, history_size=100):
        super().__init__()
        self.y = []
        self.s = []
        self.rho = []
        self.H_diag = 1.
        self.alpha = x.new_empty(history_size)
        self.history_size = history_size

    def solve(self, grad):
        mem_size = len(self.y)
        d = grad.neg()
        for i in reversed(range(mem_size)):
            self.alpha[i] = self.s[i].dot(d) * self.rho[i]
            d.add_(self.y[i], alpha=-self.alpha[i])
        d.mul_(self.H_diag)
        for i in range(mem_size):
            beta_i = self.y[i].dot(d) * self.rho[i]
            d.add_(self.s[i], alpha=self.alpha[i] - beta_i)

        return d

    def _update(self, s, y, rho_inv):
        if len(self.y) == self.history_size:
            self.y.pop(0)
            self.s.pop(0)
            self.rho.pop(0)
        self.y.append(y)
        self.s.append(s)
        self.rho.append(rho_inv.reciprocal())
        self.H_diag = rho_inv / y.dot(y)


class BFGS(HessianUpdateStrategy):
    def __init__(self, x, inverse=True):
        super().__init__()
        self.inverse = inverse
        if inverse:
            self.I = torch.eye(x.numel(), device=x.device, dtype=x.dtype)
            self.H = self.I.clone()
        else:
            self.B = torch.eye(x.numel(), device=x.device, dtype=x.dtype)

    def solve(self, grad):
        if self.inverse:
            return torch.matmul(self.H, grad.neg())
        else:
            return torch.cholesky_solve(grad.neg().unsqueeze(1),
                                        torch.linalg.cholesky(self.B)).squeeze(1)

    def _update(self, s, y, rho_inv):
        rho = rho_inv.reciprocal()
        if self.inverse:
            if self.n_updates == 0:
                self.H.mul_(rho_inv / y.dot(y))
            R = torch.addr(self.I, s, y, alpha=-rho)
            torch.addr(
                torch.linalg.multi_dot((R, self.H, R.t())),
                s, s, alpha=rho, out=self.H)
        else:
            if self.n_updates == 0:
                self.B.mul_(rho * y.dot(y))
            Bs = torch.mv(self.B, s)
            self.B.addr_(y, y, alpha=rho)
            self.B.addr_(Bs, Bs, alpha=-1./s.dot(Bs))


@torch.no_grad()
def _minimize_bfgs_core(
        fun, x0, lr=1., low_mem=False, history_size=100, inv_hess=True,
        max_iter=None, line_search='strong-wolfe', gtol=1e-5, xtol=1e-9,
        normp=float('inf'), callback=None, disp=0, return_all=False):
    """Minimize a multivariate function with BFGS or L-BFGS.

    We choose from BFGS/L-BFGS with the `low_mem` argument.

    Parameters
    ----------
    fun : callable
        Scalar objective function to minimize
    x0 : Tensor
        Initialization point
    lr : float
        Step size for parameter updates. If using line search, this will be
        used as the initial step size for the search.
    low_mem : bool
        Whether to use L-BFGS, the "low memory" variant of the BFGS algorithm.
    history_size : int
        History size for L-BFGS hessian estimates. Ignored if `low_mem=False`.
    inv_hess : bool
        Whether to parameterize the inverse hessian vs. the hessian with BFGS.
        Ignored if `low_mem=True` (L-BFGS always parameterizes the inverse).
    max_iter : int, optional
        Maximum number of iterations to perform. Defaults to 200 * x0.numel()
    line_search : str
        Line search specifier. Currently the available options are
        {'none', 'strong_wolfe'}.
    gtol : float
        Termination tolerance on 1st-order optimality (gradient norm).
    xtol : float
        Termination tolerance on function/parameter changes.
    normp : Number or str
        The norm type to use for termination conditions. Can be any value
        supported by `torch.norm` p argument.
    callback : callable, optional
        Function to call after each iteration with the current parameter
        state, e.g. ``callback(x)``.
    disp : int or bool
        Display (verbosity) level. Set to >0 to print status messages.
    return_all : bool, optional
        Set to True to return a list of the best solution at each of the
        iterations.

    Returns
    -------
    result : OptimizeResult
        Result of the optimization routine.
    """
    lr = float(lr)
    disp = int(disp)
    if max_iter is None:
        max_iter = x0.numel() * 200
    if low_mem and not inv_hess:
        raise ValueError('inv_hess=False is not available for L-BFGS.')

    # construct scalar objective function
    sf = ScalarFunction(fun, x0.shape)
    closure = sf.closure
    if line_search == 'strong-wolfe':
        dir_evaluate = sf.dir_evaluate

    # compute initial f(x) and f'(x)
    x = x0.detach().view(-1).clone(memory_format=torch.contiguous_format)
    f, g, _, _ = closure(x)
    if disp > 1:
        print('initial fval: %0.4f' % f)
    if return_all:
        allvecs = [x]

    # initial settings
    if low_mem:
        hess = L_BFGS(x, history_size)
    else:
        hess = BFGS(x, inv_hess)
    d = g.neg()
    t = min(1., g.norm(p=1).reciprocal()) * lr
    n_iter = 0

    # BFGS iterations
    for n_iter in range(1, max_iter+1):

        # ==================================
        #   compute Quasi-Newton direction
        # ==================================

        if n_iter > 1:
            d = hess.solve(g)

        # directional derivative
        gtd = g.dot(d)

        # check if directional derivative is below tolerance
        if gtd > -xtol:
            warnflag = 4
            msg = 'A non-descent direction was encountered.'
            break

        # ======================
        #   update parameter
        # ======================

        if line_search == 'none':
            # no line search, move with fixed-step
            x_new = x + d.mul(t)
            f_new, g_new, _, _ = closure(x_new)
        elif line_search == 'strong-wolfe':
            #  Determine step size via strong-wolfe line search
            f_new, g_new, t, ls_evals = \
                strong_wolfe(dir_evaluate, x, t, d, f, g, gtd)
            x_new = x + d.mul(t)
        else:
            raise ValueError('invalid line_search option {}.'.format(line_search))

        if disp > 1:
            print('iter %3d - fval: %0.4f' % (n_iter, f_new))
        if return_all:
            allvecs.append(x_new)
        if callback is not None:
            callback(x_new)

        # ================================
        #   update hessian approximation
        # ================================

        s = x_new.sub(x)
        y = g_new.sub(g)

        hess.update(s, y)

        # =========================================
        #   check conditions and update buffers
        # =========================================

        # convergence by insufficient progress
        if (s.norm(p=normp) <= xtol) | ((f_new - f).abs() <= xtol):
            warnflag = 0
            msg = _status_message['success']
            break

        # update state
        f[...] = f_new
        x.copy_(x_new)
        g.copy_(g_new)
        t = lr

        # convergence by 1st-order optimality
        if g.norm(p=normp) <= gtol:
            warnflag = 0
            msg = _status_message['success']
            break

        # precision loss; exit
        if ~f.isfinite():
            warnflag = 2
            msg = _status_message['pr_loss']
            break

    else:
        # if we get to the end, the maximum num. iterations was reached
        warnflag = 1
        msg = _status_message['maxiter']

    if disp:
        print(msg)
        print("         Current function value: %f" % f)
        print("         Iterations: %d" % n_iter)
        print("         Function evaluations: %d" % sf.nfev)
    result = OptimizeResult(fun=f, x=x.view_as(x0), grad=g.view_as(x0),
                            status=warnflag, success=(warnflag==0),
                            message=msg, nit=n_iter, nfev=sf.nfev)
    if not low_mem:
        if inv_hess:
            result['hess_inv'] = hess.H.view(2 * x0.shape)
        else:
            result['hess'] = hess.B.view(2 * x0.shape)
    if return_all:
        result['allvecs'] = allvecs

    return result


def _minimize_bfgs(
        fun, x0, lr=1., inv_hess=True, max_iter=None,
        line_search='strong-wolfe', gtol=1e-5, xtol=1e-9,
        normp=float('inf'), callback=None, disp=0, return_all=False):
    """Minimize a multivariate function with BFGS

    Parameters
    ----------
    fun : callable
        Scalar objective function to minimize.
    x0 : Tensor
        Initialization point.
    lr : float
        Step size for parameter updates. If using line search, this will be
        used as the initial step size for the search.
    inv_hess : bool
        Whether to parameterize the inverse hessian vs. the hessian with BFGS.
    max_iter : int, optional
        Maximum number of iterations to perform. Defaults to
        ``200 * x0.numel()``.
    line_search : str
        Line search specifier. Currently the available options are
        {'none', 'strong_wolfe'}.
    gtol : float
        Termination tolerance on 1st-order optimality (gradient norm).
    xtol : float
        Termination tolerance on function/parameter changes.
    normp : Number or str
        The norm type to use for termination conditions. Can be any value
        supported by :func:`torch.norm`.
    callback : callable, optional
        Function to call after each iteration with the current parameter
        state, e.g. ``callback(x)``.
    disp : int or bool
        Display (verbosity) level. Set to >0 to print status messages.
    return_all : bool, optional
        Set to True to return a list of the best solution at each of the
        iterations.

    Returns
    -------
    result : OptimizeResult
        Result of the optimization routine.
    """
    return _minimize_bfgs_core(
        fun, x0, lr, low_mem=False, inv_hess=inv_hess, max_iter=max_iter,
        line_search=line_search, gtol=gtol, xtol=xtol,
        normp=normp, callback=callback, disp=disp, return_all=return_all)


def _minimize_lbfgs(
        fun, x0, lr=1., history_size=100, max_iter=None,
        line_search='strong-wolfe', gtol=1e-5, xtol=1e-9,
        normp=float('inf'), callback=None, disp=0, return_all=False):
    """Minimize a multivariate function with L-BFGS

    Parameters
    ----------
    fun : callable
        Scalar objective function to minimize.
    x0 : Tensor
        Initialization point.
    lr : float
        Step size for parameter updates. If using line search, this will be
        used as the initial step size for the search.
    history_size : int
        History size for L-BFGS hessian estimates.
    max_iter : int, optional
        Maximum number of iterations to perform. Defaults to
        ``200 * x0.numel()``.
    line_search : str
        Line search specifier. Currently the available options are
        {'none', 'strong_wolfe'}.
    gtol : float
        Termination tolerance on 1st-order optimality (gradient norm).
    xtol : float
        Termination tolerance on function/parameter changes.
    normp : Number or str
        The norm type to use for termination conditions. Can be any value
        supported by :func:`torch.norm`.
    callback : callable, optional
        Function to call after each iteration with the current parameter
        state, e.g. ``callback(x)``.
    disp : int or bool
        Display (verbosity) level. Set to >0 to print status messages.
    return_all : bool, optional
        Set to True to return a list of the best solution at each of the
        iterations.

    Returns
    -------
    result : OptimizeResult
        Result of the optimization routine.
    """
    return _minimize_bfgs_core(
        fun, x0, lr, low_mem=True, history_size=history_size,
        max_iter=max_iter, line_search=line_search, gtol=gtol, xtol=xtol,
        normp=normp, callback=callback, disp=disp, return_all=return_all)