In [None]:
import skimage

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
import argparse
import cv2
from scipy import io
from tqdm.notebook import tqdm
import os
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_msssim import ssim
from skimage.metrics import structural_similarity as compare_ssim
from modules import utils
from modules.models import INR

from torchsummary import summary
from modules.encoding import Encoding
import time
from encoding import MultiResHashGrid
import torch
import lpips
from pytorch_msssim import ssim

lpips_vgg_model = lpips.LPIPS(net="vgg")
lpips_alex_model = lpips.LPIPS(net="alex")

parser = argparse.ArgumentParser(description='INCODE')
# Shared Parameters
parser.add_argument('--input',type=str, default='/root/autodl-tmp/INCODE-main/0070.png', help='Input image path')
# parser.add_argument('--input',type=str, default='/home/cy/Poly/INCODE-main/middle.png', help='Input image path')
parser.add_argument('--inr_model',type=str, default='siren', help='[gauss, mfn, relu, siren, wire, wire2d, ffn, incode]')
parser.add_argument('--ffn_type',type=str, default='siren', help='[relu, siren, swish]')
parser.add_argument('--lr',type=float, default=5e-3, help='Learning rate')
parser.add_argument('--using_schedular', type=bool, default=True, help='Whether to use schedular')
parser.add_argument('--scheduler_b', type=float, default=0.1, help='Learning rate scheduler')
parser.add_argument('--maxpoints', type=int, default=128*128, help='Batch size')
parser.add_argument('--niters', type=int, default=101, help='Number if iterations')
parser.add_argument('--steps_til_summary', type=int, default=10, help='Number of steps till summary visualization')

# INCODE Parameters
parser.add_argument('--a_coef',type=float, default=0.1993, help='a coeficient')
parser.add_argument('--b_coef',type=float, default=0.0196, help='b coeficient')
parser.add_argument('--c_coef',type=float, default=0.0588, help='c coeficient')
parser.add_argument('--d_coef',type=float, default=0.0269, help='d coeficient')
parser.add_argument('--using_cosoptim', type=bool, default=False, help='Whether to use schedular')
parser.add_argument('--eta_min',type=float, default=1e-8, help='d coeficient')
parser.add_argument('--T_max',type=float, default=5001, help='d coeficient')
args = parser.parse_args(args=[])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(args.input)
## Loading Data
im = utils.normalize(plt.imread(args.input).astype(np.float32), True)
# im=cv2.resize(im, None, fx=1024/im.shape[1], fy=1024/im.shape[0], interpolation=cv2.INTER_AREA)
im = cv2.resize(im, None, fx=1/4, fy=1/4, interpolation=cv2.INTER_AREA)
H, W, _ = im.shape
print('H:',H, 'W:',W)
steps=250
per_epoch=int((H*W)/args.maxpoints)
# args.niters=int(steps/per_epoch)
print('per_epoch:',per_epoch,'args.niters:',args.niters)
# Frequency Encoding
pos_encode_freq = {'type':'frequency', 'use_nyquist': True, 'mapping_input': int(max(H, W)/3)}

# Gaussian Encoding
pos_encode_gaus = {'type':'gaussian', 'scale_B': 10, 'mapping_input': 256}

# No Encoding
pos_encode_no = {'type': None}

model_psnr={}
model_ssim={}
total_time = {}

In [None]:
# relu+hash
class SineLayer(nn.Module):
    '''
    SineLayer is a custom PyTorch module that applies the Sinusoidal activation function to the output of a linear transformation.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool, optional): If True, the linear transformation includes a bias term. Default is True.
        is_first (bool, optional): If it is the first layer, we initialize the weights differently. Default is False.
        omega_0 (float, optional): Frequency scaling factor for the sinusoidal activation. Default is 30.
        scale (float, optional): Scaling factor for the output of the sine activation. Default is 10.0.
        init_weights (bool, optional): If True, initializes the layer's weights according to the SIREN paper. Default is True.

    '''
    
    def __init__(self, in_features, out_features, bias=True,
                is_first=False, omega_0=30, scale=10.0, init_weights=True):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        # if init_weights:
        #     self.init_weights()
    
    def init_weights(self):
        # self.linear.bias.data.fill_(10)
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                            1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(1 / self.in_features), 
                                            np.sqrt(1 / self.in_features))
        
    def forward(self, input):
        return self.linear(input)

class Siren(nn.Module):
    """
        Siren activation
        https://arxiv.org/abs/2006.09661
    """

    def __init__(self, w0=1):
        """
            w0 comes from the end of section 3
            it should be 30 for the first layer
            and 1 for the rest
        """
        super().__init__()
        self.w0 = torch.tensor(w0)

    def forward(self, x):
        # return torch.sin(self.w0*(torch.abs(x)+1)*x) 
        return torch.sin(self.w0 * x) 
    def extra_repr(self):
        return "w0={}".format(self.w0)
    
