In [None]:
from generator_model import Generator 
from PIL import Image
from torchvision import transforms
import torch
import matplotlib.pyplot as plt

def load_checkpoint(checkpoint_file:str, model, device):
    print("Loading checkpoint ", checkpoint_file)
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])

def preprocess_image(image_path):
    preprocess = transforms.Compose([
        transforms.Resize((256, 256)), 
        transforms.ToTensor(),          
    ])
    image = Image.open(image_path).convert('RGB')
    image_tensor = preprocess(image).unsqueeze(0)  
    return image_tensor, image

def postprocess_output(output_tensor):
    output_image = transforms.ToPILImage()(output_tensor.squeeze().cpu())  
    return output_image

def run_inference(generator_model, input_image_tensor, device):
    input_image_tensor = input_image_tensor.to(device)
    with torch.no_grad(): 
        output = generator_model(input_image_tensor)
    return output

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_image_tensor, original_image = preprocess_image('./TestImages/(6).JPEG')

generator_T_model = Generator(3)
generator_S_model = Generator(3)
load_checkpoint("./TestModels/200_gen_t.pth.tar", generator_T_model, device)
load_checkpoint("./TestModels/200_gen_s.pth.tar", generator_S_model, device)

In [None]:
encrypted_image = generator_T_model(input_image_tensor)

In [None]:
encrypted_image_tensor = encrypted_image.detach().cpu()
encrypted_image = postprocess_output(encrypted_image_tensor)

In [None]:
reconstructed_image = generator_S_model(encrypted_image_tensor)

In [None]:
reconstructed_image_tensor = reconstructed_image.detach().cpu()
reconstructed_image = postprocess_output(reconstructed_image_tensor)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(10, 5))  # Create subplots with 1 row and 2 columns

# Display the original input image
ax[0].imshow(original_image)
ax[0].set_title("Original Image")
ax[0].axis('off')  # Hide axes

# Display the generated output image
ax[1].imshow(encrypted_image)
ax[1].set_title("Encrypted Image")
ax[1].axis('off')  # Hide axes

ax[2].imshow(reconstructed_image)
ax[2].set_title("Reconstructed Image")
ax[2].axis('off')  # Hide axes

plt.show()

In [None]:
import os
from generator_model import Generator 
from PIL import Image
from torchvision import transforms
import torch
import matplotlib.pyplot as plt

gen_t_path = "./TestModels/200_gen_t.pth.tar"
gen_s_path = "./TestModels/200_gen_s.pth.tar"
test_images_dir = './TestImages'

def load_checkpoint(checkpoint_file:str, model, device):
    print("Loading checkpoint ", checkpoint_file)
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])

def preprocess_image(image_path):
    preprocess = transforms.Compose([
        transforms.Resize((256, 256)), 
        transforms.ToTensor(),          
    ])
    image = Image.open(image_path).convert('RGB')
    image_tensor = preprocess(image).unsqueeze(0)  
    return image_tensor, image

def postprocess_output(output_tensor):
    output_image = transforms.ToPILImage()(output_tensor.squeeze().cpu())  
    return output_image

def run_inference(generator_model, input_image_tensor, device):
    input_image_tensor = input_image_tensor.to(device)
    with torch.no_grad(): 
        output = generator_model(input_image_tensor)
    return output

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator_T_model = Generator(3)
generator_S_model = Generator(3)
load_checkpoint(gen_t_path, generator_T_model, device)
load_checkpoint(gen_s_path, generator_S_model, device)

for filename in os.listdir(test_images_dir):
    if filename.endswith(('.jpg', '.jpeg', '.png', 'JPEG', 'JPG', '.PNG')):  # Filter for image files
        image_path = os.path.join(test_images_dir, filename)
        
        # Preprocess input image
        input_image_tensor, original_image = preprocess_image(image_path)

        # Run inference on the input image using both models
        encrypted_image = generator_T_model(input_image_tensor)
        encrypted_image_tensor = encrypted_image.detach().cpu()
        encrypted_image = postprocess_output(encrypted_image_tensor)
        
        reconstructed_image = generator_S_model(encrypted_image_tensor)
        reconstructed_image_tensor = reconstructed_image.detach().cpu()
        reconstructed_image = postprocess_output(reconstructed_image_tensor)

        # Display the original, encrypted, and reconstructed images
        fig, ax = plt.subplots(1, 3, figsize=(10, 5))  # Create subplots with 1 row and 3 columns
        
        # Display the original input image
        ax[0].imshow(original_image)
        ax[0].set_title("Original Image")
        ax[0].axis('off')  # Hide axes

        # Display the encrypted image
        ax[1].imshow(encrypted_image)
        ax[1].set_title("Encrypted Image")
        ax[1].axis('off')  # Hide axes

        # Display the reconstructed image
        ax[2].imshow(reconstructed_image)
        ax[2].set_title("Reconstructed Image")
        ax[2].axis('off')  # Hide axes

        plt.show()
