In [2]:
pip install pycryptodome

Collecting pycryptodome
  Downloading pycryptodome-3.23.0-cp37-abi3-win_amd64.whl.metadata (3.5 kB)
Downloading pycryptodome-3.23.0-cp37-abi3-win_amd64.whl (1.8 MB)
   ---------------------------------------- 0.0/1.8 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.8 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.8 MB ? eta -:--:--
   ---------------------------------------- 0.0/1.8 MB ? eta -:--:--
   ----- ---------------------------------- 0.3/1.8 MB ? eta -:--:--
   ----------- ---------------------------- 0.5/1.8 MB 985.5 kB/s eta 0:00:02
   ----------------- ---------------------- 0.8/1.8 MB 1.0 MB/s eta 0:00:01
   ----------------- ---------------------- 0.8/1.8 MB 1.0 MB/s eta 0:00:01
   ----------------------- ---------------- 1.0/1.8 MB 1.0 MB/s eta 0:00:01
   ---------------------------------- ----- 1.6/1.8 MB 1.2 MB/s eta 0:00:01
   ---------------------------------------- 1.8/1.8 MB 1.2 MB/s eta 0:00:00
Installing collected packages: p


[notice] A new release of pip is available: 24.2 -> 25.3
[notice] To update, run: D:\Research and Design Simulation\fusion_env\Scripts\python.exe -m pip install --upgrade pip


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from Crypto.Cipher import AES

# ================================================================
# 1. Encoder-Decoder Networks
# ================================================================
class Encoder(nn.Module):
    def __init__(self, latent_channels=16):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(9, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, latent_channels, 3, padding=1)
        )

    def forward(self, img1, img2, img3):
        x = torch.cat([img1, img2, img3], dim=1)
        return self.conv(x)


class Decoder(nn.Module):
    def __init__(self, latent_channels=16):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(latent_channels, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 9, 3, padding=1)
        )

    def forward(self, fusion):
        x = self.conv(fusion)
        return torch.split(x, 3, dim=1)


# ================================================================
# 2. Dataset for 3 input images
# ================================================================
class ImageTripletDataset(torch.utils.data.Dataset):
    def __init__(self, folder1, folder2, folder3, transform=None):
        self.files1 = sorted(os.listdir(folder1))
        self.files2 = sorted(os.listdir(folder2))
        self.files3 = sorted(os.listdir(folder3))

        self.folder1 = folder1
        self.folder2 = folder2
        self.folder3 = folder3
        self.transform = transform

    def __len__(self):
        return min(len(self.files1), len(self.files2), len(self.files3))

    def __getitem__(self, idx):
        img1 = Image.open(os.path.join(self.folder1, self.files1[idx])).convert("RGB")
        img2 = Image.open(os.path.join(self.folder2, self.files2[idx])).convert("RGB")
        img3 = Image.open(os.path.join(self.folder3, self.files3[idx])).convert("RGB")

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)

        return img1, img2, img3


# ================================================================
# 3. Tensor → PIL Image
# ================================================================
def tensor_to_pil(tensor):
    img = tensor.detach().cpu()
    img = img[:3]
    img = img.permute(1, 2, 0)
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    img = (img * 255).byte().numpy()
    return Image.fromarray(img, mode="RGB")


# ================================================================
# 4. AES Encryption / Decryption + Encrypted Preview
# ================================================================
AES_KEY = b"1234567890ABCDEF1234567890ABCDEF"  # 32 bytes AES-256 key

def save_encrypted_image_preview(ciphertext, save_path, shape):
    c, h, w = shape
    expected = c * h * w

    arr = np.frombuffer(ciphertext, dtype=np.uint8)

    if len(arr) < expected:
        arr = np.pad(arr, (0, expected - len(arr)), 'constant')

    arr = arr[:expected].reshape((c, h, w))
    arr = np.transpose(arr, (1, 2, 0))

    if arr.shape[2] < 3:
        arr = np.repeat(arr, 3, axis=2)
    else:
        arr = arr[:, :, :3]

    Image.fromarray(arr.astype(np.uint8), "RGB").save(save_path)


def encrypt_tensor(tensor, save_bin_path, save_png_path):
    raw_bytes = tensor.detach().cpu().numpy().tobytes()
    cipher = AES.new(AES_KEY, AES.MODE_EAX)
    ciphertext, tag = cipher.encrypt_and_digest(raw_bytes)

    with open(save_bin_path, "wb") as f:
        f.write(cipher.nonce + tag + ciphertext)

    save_encrypted_image_preview(ciphertext, save_png_path, tensor.shape)