class PolyReLUCode(nn.Module):
    
    def __init__(
        self,activate='ReLU',norm_type='None'
        ) -> None:
        super(PolyReLUCode, self).__init__()
        input_dim = 32
        hidden_channel=64
        
        self.positional_encoding = MultiResHashGrid(dim=2,
                                                                n_levels = 16,
                                                                n_features_per_level = 2,
                                                                log2_hashmap_size = 14,
                                                                base_resolution = 16,
                                                                finest_resolution = 256,
                                                            )
        w1=100
        w2=2
        w3=1
        w4=1
        w5=1
        w6=1
        self.linear1=SineLayer(input_dim,hidden_channel,omega_0=w1,is_first=True)
        if norm_type=='LayerNorm':
            self.norm1=nn.LayerNorm(hidden_channel)
            self.norm2=nn.LayerNorm(hidden_channel)
            self.norm3=nn.LayerNorm(hidden_channel)
            self.norm4=nn.LayerNorm(hidden_channel)
            self.norm5=nn.LayerNorm(hidden_channel)
            self.norm6=nn.LayerNorm(hidden_channel)
        elif norm_type=='BatchNorm1d':
            self.norm1=nn.BatchNorm1d(65536)
            self.norm2=nn.BatchNorm1d(65536)
            self.norm3=nn.BatchNorm1d(65536)
            self.norm4=nn.BatchNorm1d(65536)
        elif norm_type=='None':
            self.norm1=nn.Identity()
            self.norm2=nn.Identity()
            self.norm3=nn.Identity()
            
        self.linear2=SineLayer(hidden_channel,hidden_channel,omega_0=w2)
        if activate=='ReLU':
            self.nolinear1=nn.ReLU()
            self.nolinear2=nn.ReLU()
            self.nolinear3=nn.ReLU()
            self.nolinear4=nn.ReLU()
            self.nolinear5=nn.ReLU()
            self.nolinear6=nn.ReLU()
        if activate=='Siren':
            self.nolinear1=Siren(w1)
            self.nolinear2=Siren(w2)
            self.nolinear3=Siren(w3)
            self.nolinear4=Siren(w4)
            self.nolinear5=Siren(1)
            self.nolinear6=Siren(1)
        layers = []
        layers.append(SineLayer(hidden_channel, 3,is_first=True))
        # layers.append(nn.Sigmoid())
        self.layers = nn.Sequential(*layers)
    def forward(self, input):

        x = input
        x=self.positional_encoding(x)
        x = self.nolinear1(self.linear1(x))
        x = self.nolinear2(self.linear2(x))

        x = self.layers(x)
        
        return x


args.inr_model='relu'

args.lr=5e-3
model = PolyReLUCode(activate='ReLU',norm_type='LayerNorm').to(device)

# print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {num_params/1e6}(M)')


# Optimizer setup
if args.inr_model == 'wire':
    args.lr = args.lr * min(1, args.maxpoints / (H * W))
    
if args.using_cosoptim:
    optim = torch.optim.Adam(params=model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-15, weight_decay=0)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer = optim,T_max = args.T_max,eta_min=args.eta_min)
else:
    optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
    scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))


# torch.optim.lr_scheduler.StepLR(optim, step_size  = step, gamma = 0.8)

# Initialize lists for PSNR and MSE values
psnr_values = []
ssim_values = []
mse_array = torch.zeros(args.niters, device=device)

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

# Generate coordinate grid
coords = utils.get_coords(H, W, dim=2)[None, ...]

# Convert input image to a tensor and reshape
gt = torch.tensor(im).reshape(H * W, 3)[None, ...].to(device)

