In [1]:
import scipy.io as io
import numpy as np
from tqdm import tqdm
import argparse
from einops import rearrange, repeat
import torch
import torchvision
from torch.utils.data import TensorDataset, DataLoader
from torchvision.utils import save_image

from model import *
from datasets import *


In [2]:
# def process_data(dataname='carl',view=0,Del=0.1,fold=0):
    
#     data=io.loadmat(f"./data/carl.mat")
#     #folds=io.loadmat(f"./data/carl_del_{Del}.mat")['folds']
#     X=data['X']
#     x=rearrange(torch.tensor(X[0,view]/255), '(n i) (h w) -> n i h w', i=1,h=145,w=100).float()

#     dataset=TensorDataset(x)
#     dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True)
#     return dataloader

In [3]:
device="cuda" if torch.cuda.is_available() else "cpu"
dataname='carl'
view=0
Del=0.1
fold=0
# dataloader=process_data(dataname=dataname,view=0,Del=0.1,fold=0)

In [4]:
# 加载自编码器
n_feat=64
arch=[[2,2],[3,2],[5,5],[5,5]]
AE=[]
for view in [0,1,2]:
    AE.append(AE2d(in_channels=1,n_feat=n_feat,arch=arch))
    AE[view]=AE[view].to(device)
    AE[view].load_state_dict(torch.load(f"./models/AE_{dataname}_view{view}_del{Del}_fold{fold}_ep100.pth",map_location=device))


In [5]:
# 家在diffusion
n_T=1000
betas=(1e-6, 2e-2)
pairedrate=0.1
drop_prob=0.1
ddpm=[]

configs={}
configs['arch']=[4,4,4,8]
configs['dim_x']=512
configs['dim_c']=1024
for view in [0,1,2]:
    ddpm.append(DDPM(
        nn_model=UNet(in_channels=1, n_feat=n_feat, feature_dim=configs['dim_c'], arch=configs['arch']),
        betas=betas, n_T=n_T, device=device, drop_prob=drop_prob
        ))
    ddpm[view]=ddpm[view].to(device)
    ddpm[view].load_state_dict(torch.load(f"./models/ddpm_{dataname}_view{view}_pairedrate{pairedrate}_fold{fold}_ep1000.pth",map_location=device))

In [6]:
from PIL import Image
# 加载数据
img_classic=[]
img_infrared=[]
img_thermal=[]
for i in [30,181,375]:
    image = Image.open(f'./data/classic/{i}.png')
    array = np.array(image).reshape(1,1,145,100)
    tensor = torch.tensor(array)/255
    img_classic.append(tensor)
    image = Image.open(f'./data/infrared/{i}.png')
    array = np.array(image).reshape(1,1,145,100)
    tensor = torch.tensor(array)/255
    img_infrared.append(tensor)
    image = Image.open(f'./data/thermal/{i}.png')
    array = np.array(image).reshape(1,1,145,100)
    tensor = torch.tensor(array)/255
    img_thermal.append(tensor)


In [7]:
## AE 提特征
z_classic=[]
z_infrared=[]
z_thermal=[]
with torch.no_grad():
    for i in range(len(img_classic)):
        x=img_classic[i].to(device)
        x=torch.nn.functional.pad(x,pad=[0,0,2,3],mode='constant',value=0)
        z = AE[0].forward_z(x)
        z_classic.append(z.reshape([-1,512]))

        x=img_infrared[i].to(device)
        x=torch.nn.functional.pad(x,pad=[0,0,2,3],mode='constant',value=0)
        z = AE[1].forward_z(x)
        z_infrared.append(z.reshape([-1,512]))

        x=img_thermal[i].to(device)
        x=torch.nn.functional.pad(x,pad=[0,0,2,3],mode='constant',value=0)
        z = AE[2].forward_z(x)
        z_thermal.append(z.reshape([-1,512]))

In [8]:
## diffusion

with torch.no_grad():
    ## infrared+thermal====>>classic
    c=torch.cat([z_infrared[0].reshape(1,-1,1),z_thermal[0].reshape(1,-1,1)],dim=-2)
    recover_z_classic=ddpm[0].ddpm_sample(c=c, n_sample=1, size=[1,configs['dim_x']], device=device, guide_w=1.0)
    ## classic+thermal====>>infrared
    c=torch.cat([z_classic[1].reshape(1,-1,1),z_thermal[1].reshape(1,-1,1)],dim=-2)
    recover_z_infrared=ddpm[1].ddpm_sample(c=c, n_sample=1, size=[1,configs['dim_x']], device=device, guide_w=1.0)
    ## classic+infrared====>>thermal
    c=torch.cat([z_classic[2].reshape(1,-1,1),z_infrared[2].reshape(1,-1,1)],dim=-2)
    recover_z_thermal=ddpm[2].ddpm_sample(c=c, n_sample=1, size=[1,configs['dim_x']], device=device, guide_w=1.0)
    


sampling timestep 1000

In [9]:
## AE解码
with torch.no_grad():
    recover_classic = AE[0].forward_x_rec(recover_z_classic.reshape(1,-1,1,1))[:,:,2:-3,:].cpu()
    recover_infrared = AE[1].forward_x_rec(recover_z_infrared.reshape(1,-1,1,1))[:,:,2:-3,:].cpu()
    recover_thermal = AE[2].forward_x_rec(recover_z_thermal.reshape(1,-1,1,1))[:,:,2:-3,:].cpu()
    

In [10]:

array = np.array(recover_classic.reshape(145,100))*255
array = array.astype(np.uint8)
image = Image.fromarray(array)
image.save('demo_classic.png')
array = np.array(recover_infrared.reshape(145,100))*255
array = array.astype(np.uint8)
image = Image.fromarray(array)
image.save('demo_infrared.png')
array = np.array(recover_thermal.reshape(145,100))*255
array = array.astype(np.uint8)
image = Image.fromarray(array)
image.save('demo_thermal.png')
