In [2]:
import skimage
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
import argparse
import cv2
from scipy import io
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from INCODE_modules import utils
from modules import models
!pip install -q pytorch-msssim # Install this library if error pops up (Uncomment and run)
from pytorch_msssim import ssim
from modules.models1 import INR #the inr in model configuration wasnt working properly so i added this -SB

# work_dir = "./INCODE_results/Imagers/"
torch.cuda.set_device(0)

/content/drive/My Drive/INR


In [3]:
parser = argparse.ArgumentParser(description='INCODE')

run_name = f"Band_test_1"
image = "kodim20"

# Shared Parameters
parser.add_argument('--input',type=str, default=f"Data/images/kodak/{image}.png", help='Input image path')
parser.add_argument('--inr_model',type=str, default='BandRC', help='[gauss, mfn, relu, siren, wire, wire2d, ffn, incode]')
parser.add_argument('--lr',type=float, default=9e-4, help='Learning rate')
parser.add_argument('--using_schedular', type=bool, default=True, help='Whether to use schedular')
parser.add_argument('--scheduler_b', type=float, default=0.1, help='Learning rate scheduler')
parser.add_argument('--maxpoints', type=int, default=256*256, help='Batch size')
parser.add_argument('--niters', type=int, default=500, help='Number if iterations')
parser.add_argument('--steps_til_summary', type=int, default=50, help='Number of steps till summary visualization')

# INCODE Parameters
parser.add_argument('--a_coef',type=float, default=0.1993, help='a coeficient')
parser.add_argument('--b_coef',type=float, default=0.0196, help='b coeficient')
parser.add_argument('--c_coef',type=float, default=0.0588, help='c coeficient')
parser.add_argument('--d_coef',type=float, default=0.0269, help='d coeficient')


args = parser.parse_args(args=[])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
time_array = torch.zeros(args.niters, device=device)

In [4]:
save_name = f"{run_name}_{args.inr_model}"
save_path = f"Results/{save_name}"
os.makedirs(save_path, exist_ok=True)

## Loading Data

In [6]:
im_RGB_gt = utils.normalize(plt.imread(args.input).astype(np.float32), True)
im = im_RGB_gt
H, W, _ = im.shape

## Defining Positional Encodings

In [9]:
# Frequency Encoding
pos_encode_freq = {'type':'frequency', 'use_nyquist': True, 'mapping_input': int(max(H, W)/3)}

# Gaussian Encoding
pos_encode_gaus = {'type':'gaussian', 'scale_B': 10, 'mapping_input': 256}

# No Encoding
pos_encode_no = {'type': None}

### Model Configureations

In [10]:
### Harmonizer Configurations
act_prams=2
hidden_lay=3
MLP_configs={'task': 'image',
             'model': 'resnet34',
             'truncated_layer':5,
             'in_channels': 64,
             #'hidden_channels': [64, 32, 4],
             'hidden_channels': [64, 32, act_prams*(hidden_lay+1)],
             'mlp_bias':0.3120,
             'activation_layer': nn.SiLU,
             'GT': torch.tensor(im).to(device)[None,...].permute(0, 3, 1, 2),
             'T_range': [0, 10],
             'c_range': [0, 3],
            }#d:\Downloads\siren.py

### Model Configurations
model =INR(args.inr_model).run(in_features=2,
                                out_features=3,
                                hidden_features=256,
                                hidden_layers=3,
                                #first_omega_0=30.0,
                                #hidden_omega_0=30.0,
                                #pos_encode_configs=pos_encode_freq,
                                #ffn_type='relu',
                                MLP_configs = MLP_configs
                               ).to(device)



Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 198MB/s]


## Training Code

In [11]:
init_time = time.time() #timearray setup

# Optimizer setup
if args.inr_model == 'wire' or args.inr_model == 'BandRC':
    args.lr = args.lr * min(1, args.maxpoints / (H * W))
optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))

# Initialize lists for PSNR and MSE values
psnr_values = []
mse_array = torch.zeros(args.niters, device=device)

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

# Generate coordinate grid
coords = utils.get_coords(H, W, dim=2)[None, ...]

# Convert input image to a tensor and reshape
gt = torch.tensor(im).reshape(H * W, 3)[None, ...].to(device)

# Initialize a tensor for reconstructed data
rec = torch.zeros_like(gt)