# Initialize a tensor for reconstructed data
rec = torch.zeros_like(gt)
cumulative_times = []
start_time = time.time()  
for step in tqdm(range(args.niters)):
    # Randomize the order of data points for each iteration
    indices = torch.randperm(H*W)

    # Process data points in batches
    for b_idx in range(0, H*W, args.maxpoints):
        # print(optim.param_groups[0]['lr'])
        b_indices = indices[b_idx:min(H*W, b_idx+args.maxpoints)]
        b_coords = coords[:, b_indices, ...].to(device)
        b_indices = b_indices.to(device)
        
        # Calculate model output
        if args.inr_model == 'incode':
            model_output, coef = model(b_coords)  
        else:
            # print(b_coords.size())
            model_output = model(b_coords) 

        # Update the reconstructed data
        with torch.no_grad():
            rec[:, b_indices, :] = model_output

        # Calculate the output loss
        output_loss = ((model_output - gt[:, b_indices, :])**2).mean()
        
        if args.inr_model == 'incode':
            # Calculate regularization loss for 'incode' model
            a_coef, b_coef, c_coef, d_coef = coef[0]  
            reg_loss = args.a_coef * torch.relu(-a_coef) + \
                    args.b_coef * torch.relu(-b_coef) + \
                    args.c_coef * torch.relu(-c_coef) + \
                    args.d_coef * torch.relu(-d_coef)

            # Total loss for 'incode' model
            loss = output_loss + reg_loss 
        else: 
            # Total loss for other models
            loss = output_loss

        # Perform backpropagation and update model parameters
        optim.zero_grad()
        loss.backward()
        optim.step()

    # Calculate PSNR
    with torch.no_grad():
        mse_array[step] = ((gt - rec)**2).mean().item()
        psnr = -10*torch.log10(mse_array[step])
        psnr_values.append(psnr.item())

    #Adjust learning rate using a scheduler if applicable
    if args.using_schedular:
        if args.inr_model == 'incode' and 30 < step:
            scheduler.step()
        else:
            scheduler.step()

    # Prepare reconstructed image for visualization
    imrec = rec[0, ...].reshape(H, W, 3).detach().cpu().numpy()
    current_total_time = time.time() - start_time
    # 将当前的累积时间添加到列表中
    cumulative_times.append(current_total_time)
    # Check if the current iteration's loss is the best so far
    if (mse_array[step] < best_loss) or (step == 0):
        best_loss = mse_array[step]
        best_img = imrec
        best_img = (best_img - best_img.min()) / (best_img.max() - best_img.min())
        
    original_img = torch.tensor(im).permute(2, 0, 1).unsqueeze(0)
    reconstruct_img = torch.tensor(best_img).permute(2, 0, 1).unsqueeze(0)
    ms_ssim = ssim(original_img, reconstruct_img, data_range=1, size_average=False)
    ssim_values.append(ms_ssim.item())
    # Display intermediate results at specified intervals
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Total Loss: {:.5f} | PSNR: {:.4f} | SSIM:{:.4f}".format(step, 
                                                                    mse_array[step].item(),
                                                                    psnr.item(),ms_ssim.item()))
        
        # Plot
        # fig, axes = plt.subplots(1, 3, figsize=(12, 12))
        # axes[0].set_title('Ground Truth')
        # axes[0].imshow(im)
        # axes[0].axis('off')
        # axes[1].set_title('Reconstructed')
        # axes[1].imshow(best_img)
        # axes[1].axis('off')
        # axes[2].set_title('error')
        # axes[2].imshow((im-best_img))
        # axes[2].axis('off')
        # plt.show()

args.inr_model='Hash_relu'
print('--------------------')
print('Max PSNR:', max(psnr_values))
print('Max SSIM:', max(ssim_values))
print('--------------------')
model_psnr[args.inr_model]=psnr_values
total_time[args.inr_model]=np.array(cumulative_times)
model_ssim[args.inr_model]=ssim_values
# 将图像转换为PyTorch的Tensor格式


# =使用LPIPS模型计算距离
vgg_distance = lpips_vgg_model(original_img, reconstruct_img)
alex_distance = lpips_alex_model(original_img, reconstruct_img)
print("VGG: LPIPS distance:", vgg_distance.item())
print("ALEX: LPIPS distance:", alex_distance.item())
print('----------------------------------\n\n')





In [None]:
plt.style.use('default')

In [None]:
best_img_save_name='/root/autodl-tmp/INCODE-main/result//'+args.inr_model+'_best_img.png'
cv2.imwrite('/root/autodl-tmp/INCODE-main/result/im.png',im[:,:, ::-1]*255)
cv2.imwrite(best_img_save_name,best_img[:,:, ::-1]*255)


image1 = cv2.imread('/root/autodl-tmp/INCODE-main/result//im.png')
image2 = cv2.imread(best_img_save_name)

# 确保图像为同一尺寸
image1 = cv2.resize(image1, (image2.shape[1], image2.shape[0]))

# 计算差异
difference = cv2.absdiff(image1, image2)

# 将差异转换为灰度图，以便更清晰地看到差异
gray_difference = cv2.cvtColor(difference, cv2.COLOR_BGR2GRAY)/255

# 显示差异
plt.figure(figsize=(12, 12))
plt.subplot(1, 3, 1), plt.imshow(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB))
plt.title('Image 1'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 2), plt.imshow(cv2.cvtColor(image2, cv2.COLOR_BGR2RGB))
plt.title('Image 2'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 3), plt.imshow(gray_difference, cmap='jet',vmin=0,vmax=0.1)
plt.title('Difference'), plt.xticks([]), plt.yticks([])

In [None]:
# Wire
### Model Configurations

args.inr_model='wire'
args.lr= 2e-3
model = INR(args.inr_model).run(in_features=2, 
                                hidden_features=370, 
                                hidden_layers=3, 
                                out_features=3, 
                                wire_type='complex',
                                outermost_linear=True, 
                                first_omega_0=20, 
                                hidden_omega_0=20, 
                                sigma=30.0,
                                pos_encode_configs=pos_encode_no
                            ).to(device)
