# 画出预测图像

In [None]:
import torch as th
import numpy as np
import h5py

from datasets.FLIR_2 import FLIRDataset
from datasets.TNO import TNODataset

In [None]:
path = '/home/ZiHanCao/datasets/TNO'
ds = TNODataset(path, 'test', no_split=True)
dl = th.utils.data.DataLoader(ds, batch_size=1, shuffle=False)


In [None]:
import torch


from model.build_network import build_network

device = torch.device('cuda:0')
torch.cuda.set_device(device)

net=build_network('dcformer_mwsa', spectral_num=1, added_c=1, block_list=[4,[4,3],[4,3,2]],mode='C')
net.load_state_dict(
    # th.load('/home/ZiHanCao/exps/panformer/weight/dcformer_379zkf3e/ep_550.pth', map_location=device)['model']  # 2n8eo45b
    # th.load('./weight/dcformer_17rgbfmz/ep_490.pth', map_location=device)['model']
    th.load('./weight/dcformer_3l0px5zm/ep_260.pth', map_location=device)['model']
)
net=net.cuda()

In [None]:
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from scipy.io import savemat
import cv2

from model.base_model import PatchMergeModule

def convert_uint8(img):
    print(img.shape, end=' ')
    if img.dtype != np.uint8:
        img = img.clip(0, 1)
        img *= 255
        # print('convert to [0, 255]')
    return img.astype(np.uint8)

net.eval()
patch_merge_net = PatchMergeModule(net, 1, patch_size=64, scale=1)
with th.no_grad():
    for i, (ir, ms, vis, gt) in enumerate(dl):
        ir, ms, vis, gt = ir.cuda(), ms.cuda(), vis.cuda(), gt.cuda()
        
        spa_size = gt.shape[-2:]
        pan_nc = ir.size(1)
        ms_nc = ms.size(1)
        input = (
            F.interpolate(ms, size=tuple(vis.shape[-2:]), mode='bilinear', align_corners=True),
            vis,
            torch.cat([ir, torch.zeros(1, ms_nc - pan_nc, *spa_size).cuda()], dim=1)
        )
        sr = patch_merge_net.forward_chop(*input)[0]
        sr = sr.detach().cpu().numpy()[0]
        sr_show = sr.transpose([1, 2, 0])
        vis_show = vis.detach().cpu().numpy()[0].transpose([1, 2, 0])
        ir_show = ir.detach().cpu().numpy()[0].transpose([1, 2, 0])
        
        fig, axes = plt.subplots(ncols=3, figsize=(12, 4), dpi=200)
        axes = axes.flatten()
        
        for img, name, ax in zip([vis_show, ir_show, sr_show],
                                 ['vis', 'ir', 'fuse'],
                                 axes):
            ax.imshow(img, 'gray')
            ax.set_axis_off()
            ax.set_title(name)
        
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.tight_layout(pad=0)
        plt.show()
        
        sr_show = convert_uint8(sr_show)
        cv2.imwrite(f'./visualized_img/ir/{i}.bmp', sr_show)
        print('img saved to {}'.format(f'./visualized_img/ir/{i}.bmp'))
        
        # mat=dict(
        #     fuse=convert_uint8(sr_show),
        #     vis=convert_uint8(vis_show),
        #     ir=convert_uint8(ir_show)
        # )
        # mat_path = f'./visualized_img/ir/{i}.mat'
        
        # savemat(mat_path, mat)
        # print('mat file saved to {}'.format(mat_path))
        
        # raw_ir = ir
        # y=net.val_step(ms.to('cuda:1'), ir.to('cuda:1'), vis.to('cuda:1'))
        # y_show=make_grid(y.detach().cpu(), nrow=4).numpy()
        # vis=make_grid(vis, nrow=4).numpy()
        # ir=make_grid(ir, nrow=4).numpy()
        
        # print(th.abs((y.detach().cpu()-raw_ir)).sum())
        # print(y_show.min(), y_show.max())
        
        # ys = np.concatenate([vis, ir, y_show], axis=2)
        
        # ax=plt.gca()
        # ax.set_axis_off()
        # plt.gcf().set_dpi(300)
        # plt.gcf().set_size_inches(4, 4*3)
        # plt.tight_layout(pad=0)
        # ax.imshow(ys.transpose(1,2,0))
        # plt.show()

In [None]:
from scipy.io import savemat, loadmat

