In [None]:
from PIL import Image
from torchvision import transforms
import torch

# Define the transformation (should match what was used during training)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to the same size as used during training
    transforms.ToTensor(),          # Convert images to tensors
])

# Load your three input images
img_bright = Image.open('/path/to/bright_image.jpg')
img_mid = Image.open('/path/to/mid_image.jpg')
img_dark = Image.open('/path/to/dark_image.jpg')

# Apply transformations
img_bright = transform(img_bright)
img_mid = transform(img_mid)
img_dark = transform(img_dark)

# Combine the images into a single multi-channel tensor
input_image = torch.cat((img_bright, img_mid, img_dark), dim=0)
input_image = input_image.unsqueeze(0)  # Add a batch dimension


In [None]:
from options.test_options import TestOptions
from models import create_model

# Simulate Command-Line Arguments for Testing
sys.argv = [
    'test.py',
    '--dataroot', '/media/nouman/New Volume/realEstatePhoto/RealEstateDataset',
    '--name', 'real_estate_pix2pix',
    '--model', 'pix2pix',
    '--direction', 'AtoB',
    '--input_nc', '9',  # 3 channels for each of bright, mid, dark
    '--output_nc', '3',  # 3 channels for the final image
    '--gpu_ids', '0',
    '--no_dropout'
]

# Parse options for testing
opt = TestOptions().parse()

# Create and load the model
model = create_model(opt)
model.setup(opt)
model.load_networks('latest')  # Load the latest saved model
model.eval()  # Set model to evaluation mode


In [None]:
# Prepare a dictionary to mimic the dataloader output
data = {
    'A': input_image,  # The concatenated image (3 channels each for bright, mid, dark)
    'A_paths': '/path/to/bright_image.jpg'  # You can add paths if needed
}

# Set the input to the model
model.set_input(data)

# Run inference
model.test()  # Forward pass for testing

# Get the generated output
visuals = model.get_current_visuals()
generated_output = visuals['fake_B']

# Convert the output to a PIL image and display it
output_image = transforms.ToPILImage()(generated_output.squeeze(0).cpu())

# Optionally display the image
output_image.show()

# Save the output image
output_image.save('/path/to/save/generated_image.jpg')


In [None]:
import matplotlib.pyplot as plt

# Display the input and output images
plt.figure(figsize=(15, 5))

# Display the bright input image
plt.subplot(1, 4, 1)
plt.imshow(transforms.ToPILImage()(img_bright.cpu()))
plt.title('Input Bright')

# Display the mid input image
plt.subplot(1, 4, 2)
plt.imshow(transforms.ToPILImage()(img_mid.cpu()))
plt.title('Input Mid')

# Display the dark input image
plt.subplot(1, 4, 3)
plt.imshow(transforms.ToPILImage()(img_dark.cpu()))
plt.title('Input Dark')

# Display the generated output image
plt.subplot(1, 4, 4)
plt.imshow(output_image)
plt.title('Generated Output')

plt.show()