# print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {num_params/1e6}(M)')


# Optimizer setup
# if args.inr_model == 'wire':
#     args.lr = args.lr * min(1, args.maxpoints / (H * W))
    
if args.using_cosoptim:
    optim = torch.optim.Adam(params=model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-15, weight_decay=0)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer = optim,T_max = args.T_max,eta_min=args.eta_min)
else:
    optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
    scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))
# torch.optim.lr_scheduler.StepLR(optim, step_size  = step, gamma = 0.8)

# Initialize lists for PSNR and MSE values
psnr_values = []
ssim_values = []
mse_array = torch.zeros(args.niters, device=device)

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

# Generate coordinate grid
coords = utils.get_coords(H, W, dim=2)[None, ...]

# Convert input image to a tensor and reshape
gt = torch.tensor(im).reshape(H * W, 3)[None, ...].to(device)

# Initialize a tensor for reconstructed data
rec = torch.zeros_like(gt)
cumulative_times = []
start_time = time.time()  
for step in tqdm(range(args.niters)):
    # Randomize the order of data points for each iteration
    indices = torch.randperm(H*W)

    # Process data points in batches
    for b_idx in range(0, H*W, args.maxpoints):
        # print(optim.param_groups[0]['lr'])
        b_indices = indices[b_idx:min(H*W, b_idx+args.maxpoints)]
        b_coords = coords[:, b_indices, ...].to(device)
        b_indices = b_indices.to(device)
        
        # Calculate model output
        if args.inr_model == 'incode':
            model_output, coef = model(b_coords)  
        else:
            # print(b_coords.size())
            model_output = model(b_coords) 

        # Update the reconstructed data
        with torch.no_grad():
            rec[:, b_indices, :] = model_output

        # Calculate the output loss
        output_loss = ((model_output - gt[:, b_indices, :])**2).mean()
        
        if args.inr_model == 'incode':
            # Calculate regularization loss for 'incode' model
            a_coef, b_coef, c_coef, d_coef = coef[0]  
            reg_loss = args.a_coef * torch.relu(-a_coef) + \
                    args.b_coef * torch.relu(-b_coef) + \
                    args.c_coef * torch.relu(-c_coef) + \
                    args.d_coef * torch.relu(-d_coef)

            # Total loss for 'incode' model
            loss = output_loss + reg_loss 
        else: 
            # Total loss for other models
            loss = output_loss

        # Perform backpropagation and update model parameters
        optim.zero_grad()
        loss.backward()
        optim.step()

    # Calculate PSNR
    with torch.no_grad():
        mse_array[step] = ((gt - rec)**2).mean().item()
        psnr = -10*torch.log10(mse_array[step])
        psnr_values.append(psnr.item())

    #Adjust learning rate using a scheduler if applicable
    if args.using_schedular:
        if args.inr_model == 'incode' and 30 < step:
            scheduler.step()
        else:
            scheduler.step()

    # Prepare reconstructed image for visualization
    imrec = rec[0, ...].reshape(H, W, 3).detach().cpu().numpy()
    current_total_time = time.time() - start_time
    # 将当前的累积时间添加到列表中
    cumulative_times.append(current_total_time)
    # Check if the current iteration's loss is the best so far
    if (mse_array[step] < best_loss) or (step == 0):
        best_loss = mse_array[step]
        best_img = imrec
        best_img = (best_img - best_img.min()) / (best_img.max() - best_img.min())
        
    original_img = torch.tensor(im).permute(2, 0, 1).unsqueeze(0)
    reconstruct_img = torch.tensor(best_img).permute(2, 0, 1).unsqueeze(0)
    ms_ssim = ssim(original_img, reconstruct_img, data_range=1, size_average=False)
    ssim_values.append(ms_ssim.item())
    # Display intermediate results at specified intervals
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Total Loss: {:.5f} | PSNR: {:.4f} | SSIM:{:.4f}".format(step, 
                                                                    mse_array[step].item(),
                                                                    psnr.item(),ms_ssim.item()))
        
        # # Plot
        # fig, axes = plt.subplots(1, 2, figsize=(12, 12))
        # axes[0].set_title('Ground Truth')
        # axes[0].imshow(im)
        # axes[0].axis('off')
        # axes[1].set_title('Reconstructed')
        # axes[1].imshow(best_img)
        # axes[1].axis('off')
        # plt.show()

# Print maximum PSNR achieved during training
print('--------------------')
print('Max PSNR:', max(psnr_values))
print('Max SSIM:', max(ssim_values))
print('--------------------')
model_psnr[args.inr_model]=psnr_values
total_time[args.inr_model]=np.array(cumulative_times)
model_ssim[args.inr_model]=ssim_values
# 将图像转换为PyTorch的Tensor格式


