In [7]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

# Define the UNet architecture (same as training)
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Define the contracting path
        self.contracting_path = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Define the expansive path
        self.expansive_path = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2)
        )

    def forward(self, x):
        # Forward pass through contracting and expansive paths
        x = self.contracting_path(x)
        x = self.expansive_path(x)
        return x

# Load the trained model
model = UNet()
model.load_state_dict(torch.load('trained_unet_model.pth'))
model.eval()  # Set the model to evaluation mode

# Define transformations for input image
transform = transforms.Compose([
    transforms.Resize((1632,1224)),
    transforms.ToTensor()
])

# Load and preprocess input image
input_image = Image.open('data/images/image0.jpg').convert('RGB')
input_tensor = transform(input_image).unsqueeze(0)  # Add batch dimension

# Generate masked image using the trained model
with torch.no_grad():
    output_tensor = model(input_tensor)
    masked_image = output_tensor.squeeze(0).permute(1, 2, 0).numpy()  # Convert tensor to numpy array

    # Convert numpy array to PIL Image
masked_image_pil = Image.fromarray((masked_image * 255).astype('uint8'))

# Save the masked image
masked_image_pil.save('masked_image.jpg')