In [1]:
# Base imports
import sys
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pycocotools.coco import COCO
from PIL import Image
from torchvision import transforms
# Custom imports
from model import InfusionNet
import tools.dct as dct_tools
from tools.plot_dct import plot_dct, plot_bbox

flir_path = '../data/FLIR/'
llvip_path = '../data/LLVIP/'

flir_rgbimages = os.listdir(flir_path + 'images_rgb_train/data')
flir_thermalimages = os.listdir(flir_path + 'images_thermal_train/data')
llvip_rgbimages = os.listdir(llvip_path + 'visible/train')

llvip_thermalimages = os.listdir(llvip_path + 'infrared/train')

print('Number of FLIR RGB images: ', len(flir_rgbimages))
print('Number of FLIR Thermal images: ', len(flir_thermalimages))

print('Number of LLVIP RGB images: ', len(llvip_rgbimages))
print('Number of LLVIP Thermal images: ', len(llvip_thermalimages))

llvip = COCO(llvip_path + 'LLVIP.json') # load the dataset
llvip_ids = llvip.getImgIds()
img_obj = llvip.loadImgs([llvip_ids[1]])
anns_obj = llvip.loadAnns(llvip.getAnnIds(imgIds=[llvip_ids[1]]))


rgb_img = Image.open(llvip_path + 'visible/train/' + img_obj[0]['file_name'])
ir_img = Image.open(llvip_path + 'infrared/train/' + img_obj[0]['file_name'])


#plt.figure(figsize=(20,20))
#plt.subplot(1,2,1)
#plt.imshow(rgb_img)
# draw bounding boxes
#plot_bbox(rgb_img, anns_obj)
#plt.subplot(1,2,2)
#plt.imshow(ir_img)
# draw bounding boxes
#plot_bbox(ir_img, anns_obj)
#plt.show()

transforms = transforms.Compose([transforms.ToTensor()])

rgb_tensor = transforms(rgb_img).unsqueeze(0)
dct_tensor = dct_tools.dct_2d(rgb_tensor, norm='ortho')
masked_tensor = dct_tools.mask_image(dct_tensor, 0.1)
masked_rgb = dct_tools.idct_2d(masked_tensor, norm='ortho')

#plot_dct(rgb_tensor, masked_tensor, masked_rgb)
#plot_bbox(rgb_tensor, anns_obj)

ir_tensor = transforms(ir_img).unsqueeze(0)

dct_tensor = dct_tools.dct_2d(ir_tensor, norm='ortho')

masked_tensor = dct_tools.mask_image(dct_tensor, 0.05)

masked_ir = dct_tools.idct_2d(masked_tensor, norm='ortho')

#plot_dct(ir_tensor, masked_tensor, masked_ir)
#plot_bbox(ir_tensor, anns_obj)

Number of FLIR RGB images:  10319
Number of FLIR Thermal images:  10742
Number of LLVIP RGB images:  12025
Number of LLVIP Thermal images:  12025
Loading annotations into memory...
Done (t=0.13s)
Creating index...
index created!


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [2]:
from model import InfusionNet
import torch
infusion_model = InfusionNet(num_features=16, reduction=8, tau = 0.2)
#
input_tensors = torch.cat((torch.cat((rgb_tensor, ir_tensor), dim=1),torch.cat((rgb_tensor, ir_tensor), dim=1)), dim=0)
#
output_tensor = infusion_model(input_tensors)
#
print(output_tensor.shape)

Phase 0
Input to phase 0: torch.Size([2, 3, 1024, 1280])
Input to phase 0: torch.Size([2, 3, 1024, 1280])
Phase 1
Input shape to inner phases: torch.Size([2, 1, 507, 635])
torch.Size([2, 16, 507, 635]) torch.Size([2, 16, 507, 635])
Input shape to inner phases: torch.Size([2, 1, 507, 635])
torch.Size([2, 16, 507, 635]) torch.Size([2, 16, 507, 635])
Phase 2
Input shape to inner phases: torch.Size([2, 1, 507, 635])
torch.Size([2, 16, 507, 635]) torch.Size([2, 16, 507, 635])
Input shape to inner phases: torch.Size([2, 1, 507, 635])
torch.Size([2, 16, 507, 635]) torch.Size([2, 16, 507, 635])
Phase 3
Input shape to inner phases: torch.Size([2, 1, 507, 635])
torch.Size([2, 16, 507, 635]) torch.Size([2, 16, 507, 635])
Input shape to inner phases: torch.Size([2, 1, 507, 635])
torch.Size([2, 16, 507, 635]) torch.Size([2, 16, 507, 635])
Output phase 3: torch.Size([2, 1, 507, 635]) torch.Size([2, 1, 507, 635])
torch.Size([2, 3, 507, 635])