In [None]:
for step in tqdm(range(args.niters)):
    # Randomize the order of data points for each iteration
    indices = torch.randperm(H*W)

    # Process data points in batches
    for b_idx in range(0, H*W, args.maxpoints):
        b_indices = indices[b_idx:min(H*W, b_idx+args.maxpoints)]
        b_coords = coords[:, b_indices, ...].to(device)
        b_indices = b_indices.to(device)

        # Calculate model output
        if args.inr_model == 'incode' :
            model_output, coef = model(b_coords)
        else:
            model_output = model(b_coords)

        # Update the reconstructed data
        with torch.no_grad():
            rec[:, b_indices, :] = model_output

        # Calculate the output loss
        output_loss = ((model_output - gt[:, b_indices, :])**2).mean()

        if args.inr_model == 'incode':
            # Calculate regularization loss for 'incode' model
            a_coef, b_coef, c_coef, d_coef = coef[0]
            reg_loss = args.a_coef * torch.relu(-a_coef) + \
                       args.b_coef * torch.relu(-b_coef) + \
                       args.c_coef * torch.relu(-c_coef) + \
                       args.d_coef * torch.relu(-d_coef)

            # Total loss for 'incode' model
            loss = output_loss + reg_loss
        else:
            # Total loss for other models
            loss = output_loss

        # Perform backpropagation and update model parameters
        optim.zero_grad()
        loss.backward()
        optim.step()


    time_array[step] = time.time() - init_time #timearray setup (step )
    # Calculate PSNR
    with torch.no_grad():
        mse_array[step] = ((gt - rec)**2).mean().item()
        psnr = -10*torch.log10(mse_array[step])
        psnr_values.append(psnr.item())

    # Adjust learning rate using a scheduler if applicable
    if args.using_schedular:
        if args.inr_model == 'incode' and 30 < step:
            scheduler.step()
        else:
            scheduler.step()

    # Prepare reconstructed image for visualization
    imrec = rec[0, ...].reshape(H, W, 3).detach().cpu().numpy()

    # Check if the current iteration's loss is the best so far
    if (mse_array[step] < best_loss) or (step == 0):
        best_loss = mse_array[step]
        best_flat_img = rec
        best_img = imrec
        # best_img = (best_img - best_img.min()) / (best_img.max() - best_img.min())
        best_epoch=step

    # Display intermediate results at specified intervals
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Total Loss: {:.5f} | PSNR: {:.4f}".format(step,
                                                                     mse_array[step].item(),
                                                                     psnr.item()))

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0 | Total Loss: 0.12861 | PSNR: 8.9073
Epoch: 50 | Total Loss: 0.00256 | PSNR: 25.9241
Epoch: 100 | Total Loss: 0.00174 | PSNR: 27.5874
Epoch: 150 | Total Loss: 0.00139 | PSNR: 28.5572
Epoch: 200 | Total Loss: 0.00129 | PSNR: 28.8804
Epoch: 250 | Total Loss: 0.00116 | PSNR: 29.3555


In [None]:
print(f"Final PSNR: {psnr_values[-1]}")

In [None]:
def get_np_psnr(image1, image2):
  loss = ((image1.astype(np.float32) - image2.astype(np.float32))**2).mean()
  return -10*np.log10(loss)

def get_np_loss(image1, image2):
  loss = ((image1.astype(np.float32) - image2.astype(np.float32))**2).mean()
  return loss

In [None]:
rec_loss = get_np_loss(best_img, im)
rec_psnr = get_np_psnr(best_img, im)
print(rec_loss, rec_psnr)

In [None]:
weight_path = os.path.join(save_path, f"iters_{len(psnr_values)}_psnr_{max(psnr_values):.2f}.pth")
torch.save({'state_dict': model.state_dict(),
            'best_epoch': best_epoch,
            'gt': im_RGB_gt,
            'rec': best_img,
            #'consts_array': np.array(consts),
            'time_array': time_array.detach().cpu().numpy(),
            'mse_array': mse_array.detach().cpu().numpy(),
            'psnr_vals_hsv_domain': np.array(psnr_values)
            }, weight_path)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(10, 10))
axes[0].set_title('Ground Truth')
axes[0].imshow(im_RGB_gt)
axes[0].axis('off')
axes[1].set_title(f"RGB_PSNR= {rec_psnr:.3f} ")
axes[1].imshow(best_img)
axes[1].axis('off')

plt.savefig(f"{save_path}/comparison.png")
# iters_{len(psnr_values)}_psnr_{max(psnr_values):.2f}_

plt.show()


# Print maximum PSNR achieved during training
print('--------------------')
print('Max PSNR:', max(psnr_values))
print('--------------------')
print(best_epoch)

# Convergance Rate

In [None]:
font = {'font': 'serif', 'size': 12}

plt.figure()
axfont = {'family' : 'serif', 'weight' : 'regular', 'size'   : 10}
plt.rc('font', **axfont)

plt.plot(np.arange(len(psnr_values[:-1])), psnr_values[:-1], label = f"{(args.inr_model).upper()}")
plt.xlabel('# Epochs', fontdict=font)
plt.ylabel('PSNR (dB)', fontdict=font)
plt.title('Image Representation', fontdict={'family': 'serif', 'size': 12, 'weight': 'bold'})
plt.legend()
plt.grid(True, color='lightgray')

save_file = os.path.join(save_path, f'iters_{len(psnr_values)}_psnr_{max(psnr_values):.2f}_mse_psnr_plot.png')
plt.tight_layout()
plt.savefig(save_file)

plt.show()

In [None]:
# Assuming mse_array and psnr_array are torch tensors
MSE_ARRAY = mse_array.detach().cpu().numpy()
PSNR_ARRAY = psnr_values

# Create figure and subplots
plt.figure(figsize=(12, 5))

# --- MSE subplot ---
plt.subplot(1, 2, 1)
plt.plot(MSE_ARRAY, marker='o', linestyle='-', color='blue')
plt.title("MSE Plot")
plt.xlabel("Epoch")
plt.ylabel("MSE Value")
plt.grid(True)

# --- PSNR subplot ---
plt.subplot(1, 2, 2)
plt.plot(PSNR_ARRAY, marker='s',linestyle='-' , color='green')
plt.title("PSNR Plot")
plt.xlabel("Epoch")
plt.ylabel("PSNR (dB)")
plt.grid(True)

# Save plot
save_file = os.path.join(save_path, f'iters_{len(psnr_values)}_psnr_{max(psnr_values):.2f}_psnr_plot.png')
plt.tight_layout()
plt.savefig(save_file)
print(f"Plot saved to: {save_file}")

# Show plot
plt.show()