# =使用LPIPS模型计算距离
vgg_distance = lpips_vgg_model(original_img, reconstruct_img)
alex_distance = lpips_alex_model(original_img, reconstruct_img)
print("VGG: LPIPS distance:", vgg_distance.item())
print("ALEX: LPIPS distance:", alex_distance.item())
args.inr_model='wire'
model_psnr[args.inr_model]=psnr_values
total_time[args.inr_model]=np.array(cumulative_times)
model_ssim[args.inr_model]=ssim_values
# 将图像转换为PyTorch的Tensor格式


# =使用LPIPS模型计算距离
vgg_distance = lpips_vgg_model(original_img, reconstruct_img)
alex_distance = lpips_alex_model(original_img, reconstruct_img)
print("VGG: LPIPS distance:", vgg_distance.item())
print("ALEX: LPIPS distance:", alex_distance.item())
print('----------------------------------\n\n')

In [None]:
best_img_save_name='/root/autodl-tmp/INCODE-main/result//'+args.inr_model+'_best_img.png'
cv2.imwrite('/root/autodl-tmp/INCODE-main/result/im.png',im[:,:, ::-1]*255)
cv2.imwrite(best_img_save_name,best_img[:,:, ::-1]*255)


image1 = cv2.imread('/root/autodl-tmp/INCODE-main/result//im.png')
image2 = cv2.imread(best_img_save_name)

# 确保图像为同一尺寸
image1 = cv2.resize(image1, (image2.shape[1], image2.shape[0]))

# 计算差异
difference = cv2.absdiff(image1, image2)

# 将差异转换为灰度图，以便更清晰地看到差异
gray_difference = cv2.cvtColor(difference, cv2.COLOR_BGR2GRAY)/255

# 显示差异
plt.figure(figsize=(12, 12))
plt.subplot(1, 3, 1), plt.imshow(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB))
plt.title('Image 1'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 2), plt.imshow(cv2.cvtColor(image2, cv2.COLOR_BGR2RGB))
plt.title('Image 2'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 3), plt.imshow(gray_difference, cmap='jet',vmin=0,vmax=0.1)
plt.title('Difference'), plt.xticks([]), plt.yticks([])

In [None]:
#siren
### Model Configurations
args.inr_model='siren'
args.lr= 1e-3
model = INR(args.inr_model).run(in_features=2,
                                out_features=3, 
                                hidden_features=256,
                                hidden_layers=3,
                                first_omega_0=30.0,
                                hidden_omega_0=30.0,
                                pos_encode_configs=pos_encode_no
                            ).to(device)
# print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {num_params/1e6}(M)')


# Optimizer setup
if args.inr_model == 'wire':
    args.lr = args.lr * min(1, args.maxpoints / (H * W))
    
if args.using_cosoptim:
    optim = torch.optim.Adam(params=model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-15, weight_decay=0)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer = optim,T_max = args.T_max,eta_min=args.eta_min)
else:
    optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
    scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))
# torch.optim.lr_scheduler.StepLR(optim, step_size  = step, gamma = 0.8)

# Initialize lists for PSNR and MSE values
psnr_values = []
ssim_values = []
mse_array = torch.zeros(args.niters, device=device)

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

# Generate coordinate grid
coords = utils.get_coords(H, W, dim=2)[None, ...]

# Convert input image to a tensor and reshape
gt = torch.tensor(im).reshape(H * W, 3)[None, ...].to(device)

