Import pakages

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import random
import copy
import piq  # PSNR SSIM LIPIS
from torchsummary import summary

Seed setting, early stopping and data extraction

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

class EarlyStopping:
    def __init__(self, patience=2000, delta=0.000001, stop_threshold=0.000009):
        self.patience = patience
        self.delta = delta
        self.best_loss = float('inf')
        self.counter = 0
        self.stop_threshold = stop_threshold

    def __call__(self, current_loss):
        if current_loss < self.stop_threshold:
            print(f"Loss reached threshold {self.stop_threshold}, early stopping")
            return True
        if current_loss < self.best_loss - self.delta:
            self.best_loss = current_loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
            return False
def extract_patch_data(patch_img, num_freq=10):
    patch_np = np.array(patch_img)
    h, w, _ = patch_np.shape
    coords = []
    for y in range(h):
        for x in range(w):
            # normalization
            x_norm = (x / (w-1)) * 2 - 1
            y_norm = (y / (h-1)) * 2 - 1
            coords.append([x_norm, y_norm])
    return np.array(coords), patch_np.reshape(-1, 3)/255.0, h, w

def generate_patch(model, device, actual_size, num_freq=10):
    h, w = actual_size
    coords = []
    for y in range(h):
        for x in range(w):
            x_norm = (x / (w-1)) * 2 - 1
            y_norm = (y / (h-1)) * 2 - 1
            coords.append([x_norm, y_norm])

    coords_tensor = torch.tensor(coords, dtype=torch.float32).to(device)
    with torch.no_grad():
        rgb = model(coords_tensor).cpu().numpy()

    return (np.clip(rgb, 0, 1)*255).astype(np.uint8).reshape(h, w, 3)

MLP

In [None]:
def positional_encoding(x, y, num_frequencies=10):
    frequencies = 2 ** torch.arange(num_frequencies, dtype=torch.float32)
    x_enc = torch.cat([torch.sin(frequencies * x), torch.cos(frequencies * x)])
    y_enc = torch.cat([torch.sin(frequencies * y), torch.cos(frequencies * y)])
    return torch.cat([x_enc, y_enc])

class MLP(nn.Module):
    def __init__(self, input_dim=40, hidden_dims=[32]*4+[16], output_dim=3):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims + [output_dim]
        for i in range(len(dims)-1):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            if i < len(dims)-2:
                layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return torch.sigmoid(self.net(x))

SIREN


In [None]:
class SineLayer(nn.Module):
    def __init__(self, in_features, out_features, is_first=False, omega_0=30.0):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        self.linear = nn.Linear(in_features, out_features)

        # 修正初始化策略
        with torch.no_grad():
            if self.is_first:
                bound = 1 / in_features
            else:
                bound = np.sqrt(6 / in_features) / omega_0
            self.linear.weight.uniform_(-bound, bound)
            if self.linear.bias is not None:
                self.linear.bias.uniform_(-bound, bound)

    def forward(self, input):
        return torch.sin(self.omega_0 * self.linear(input))

class Siren(nn.Module):
    def __init__(self, in_features=2, hidden_features=32, hidden_layers=5, out_features=3):
        super().__init__()
        self.net = nn.ModuleList()

        # first
        self.net.append(SineLayer(in_features, hidden_features, is_first=True))

        # hidden
        for _ in range(hidden_layers-2):
            self.net.append(SineLayer(hidden_features, hidden_features))
        self.net.append(SineLayer(hidden_features, 16))

        # output
        self.final_layer = nn.Sequential(
            nn.Linear(16, out_features),
            nn.Sigmoid()
        )
        # output initial
        with torch.no_grad():
            self.final_layer[0].weight.uniform_(-np.sqrt(6/hidden_features),
                                             np.sqrt(6/hidden_features))

    def forward(self, coords):
        x = coords
        for layer in self.net:
            x = layer(x)
        return self.final_layer(x)

FilmSIREN

In [None]:

def frequency_init(freq):
    def init(m):
        with torch.no_grad():
            if isinstance(m, nn.Linear):
                num_input = m.weight.size(-1)
                m.weight.uniform_(-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq)
    return init


class FiLMLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.layer = nn.Linear(input_dim, hidden_dim)

    def forward(self, x, freq, phase_shift):
        x = self.layer(x)
        return torch.sin(freq * x + phase_shift)

