In [14]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# Define the UNet architecture (same as training)
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Define the contracting path
        self.contracting_path = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Define the expansive path
        self.expansive_path = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # Forward pass through contracting and expansive paths
        x = self.contracting_path(x)
        x = self.expansive_path(x)
        return x

# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
model.load_state_dict(torch.load('trained_unet_model.pth', map_location=device))
model.eval()  # Set the model to evaluation mode

# Define transformations for input image and mask
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize to model input size
    transforms.ToTensor()
])

# Load and preprocess input image
input_image = Image.open('path_to_input_image.jpg').convert('RGB')
input_tensor = transform(input_image).unsqueeze(0).to(device)  # Add batch dimension

# Generate masked image using the trained model
with torch.no_grad():
    output_tensor = model(input_tensor)
    masked_output = (output_tensor.sigmoid() > 0.5).cpu().numpy()  # Binarize the output
    masked_output = masked_output.squeeze(0).transpose(1, 2, 0)  # Convert CHW to HWC format

# Apply the mask to the original image
input_np = np.array(input_image)
masked_image = input_np * masked_output  # Element-wise multiplication

# Convert numpy array to PIL Image
masked_image_pil = Image.fromarray(masked_image.astype('uint8'))

# Save the masked image
masked_image_pil.save('masked_text_image.jpg')