In [1]:
import numpy as np
import os
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from collections import OrderedDict


from Joint_HDRDN import Joint_HDRDN
from alignment_network import Alignment
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
full_model = Joint_HDRDN(n_channel =8, out_channel = 4, embed_dim = 60, depths=[4, 4, 4])
full_model = nn.DataParallel(full_model)
# 2. Load state_dict from file
state_dict = torch.load('Joint_HDRDN.pth', map_location='cuda')  # load directly to GPU

full_model.load_state_dict(state_dict)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<All keys matched successfully>

In [3]:
full_state_dict = full_model.state_dict()

# Create a new dictionary with only align_head, att1, att2, conv_first
new_state_dict = {k: v for k, v in full_state_dict.items() if any(part in k for part in ['align_head', 'att1', 'att2', 'conv_first'])}
torch.save(new_state_dict, 'alignment_only_weights.pth')

In [11]:
class AlignmentOnlyModel(nn.Module):
    def __init__(self, pretrained_alignment, embed_dim= 60):
        super(AlignmentOnlyModel, self).__init__()

        # Load the pre-trained alignment parts
        self.align_head = pretrained_alignment.align_head
        self.att1 = pretrained_alignment.att1
        self.att2 = pretrained_alignment.att2
        self.conv_first = pretrained_alignment.conv_first
        self.embed_dim = embed_dim
        # New layers on top
        self.extra_layers = nn.Sequential(
            nn.Conv2d(self.embed_dim, 64, kernel_size=3, padding=1),   # 96 -> 64 channels
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),    # 64 -> 32 channels
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, kernel_size=3, padding=1)      # 32 -> 3 channels (final output)
        )

    def forward(self, x1, x2, x3):
        # --- Alignment Part ---
        f1_att, f2, f3_att, f1, f3 = self.align_head(x1, x2, x3)
        f1_att = f2 + f1_att
        f3_att = f2 + f3_att

        f1_att = self.att1(f1_att, f2) * f1_att
        f3_att = self.att2(f3_att, f2) * f3_att

        x = self.conv_first(torch.cat((f1_att, f2, f3_att), axis=1))  # shape (batch, 96, H, W)

        # --- New Layers ---
        out = self.extra_layers(x)  # Final output with 3 channels

        return out



In [12]:
# 1. Load the saved small .pth
pretrained_alignment = Alignment(n_channel =8, out_channel = 4, embed_dim = 60, depths=[4, 4, 4])  # same structure
pretrained_alignment = nn.DataParallel(pretrained_alignment).cuda()
pretrained_alignment.load_state_dict(torch.load('alignment_only_weights.pth'))

# 2. Create new network
model = AlignmentOnlyModel(pretrained_alignment.module, embed_dim = 60)  # .module because of DataParallel
model = nn.DataParallel(model).cuda()

In [13]:
npypath = '/data/asim/ISP/HDR_transformer/data/RAW/raw-2022-0606-2151-4147.npz'
imdata = np.load(npypath)

sht = imdata['sht']
mid = imdata['mid']
lng = imdata['lng']
hdr = imdata['hdr']

crop_size = 992
H, W = hdr.shape[1], hdr.shape[2]
start_h = (H - crop_size) // 2
start_w = (W - crop_size) // 2

sht_crop = sht[:, start_h:start_h+crop_size, start_w:start_w+crop_size]
mid_crop = mid[:, start_h:start_h+crop_size, start_w:start_w+crop_size]
lng_crop = lng[:, start_h:start_h+crop_size, start_w:start_w+crop_size]
hdr_crop = hdr[:, start_h:start_h+crop_size, start_w:start_w+crop_size]

print(sht_crop.shape, hdr_crop.shape)

(8, 992, 992) (4, 992, 992)


In [14]:
def to_tensor(np_array):
    t = torch.from_numpy(np_array).float()
    return t

im1 = to_tensor(sht_crop).to(device).unsqueeze(0)
im2 = to_tensor(mid_crop).to(device).unsqueeze(0)
im3 = to_tensor(lng_crop).to(device).unsqueeze(0)
ref_hdr = to_tensor(hdr_crop).to(device).unsqueeze(0)
im1[:, :4, :, :].shape, ref_hdr.shape

(torch.Size([1, 4, 992, 992]), torch.Size([1, 4, 992, 992]))

In [15]:
with torch.no_grad():
    generate_hdr = model(im1, im2, im3)
generate_hdr.shape

torch.Size([1, 3, 992, 992])