class CustomMappingNetwork(nn.Module):
    def __init__(self, z_dim, map_hidden_dim, map_output_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(z_dim, map_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(map_hidden_dim, map_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(map_hidden_dim, map_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(map_hidden_dim, map_output_dim)
        )
        self.network.apply(frequency_init(25))
        with torch.no_grad():
            self.network[-1].weight *= 0.25

    def forward(self, z):
        out = self.network(z)
        return out[..., :out.shape[-1]//2], out[..., out.shape[-1]//2:]

class FilmSIREN(nn.Module):
    def __init__(self, input_dim=2, z_dim=64, hidden_dim=36, output_dim=3, device=None):
        super().__init__()
        self.device = device
        self.z = nn.Parameter(torch.randn(1, z_dim))

        self.network = nn.ModuleList([
            FiLMLayer(input_dim, hidden_dim),
            FiLMLayer(hidden_dim, hidden_dim),
            FiLMLayer(hidden_dim, hidden_dim),
        ])
        self.final_layer = nn.Linear(hidden_dim, output_dim)
        self.mapping_network = CustomMappingNetwork(z_dim, 6, len(self.network)*hidden_dim*2)

        self.network.apply(frequency_init(25))
        self.final_layer.apply(frequency_init(25))
        self.network[0].layer.apply(lambda m: m.weight.data.uniform_(-1./input_dim, 1./input_dim))

    def forward(self, coords):
        frequencies, phase_shifts = self.mapping_network(self.z)
        frequencies = frequencies * 15 + 30

        x = coords
        for idx, layer in enumerate(self.network):
            start = idx * self.network[0].layer.out_features
            end = (idx+1) * self.network[0].layer.out_features
            x = layer(x, frequencies[:, start:end], phase_shifts[:, start:end])

        return torch.sigmoid(self.final_layer(x))

In [None]:
def train_siren(model, coords, rgb_values, device, total_steps=10000):
    coords_tensor = torch.tensor(coords, dtype=torch.float32).to(device)
    rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32).to(device)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 5000, 0.5)
    early_stop = EarlyStopping(patience=4000)

    for step in range(total_steps):
        optimizer.zero_grad()
        pred = model(coords_tensor)
        loss = criterion(pred, rgb_tensor)
        loss.backward()
        optimizer.step()
        scheduler.step()

        if early_stop(loss.item()):
            print(f"Early stopping at step {step}")
            break

        if step % 2000 == 0:
            print(f"Step {step}, Loss: {loss.item():.6f}")
    return model

if __name__ == "__main__":

    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Parameters
    image_path = cropped_image_path
    patch_size = 64
    overlap = patch_size // 4
    stride = patch_size - overlap

    # Load image
    full_img = Image.open(image_path).convert("RGB")
    orig_w, orig_h = full_img.size
    original_np = np.array(full_img)

    # Initialization for reconstruction canvas
    sum_canvas = np.zeros((orig_h, orig_w, 3), dtype=np.float32)
    count_canvas = np.zeros((orig_h, orig_w, 1), dtype=np.float32)

    # Metrics storage
    psnr_values = []
    ssim_values = []
    patch_count = 0

    # Create a directory to save the weights
    weights_dir = "saved_siren_016_weights"
    os.makedirs(weights_dir, exist_ok=True)

    # Process each patch
    for i in range(0, orig_w, stride):
        for j in range(0, orig_h, stride):
            left = max(0, min(i, orig_w - patch_size))
            upper = max(0, min(j, orig_h - patch_size))
            right = left + patch_size
            lower = upper + patch_size

            patch = full_img.crop((left, upper, right, lower))
            coords, rgb, h_patch, w_patch = extract_patch_data(patch)

            # Initialize the model
            model = Siren(in_features=2, hidden_features=32, hidden_layers=6).to(device)
            print(f"\nTraining patch at ({left},{upper})")
            trained_model = train_siren(model, coords, rgb, device)

            # summary
            if i==0 and j==0:
                summary(model, (2,))


            # Save weights after training each patch
            patch_weight_path = os.path.join(weights_dir, f'patch_{patch_count}.pth')
            torch.save(trained_model.state_dict(), patch_weight_path)

            # Generate results
            with torch.no_grad():
                coords_tensor = torch.tensor(coords, dtype=torch.float32).to(device)
                output = trained_model(coords_tensor).cpu().numpy()


            generated = (output * 255).clip(0, 255).astype(np.uint8).reshape(h_patch, w_patch, 3)
            sum_canvas[upper:lower, left:right] += generated
            count_canvas[upper:lower, left:right] += 1


    # Generate final image
    count_canvas[count_canvas == 0] = 1  # Avoid division by zero
    final_image = (sum_canvas / count_canvas).astype(np.uint8)
    Image.fromarray(final_image).save("full_reconstruction_siren_016_64.png")


In [None]:

def train_siren(model, coords, rgb_values, device, total_steps=10000):
    coords_tensor = torch.tensor(coords, dtype=torch.float32).to(device)
    rgb_tensor = torch.tensor(rgb_values, dtype=torch.float32).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.5)
    criterion = nn.MSELoss()
    early_stop = EarlyStopping(patience=4000)

    for step in range(total_steps):
        optimizer.zero_grad()
        pred = model(coords_tensor)
        loss = criterion(pred, rgb_tensor)
        loss.backward()
        optimizer.step()
        scheduler.step()

        if early_stop(loss.item()):
            print(f"Early stopping at step {step}")
            break

        if step % 2000 == 0:
            print(f"Step {step}, Loss: {loss.item():.6f}")

    return model

if __name__ == "__main__":
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    sum_canvas = np.zeros((orig_h, orig_w, 3), dtype=np.float32)
    count_canvas = np.zeros((orig_h, orig_w, 1), dtype=np.float32)

    # each patch
    psnr_values = []
    ssim_values = []
    prev_state_dict = None
    patch_count = 0
    weights_dir = "saved_filmsiren_016_weights"
    os.makedirs(weights_dir, exist_ok=True)

    for i in range(0, orig_w, stride):
        for j in range(0, orig_h, stride):

            left = max(0, min(i, orig_w - patch_size))
            upper = max(0, min(j, orig_h - patch_size))
            right = left + patch_size
            lower = upper + patch_size

            # current patch
            patch = full_img.crop((left, upper, right, lower))
            actual_h, actual_w = patch.size[1], patch.size[0]

            # data
            coords, rgb, h_patch, w_patch = extract_patch_data(patch)

            # train（from previous）
            model = FilmSIREN(device=device).to(device)
            if prev_state_dict is not None:
                model.load_state_dict(prev_state_dict, strict=False)
            if j == 0 and i == 0 :
                summary(model, input_size=(2,), device=device.type)
            # current patch
            print(f"\nTraining patch at ({left},{upper})")
            trained_model = train_siren(model, coords, rgb, device)
            patch_weight_path = os.path.join(weights_dir, f'patch_{patch_count}.pth')
            torch.save(trained_model.state_dict(), patch_weight_path)
            # custom weight（except z）
            current_state_dict = copy.deepcopy(trained_model.state_dict())
            del current_state_dict['z']
            prev_state_dict = current_state_dict

            # output
            with torch.no_grad():
                coords_tensor = torch.tensor(coords, dtype=torch.float32).to(device)
                output = trained_model(coords_tensor).cpu().numpy()


            generated = (output * 255).clip(0, 255).astype(np.float32)
            generated = generated.reshape(actual_h, actual_w, 3)

            # final
            sum_canvas[upper:lower, left:right] += generated
            count_canvas[upper:lower, left:right] += 1



    # final
    count_canvas[count_canvas == 0] = 1
    final_image = (sum_canvas / count_canvas).astype(np.uint8)
    Image.fromarray(final_image).save("full_reconstruction_filsiren_016_64.png")




recon

In [None]:
from PIL import Image

# def enlarge_image(input_path, output_path, new_size=(4094, 2400)):
def enlarge_image(input_path, output_path, new_size=(1920, 1080)):

    original_image = Image.open(input_path)
    original_width, original_height = original_image.size
    new_image = Image.new("RGB", new_size, (255, 255, 255))

    # central
    top_left_x = (new_size[0] - original_width) // 2
    top_left_y = (new_size[1] - original_height) // 2

    new_image.paste(original_image, (top_left_x, top_left_y))
    new_image.save(output_path)


input_image_path = path  # image path

output_image_path = 'enlarged_image.png'  # larger image
enlarge_image(input_image_path, output_image_path)
width1, height1 = image.size


print(f"Image Size: {width1} x {height1}")

In [None]:
import odak
import torch
import torch.nn as nn
from argparse import Namespace

args_prop = Namespace(
    wavelengths=[639e-9, 515e-9, 473e-9],
    pixel_pitch=3.74e-6,
    volume_depth=5e-3,
    d_val=0.,
    pad_size=[1080, 1920],
    aperture_size = 1500,
    device = "cuda"
)


propagator = odak.learn.wave.propagator(
    resolution = args_prop.pad_size,
    wavelengths = args_prop.wavelengths,
    pixel_pitch = args_prop.pixel_pitch,
    number_of_frames = 3,
    number_of_depth_layers = 3,
    volume_depth = args_prop.volume_depth,
    image_location_offset = args_prop.d_val,
    propagation_type = 'Bandlimited Angular Spectrum',
    propagator_type = 'forward',
    laser_channel_power = torch.eye(3),
    aperture_size = args_prop.aperture_size,
    aperture = None,
    method = 'conventional',
    device = args_prop.device
)
phase_map = odak.learn.tools.load_image(r"enlarged_image.png", normalizeby = 255., torch_style = True).to(args_prop.device)
phase_map = (phase_map * 2 * odak.pi) % (2 * odak.pi)

try:
    recon_output = propagator.reconstruct(phase_map, amplitude=None, no_grad = True)
except Exception as e:
    print("Error during reconstruction:", e)

print(recon_output.size())
reconstruction_intensities = torch.sum(recon_output, dim = 0)
for idx, recon in enumerate(reconstruction_intensities):
    odak.learn.tools.save_image(
        f"./recon_{idx}.png",
        recon,
        cmin=0.,
        cmax=1.0
    )