In [None]:
from transformers import AutoImageProcessor, SwinForMaskedImageModeling
from PIL import Image
import torch
import matplotlib.pyplot as plt

# Load and convert original and masked images to RGB
original_image_path = "p1.png"
masked_image_path = "masked_p1.png"  # Path to the masked image

original_image = Image.open(original_image_path).convert("RGB")
masked_image = Image.open(masked_image_path).convert("RGB")

# Initialize processor and model, either load pre-trained from Hugging Face or local path
# Save processor and model after initialization
processor_path = "processor"
model_path = "model"

# Uncomment the following lines if you haven't already saved the model and processor
# image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
# model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")
# model.save_pretrained(model_path)
# image_processor.save_pretrained(processor_path)

# Load processor and model from saved paths
image_processor = AutoImageProcessor.from_pretrained(processor_path)
model = SwinForMaskedImageModeling.from_pretrained(model_path)

# Process original and masked images
pixel_values_original = image_processor(images=original_image, return_tensors="pt").pixel_values
pixel_values_masked = image_processor(images=masked_image, return_tensors="pt").pixel_values

# Create a mask for the masked image (Assuming the mask is known or can be derived)
num_patches = (model.config.image_size // model.config.patch_size) ** 2
# Here, you should load or define the mask that corresponds to the masked image
# For demonstration, let's assume we have a similar random mask
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

# Run the model with the masked image
outputs = model(pixel_values_masked, bool_masked_pos=bool_masked_pos)
loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction

# Print results
print("Reconstructed pixel values shape:", reconstructed_pixel_values.shape)
print("Loss:", loss.item())

# Convert reconstructed pixel values to image and display
reconstructed_image = reconstructed_pixel_values.detach().cpu().squeeze().permute(1, 2, 0).numpy()
plt.imshow(reconstructed_image)
plt.title("Reconstructed Image from Masked Input")
plt.axis("off")
plt.show()