# Initialize a tensor for reconstructed data
rec = torch.zeros_like(gt)
cumulative_times = []
start_time = time.time()  
for step in tqdm(range(args.niters)):
    # Randomize the order of data points for each iteration
    indices = torch.randperm(H*W)

    # Process data points in batches
    for b_idx in range(0, H*W, args.maxpoints):
        # print(optim.param_groups[0]['lr'])
        b_indices = indices[b_idx:min(H*W, b_idx+args.maxpoints)]
        b_coords = coords[:, b_indices, ...].to(device)
        b_indices = b_indices.to(device)
        
        # Calculate model output
        if args.inr_model == 'incode':
            model_output, coef = model(b_coords)  
        else:
            # print(b_coords.size())
            model_output = model(b_coords) 

        # Update the reconstructed data
        with torch.no_grad():
            rec[:, b_indices, :] = model_output

        # Calculate the output loss
        output_loss = ((model_output - gt[:, b_indices, :])**2).mean()
        
        if args.inr_model == 'incode':
            # Calculate regularization loss for 'incode' model
            a_coef, b_coef, c_coef, d_coef = coef[0]  
            reg_loss = args.a_coef * torch.relu(-a_coef) + \
                    args.b_coef * torch.relu(-b_coef) + \
                    args.c_coef * torch.relu(-c_coef) + \
                    args.d_coef * torch.relu(-d_coef)

            # Total loss for 'incode' model
            loss = output_loss + reg_loss 
        else: 
            # Total loss for other models
            loss = output_loss

        # Perform backpropagation and update model parameters
        optim.zero_grad()
        loss.backward()
        optim.step()

    # Calculate PSNR
    with torch.no_grad():
        mse_array[step] = ((gt - rec)**2).mean().item()
        psnr = -10*torch.log10(mse_array[step])
        psnr_values.append(psnr.item())

    #Adjust learning rate using a scheduler if applicable
    if args.using_schedular:
        if args.inr_model == 'incode' and 30 < step:
            scheduler.step()
        else:
            scheduler.step()

    # Prepare reconstructed image for visualization
    imrec = rec[0, ...].reshape(H, W, 3).detach().cpu().numpy()
    current_total_time = time.time() - start_time
    # 将当前的累积时间添加到列表中
    cumulative_times.append(current_total_time)
    # Check if the current iteration's loss is the best so far
    if (mse_array[step] < best_loss) or (step == 0):
        best_loss = mse_array[step]
        best_img = imrec
        best_img = (best_img - best_img.min()) / (best_img.max() - best_img.min())
        
    original_img = torch.tensor(im).permute(2, 0, 1).unsqueeze(0)
    reconstruct_img = torch.tensor(best_img).permute(2, 0, 1).unsqueeze(0)
    ms_ssim = ssim(original_img, reconstruct_img, data_range=1, size_average=False)
    ssim_values.append(ms_ssim.item())
    # Display intermediate results at specified intervals
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Total Loss: {:.5f} | PSNR: {:.4f} | SSIM:{:.4f}".format(step, 
                                                                    mse_array[step].item(),
                                                                    psnr.item(),ms_ssim.item()))
        
        # # Plot
        # fig, axes = plt.subplots(1, 2, figsize=(12, 12))
        # axes[0].set_title('Ground Truth')
        # axes[0].imshow(im)
        # axes[0].axis('off')
        # axes[1].set_title('Reconstructed')
        # axes[1].imshow(best_img)
        # axes[1].axis('off')
        # plt.show()

# Print maximum PSNR achieved during training
print('--------------------')
print('Max PSNR:', max(psnr_values))
print('Max SSIM:', max(ssim_values))
print('--------------------')
args.inr_model='siren'
model_psnr[args.inr_model]=psnr_values
total_time[args.inr_model]=np.array(cumulative_times)
model_ssim[args.inr_model]=ssim_values
# 将图像转换为PyTorch的Tensor格式


# =使用LPIPS模型计算距离
vgg_distance = lpips_vgg_model(original_img, reconstruct_img)
alex_distance = lpips_alex_model(original_img, reconstruct_img)
print("VGG: LPIPS distance:", vgg_distance.item())
print("ALEX: LPIPS distance:", alex_distance.item())




In [None]:
best_img_save_name='/root/autodl-tmp/INCODE-main/result//'+args.inr_model+'_best_img.png'
cv2.imwrite('/root/autodl-tmp/INCODE-main/result/im.png',im[:,:, ::-1]*255)
cv2.imwrite(best_img_save_name,best_img[:,:, ::-1]*255)


image1 = cv2.imread('/root/autodl-tmp/INCODE-main/result//im.png')
image2 = cv2.imread(best_img_save_name)

# 确保图像为同一尺寸
image1 = cv2.resize(image1, (image2.shape[1], image2.shape[0]))

# 计算差异
difference = cv2.absdiff(image1, image2)

# 将差异转换为灰度图，以便更清晰地看到差异
gray_difference = cv2.cvtColor(difference, cv2.COLOR_BGR2GRAY)/255

# 显示差异
plt.figure(figsize=(12, 12))
plt.subplot(1, 3, 1), plt.imshow(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB))
plt.title('Image 1'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 2), plt.imshow(cv2.cvtColor(image2, cv2.COLOR_BGR2RGB))
plt.title('Image 2'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 3), plt.imshow(gray_difference, cmap='jet',vmin=0,vmax=0.1)
plt.title('Difference'), plt.xticks([]), plt.yticks([])

In [None]:
# HO+siren
# poly_siren
class SineLayer(nn.Module):
    '''
    SineLayer is a custom PyTorch module that applies the Sinusoidal activation function to the output of a linear transformation.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool, optional): If True, the linear transformation includes a bias term. Default is True.
        is_first (bool, optional): If it is the first layer, we initialize the weights differently. Default is False.
        omega_0 (float, optional): Frequency scaling factor for the sinusoidal activation. Default is 30.
        scale (float, optional): Scaling factor for the output of the sine activation. Default is 10.0.
        init_weights (bool, optional): If True, initializes the layer's weights according to the SIREN paper. Default is True.

    '''
    
    def __init__(self, in_features, out_features, bias=True,
                is_first=False, omega_0=30, scale=10.0, init_weights=True):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        if init_weights:
            self.init_weights()
    
    def init_weights(self):
        # self.linear.bias.data.fill_(10)
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                            1 / self.in_features)  
                # self.linear.bias.data.uniform_(-0.8,0.8)    
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                            np.sqrt(6 / self.in_features) / self.omega_0)
                # self.linear.bias.data.uniform_(-8,8)
        
    def forward(self, input):
        return self.linear(input)