path = r'./visualized_img/ir/data_2.mat'
mat_d = loadmat(path)
fuse = mat_d['fuse']
np.min(fuse)

In [None]:
x=np.random.randn(3, 64, 64)
y=np.random.randn(3, 64, 64)

np.concatenate([x,y], axis=-1).shape

# 可视化attention

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import h5py

from model.module.attention import MultiScaleWindowCrossAttention, CAttention
from model.module.swin import window_partition
from model.module.helper_func import exists
from model.dcformer_mwsa import DCFormerMWSA

from datasets.HISR import HISRDataSets

In [None]:
net = DCFormerMWSA(spectral_num=31, added_c=3, block_list=[4,[4,3],[4,3,2]],mode='C')#.cuda(1)
net.eval()

path = "/home/ZiHanCao/datasets/HISI/new_cave/test_cave(with_up)x4.h5"
ds = HISRDataSets(h5py.File(path))


In [None]:
from einops import rearrange

CATTEN_LIST = []
CATTN_OUT_LIST = []
MWSA_ATTN_LIST = []
MWSA_ATTN_OUT_LIST = []
GHOST_OUT_LIST = []

def cattention_hook(m, inp, outp):
    # calcu attn again
    x = inp[0]
    with torch.no_grad():
        qkv = m.qkv_dwconv(m.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)

        q = rearrange(q, "b (head c) h w -> b head c (h w)", head=m.num_heads)
        k = rearrange(k, "b (head c) h w -> b head c (h w)", head=m.num_heads)
        v = rearrange(v, "b (head c) h w -> b head c (h w)", head=m.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        attn = (q @ k.transpose(-2, -1)) * m.temperature
        attn = attn.softmax(dim=-1)
    
    CATTEN_LIST.append(attn.detach().cpu().numpy())
    CATTN_OUT_LIST.append(outp.detach().cpu().numpy())
    
def mwsa_attn_hook(m, inp, outp):
    tgt, mem = inp
    with torch.no_grad():
        if not exists(m.window_size1) and not exists(m.window_size2):
                m.window_size1 = m.window_dict[tgt.size(-1)]
                m.window_size2 = m.window_dict[mem.size(-1)]

        b, c, h, w = tgt.shape
        mem = m.match_c(mem)
        q = window_partition(
            tgt.permute(0, 2, 3, 1), m.window_size1
        )  # [nw*b, wh1, ww1, c]
        kv = window_partition(
            mem.permute(0, 2, 3, 1), m.window_size2
        )  # [nw*b, wh2, ww2, c]

        q = m.q(q)
        kv = m.kv(kv)
        k, v = kv.chunk(2, dim=-1)

        # assert tgt.size(0) == mem.size(0)

        # q: [b*nw, nh, wh1*ww1, c]
        # k, v: [b*nw, nh, wh2*ww2, c]
        q = rearrange(q, "b h w (head c) -> b head (h w) c", head=m.num_heads)
        k = rearrange(k, "b h w (head c) -> b head (h w) c", head=m.num_heads)
        v = rearrange(v, "b h w (head c) -> b head (h w) c", head=m.num_heads)

        q = F.normalize(q, dim=-1)
        k = F.normalize(k, dim=-1)

        # [b*nw, nh, wh1*ww1, wh2*ww2]
        attn = (q @ k.transpose(-2, -1)) * m.temperature

        attn = attn.softmax(-1)
        ghost_out = m.ghost_module(
            torch.cat([tgt, F.interpolate(mem, tgt.shape[-1], mode="bilinear")], dim=1)
        )
    MWSA_ATTN_LIST.append(attn.detach().cpu().numpy())
    MWSA_ATTN_OUT_LIST.append(outp.detach().cpu().numpy())
    GHOST_OUT_LIST.append(ghost_out.detach().cpu().numpy())
    
    

In [None]:
c_attn_handlers = []
mwsa_attn_hooks = []

for m in net.modules():
    if isinstance(m, CAttention):
        h = m.register_forward_hook(cattention_hook)
        c_attn_handlers.append(h)
        print('c attention hook')
    
    if isinstance(m, MultiScaleWindowCrossAttention):
        h = m.register_forward_hook(mwsa_attn_hook)
        mwsa_attn_hooks.append(h)
        print('mwsa attention hook')
        
        
def remove_all_hooks(hs):
    for h in hs:
        h.remove()
        

In [None]:
remove_all_hooks(c_attn_handlers)

In [None]:
dl = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=False)