def decrypt_tensor(bin_path, tensor_shape):
    with open(bin_path, "rb") as f:
        file = f.read()

    nonce = file[:16]
    tag = file[16:32]
    ciphertext = file[32:]

    cipher = AES.new(AES_KEY, AES.MODE_EAX, nonce)
    decrypted = cipher.decrypt_and_verify(ciphertext, tag)

    array = torch.frombuffer(decrypted, dtype=torch.float32)
    return array.reshape(tensor_shape)


# ================================================================
# 5. Load Pretrained Models
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder().to(device)
decoder = Decoder().to(device)

encoder.load_state_dict(torch.load(
    r"D:\Research and Design Simulation\fusion_env\Include\models\encoder_3img.pth",
    map_location=device
))
decoder.load_state_dict(torch.load(
    r"D:\Research and Design Simulation\fusion_env\Include\models\decoder_3img.pth",
    map_location=device
))

encoder.eval()
decoder.eval()


# ================================================================
# 6. Load Test Images
# ================================================================
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

test_dataset = ImageTripletDataset(
    r"D:\Research and Design Simulation\fusion_env\Include\data\test\img1",
    r"D:\Research and Design Simulation\fusion_env\Include\data\test\img2",
    r"D:\Research and Design Simulation\fusion_env\Include\data\test\img3",
    transform=transform
)

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


# ================================================================
# 7. Three Output Folders
# ================================================================
save_encoded = "test_results_3img_encoded"          # encrypted .bin + noisy PNG
save_decoded = "test_results_3img_decoded"          # decrypted fused image
save_reconstructed = "test_results_3img_reconstructed"  # images after decoder

os.makedirs(save_encoded, exist_ok=True)
os.makedirs(save_decoded, exist_ok=True)
os.makedirs(save_reconstructed, exist_ok=True)


# ================================================================
# 8. Processing Loop
# ================================================================
with torch.no_grad():
    mse_loss_fn = nn.MSELoss()
    
    for i, (img1, img2, img3) in enumerate(test_loader):
        img1, img2, img3 = img1.to(device), img2.to(device), img3.to(device)

        # Encode to fused tensor
        fused = encoder(img1, img2, img3)

        # -------------------------------
        # Encrypt
        # -------------------------------
        bin_path = f"{save_encoded}/fused_{i}.bin"
        enc_png = f"{save_encoded}/fused_{i}.png"

        encrypt_tensor(fused[0], bin_path, enc_png)

        # -------------------------------
        # Decrypt
        # -------------------------------
        dec_fused = decrypt_tensor(bin_path, fused[0].shape).to(device)

        # Save decrypted fused image (visualized)
        tensor_to_pil(dec_fused).save(f"{save_decoded}/decrypted_fused_{i}.png")

        # -------------------------------
        # Decode (Reconstruction)
        # -------------------------------
        r1, r2, r3 = decoder(dec_fused.unsqueeze(0))

        tensor_to_pil(r1[0]).save(f"{save_reconstructed}/recon1_{i}.png")
        tensor_to_pil(r2[0]).save(f"{save_reconstructed}/recon2_{i}.png")
        tensor_to_pil(r3[0]).save(f"{save_reconstructed}/recon3_{i}.png")
        

        # ============================================================
        # 9. Compute MSE for this image triplet
        # ============================================================
        mse1 = mse_loss_fn(r1, img1).item()
        mse2 = mse_loss_fn(r2, img2).item()
        mse3 = mse_loss_fn(r3, img3).item()

        print(f"Image {i} → MSE1: {mse1:.6f}, MSE2: {mse2:.6f}, MSE3: {mse3:.6f}")

        # Save results to file
        with open("mse_results.txt", "a") as f:
            f.write(f"Image {i}: MSE1={mse1:.6f}, MSE2={mse2:.6f}, MSE3={mse3:.6f}\n")


print("✅ Completed: Encrypted + Decrypted + Reconstructed Saved Successfully!")


Image 0 → MSE1: 0.000919, MSE2: 0.001052, MSE3: 0.001308
✅ Completed: Encrypted + Decrypted + Reconstructed Saved Successfully!


  array = torch.frombuffer(decrypted, dtype=torch.float32)