class Siren(nn.Module):
    """
        Siren activation
        https://arxiv.org/abs/2006.09661
    """

    def __init__(self, w0=1):
        """
            w0 comes from the end of section 3
            it should be 30 for the first layer
            and 1 for the rest
        """
        super().__init__()
        self.w0 = torch.tensor(w0)

    def forward(self, x):
        # return torch.sin(self.w0*(torch.abs(x)+1)*x) 
        return torch.sin(self.w0 * x) 
    def extra_repr(self):
        return "w0={}".format(self.w0)
    
    

class PolySiren(nn.Module):
    
    def __init__(
        self,activate='ReLU',norm_type='None'
        ) -> None:
        super(PolySiren, self).__init__()
        input_dim = 2

        input_dim = 2 
        hidden_channel=256
        
        
        w1=100
        w2=2
        w3=1
        w4=1
        
        
        
        self.linear1=SineLayer(input_dim,hidden_channel,omega_0=w1,is_first=True)
        if norm_type=='LayerNorm':
            # self.norm1=nn.LayerNorm(hidden_channel)
            self.norm2=nn.LayerNorm(hidden_channel)
            self.norm3=nn.LayerNorm(hidden_channel)
            self.norm4=nn.LayerNorm(hidden_channel)
            # self.norm5=nn.LayerNorm(hidden_channel)
            # self.norm6=nn.LayerNorm(hidden_channel)
        elif norm_type=='BatchNorm1d':
            self.norm1=nn.BatchNorm1d(65536)
            self.norm2=nn.BatchNorm1d(65536)
            self.norm3=nn.BatchNorm1d(65536)
            self.norm4=nn.BatchNorm1d(65536)
        elif norm_type=='None':
            self.norm1=nn.Identity()
            self.norm2=nn.Identity()
            self.norm3=nn.Identity()
            
            
            
        self.linear2=SineLayer(hidden_channel,hidden_channel,omega_0=w2)
        self.linear3=SineLayer(hidden_channel,hidden_channel,omega_0=w3)
        self.linear4=SineLayer(hidden_channel,hidden_channel,omega_0=w4)
        if activate=='ReLU':
            self.nolinear1=nn.ReLU()
            self.nolinear2=nn.ReLU()
            self.nolinear3=nn.ReLU()
        if activate=='Siren':
            self.nolinear1=Siren(w1)
            self.nolinear2=Siren(w2)
            self.nolinear3=Siren(w3)
            self.nolinear4=Siren(w4)
            self.nolinear5=Siren(1)
            self.nolinear6=Siren(1)
        layers = []
        layers.append(SineLayer(hidden_channel, 3,is_first=True))
        layers.append(nn.Sigmoid())
        self.layers = nn.Sequential(*layers)
    def forward(self, input):

        x = input

        
        x = self.nolinear1(self.linear1(x))
        x = self.nolinear2(self.norm2(x+x*self.linear2(x)))
        x = self.nolinear3(self.norm3(x+x*self.linear3(x)))
        x = self.nolinear4(self.norm4(x+x*self.linear4(x)))
        x = self.layers(x)
        
        return x


args.inr_model ='siren'
args.lr=6e-3
model = PolySiren(activate='Siren',norm_type='LayerNorm').to(device)

# print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {num_params/1e6}(M)')



# Optimizer setup
if args.inr_model == 'wire':
    args.lr = args.lr * min(1, args.maxpoints / (H * W))
    
if args.using_cosoptim:
    optim = torch.optim.Adam(params=model.parameters(), lr=args.lr, betas=(0.9, 0.99), eps=1e-15, weight_decay=0)
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer = optim,T_max = args.T_max,eta_min=args.eta_min)
else:
    optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
    scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))
# torch.optim.lr_scheduler.StepLR(optim, step_size  = step, gamma = 0.8)

# Initialize lists for PSNR and MSE values
psnr_values = []
ssim_values = []
mse_array = torch.zeros(args.niters, device=device)

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

# Generate coordinate grid
coords = utils.get_coords(H, W, dim=2)[None, ...]

# Convert input image to a tensor and reshape
gt = torch.tensor(im).reshape(H * W, 3)[None, ...].to(device)