In [None]:
pan, ms, lms, gt = next(iter(dl))

out = net.val_step(ms.float(),#.cuda(1),
                   lms.float(),#.cuda(1),
                   pan.float())#.cuda(1))

In [None]:

# norm to [0, 1]
def norm(x):
    return (x - x.min()) / (x.max() - x.min())

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 5, figsize=(25, 10))
axes = axes.flatten()
for i, c in enumerate(MWSA_ATTN_OUT_LIST):
    c = c.mean(1)[0]
    # c = c[0, 5]
    axes[i].imshow(norm(c), 'hot')
    # axes[i].set_axis_off()
    axes[i].set_title(f'MWSA Attention {i}')
    
axes[-1].set_axis_off()

In [None]:
MWSA_ATTN_OUT_LIST[0].shape

# sailency map

In [None]:
import numpy as np
import torch as th
import PIL.Image as pim

from sailency import LAM


In [None]:

LAM()

# multi-source projection head

In [None]:
from model.dcformer_mwsa import DCFormerMWSA
from model.base_model import BaseModel, register_model

import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualConvBlock(nn.Module):
    def __init__(
        self, channels: int, kernel_size: int, stride: int, padding: int
    ) -> None:
        super(ResidualConvBlock, self).__init__()
        self.rcb = nn.Sequential(
            nn.Conv2d(
                channels,
                channels,
                (kernel_size, kernel_size),
                (stride, stride),
                (padding, padding),
            ),
            nn.PReLU(),
            nn.Conv2d(
                channels,
                channels,
                (kernel_size, kernel_size),
                (stride, stride),
                (padding, padding),
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.rcb(x)

        out = torch.add(out, identity)

        return out


# Reference from `https://github.com/SHI-Labs/Cross-Scale-Non-Local-Attention/blob/master/src/model/utils/tools.py`
class MultiSourceProjection(nn.Module):
    def __init__(self, channels, kernel_size, scale) -> None:
        super(MultiSourceProjection, self).__init__()
        if scale == 2:
            de_kernel_size = 6
            stride = 2
            padding = 2
            upscale_factor = 2
        elif scale == 3:
            de_kernel_size = 9
            stride = 3
            padding = 3
            upscale_factor = 3
        elif scale == 4:
            de_kernel_size = 12
            stride = 4
            padding = 4

        self.down_conv1 = nn.Sequential(
            nn.Conv2d(
                channels,
                channels,
                (de_kernel_size, de_kernel_size),
                (stride, stride),
                (padding, padding),
            ),
            nn.PReLU(),
        )
        self.diff_encode1 = nn.Sequential(
            nn.ConvTranspose2d(
                channels,
                channels,
                (de_kernel_size, de_kernel_size),
                (stride, stride),
                (padding, padding),
            ),
            nn.PReLU(),
        )

        self.encoder = ResidualConvBlock(channels, kernel_size, 1, kernel_size // 2)

    def forward(self, x, x_is: torch.Tensor, x_cs: torch.Tensor) -> torch.Tensor:
        # cross_scale_attention = self.cross_scale_attention(x)
        # non_local_attention = self.non_local_attention(x)

        diff = self.encoder(x_cs - x_is)
        out = x_is + diff

        down_conv1 = self.down_conv1(out)
        diff1 = torch.sub(x, down_conv1)
        diff_encode1 = self.diff_encode1(diff1)
        estimate = torch.add(out, diff_encode1)
        # down_conv1 = self.down_conv1(multi_source_projection1)
        # diff1 = torch.sub(x, down_conv1)
        # diff_encode1 = self.diff_encode1(diff1)
        # estimate = torch.add(multi_source_projection1, diff_encode1)
        return estimate

In [None]:
msp = MultiSourceProjection(3, 3, 4)

x = torch.randn(1, 3, 16, 16)
in_x = torch.randn(1, 3, 64, 64)
cs_x = torch.randn(1, 3, 64, 64)

print(msp(x, in_x, cs_x).shape)

In [4]:
import scipy.io as sio
import numpy as np
import h5py

path = "/media/office-401-remote/Elements SE/cao/ZiHanCao/datasets/HISI/new_cave/test_cave(with_up)x4.h5"

In [7]:
file=h5py.File(path)
gt=file['GT'][:]
gt.shape

(11, 31, 512, 512)

In [None]:
sio.savemat('gt.mat', {'gt': gt})