# Initialize a tensor for reconstructed data
rec = torch.zeros_like(gt)
cumulative_times = []
start_time = time.time()  
for step in tqdm(range(args.niters)):
    # Randomize the order of data points for each iteration
    indices = torch.randperm(H*W)

    # Process data points in batches
    for b_idx in range(0, H*W, args.maxpoints):
        # print(optim.param_groups[0]['lr'])
        b_indices = indices[b_idx:min(H*W, b_idx+args.maxpoints)]
        b_coords = coords[:, b_indices, ...].to(device)
        b_indices = b_indices.to(device)
        
        # Calculate model output
        if args.inr_model == 'incode':
            model_output, coef = model(b_coords)  
        else:
            # print(b_coords.size())
            model_output = model(b_coords) 

        # Update the reconstructed data
        with torch.no_grad():
            rec[:, b_indices, :] = model_output

        # Calculate the output loss
        output_loss = ((model_output - gt[:, b_indices, :])**2).mean()
        
        if args.inr_model == 'incode':
            # Calculate regularization loss for 'incode' model
            a_coef, b_coef, c_coef, d_coef = coef[0]  
            reg_loss = args.a_coef * torch.relu(-a_coef) + \
                    args.b_coef * torch.relu(-b_coef) + \
                    args.c_coef * torch.relu(-c_coef) + \
                    args.d_coef * torch.relu(-d_coef)

            # Total loss for 'incode' model
            loss = output_loss + reg_loss 
        else: 
            # Total loss for other models
            loss = output_loss

        # Perform backpropagation and update model parameters
        optim.zero_grad()
        loss.backward()
        optim.step()

    # Calculate PSNR
    with torch.no_grad():
        mse_array[step] = ((gt - rec)**2).mean().item()
        psnr = -10*torch.log10(mse_array[step])
        psnr_values.append(psnr.item())

    #Adjust learning rate using a scheduler if applicable
    if args.using_schedular:
        if args.inr_model == 'incode' and 30 < step:
            scheduler.step()
        else:
            scheduler.step()

    # Prepare reconstructed image for visualization
    imrec = rec[0, ...].reshape(H, W, 3).detach().cpu().numpy()
    current_total_time = time.time() - start_time
    # 将当前的累积时间添加到列表中
    cumulative_times.append(current_total_time)
    # Check if the current iteration's loss is the best so far
    if (mse_array[step] < best_loss) or (step == 0):
        best_loss = mse_array[step]
        best_img = imrec
        best_img = (best_img - best_img.min()) / (best_img.max() - best_img.min())
        
    original_img = torch.tensor(im).permute(2, 0, 1).unsqueeze(0)
    reconstruct_img = torch.tensor(best_img).permute(2, 0, 1).unsqueeze(0)
    ms_ssim = ssim(original_img, reconstruct_img, data_range=1, size_average=False)
    ssim_values.append(ms_ssim.item())
    # Display intermediate results at specified intervals
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Total Loss: {:.5f} | PSNR: {:.4f} | SSIM:{:.4f}".format(step, 
                                                                    mse_array[step].item(),
                                                                    psnr.item(),ms_ssim.item()))
        
        # # Plot
        # fig, axes = plt.subplots(1, 2, figsize=(12, 12))
        # axes[0].set_title('Ground Truth')
        # axes[0].imshow(im)
        # axes[0].axis('off')
        # axes[1].set_title('Reconstructed')
        # axes[1].imshow(best_img)
        # axes[1].axis('off')
        # plt.show()

# Print maximum PSNR achieved during training
args.inr_model='HO_siren'
print('--------------------')
print('Max PSNR:', max(psnr_values))
print('Max SSIM:', max(ssim_values))
print('--------------------')


In [None]:
best_img_save_name='/root/autodl-tmp/INCODE-main/result//'+args.inr_model+'_best_img.png'
cv2.imwrite('/root/autodl-tmp/INCODE-main/result/im.png',im[:,:, ::-1]*255)
cv2.imwrite(best_img_save_name,best_img[:,:, ::-1]*255)


image1 = cv2.imread('/root/autodl-tmp/INCODE-main/result//im.png')
image2 = cv2.imread(best_img_save_name)

# 确保图像为同一尺寸
image1 = cv2.resize(image1, (image2.shape[1], image2.shape[0]))

# 计算差异
difference = cv2.absdiff(image1, image2)

# 将差异转换为灰度图，以便更清晰地看到差异
gray_difference = cv2.cvtColor(difference, cv2.COLOR_BGR2GRAY)/255

# 显示差异
plt.figure(figsize=(12, 12))
plt.subplot(1, 3, 1), plt.imshow(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB))
plt.title('Image 1'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 2), plt.imshow(cv2.cvtColor(image2, cv2.COLOR_BGR2RGB))
plt.title('Image 2'), plt.xticks([]), plt.yticks([])

plt.subplot(1, 3, 3), plt.imshow(gray_difference, cmap='jet',vmin=0,vmax=0.1)
plt.title('Difference'), plt.xticks([]), plt.yticks([])