In [1]:
import torch
from pconv_cr.model import PConvUNet  # Importing a pre-trained Partial Convolution GAN (PConv)
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
from io import BytesIO

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Load the pre-trained PConv-UNet model
model = PConvUNet().to(device)
checkpoint = torch.load("pconv_unet_checkpoint.pth")  # Replace with path to actual weights
model.load_state_dict(checkpoint['state_dict'])
model.eval()  # Set the model to evaluation mode

In [4]:
# Load and preprocess an image
def load_image(image_path, size=256):
    '''
    Loads and preprocesses an image.
    '''
    img = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)  # Add batch dimension and move to device
    return img_tensor

In [5]:
# Create a mask to define the missing regions (e.g., borders)
def create_border_mask(image_tensor, border_width=50):
    '''
    Creates a binary mask where the borders are set to 0 (indicating missing parts) 
    and the rest of the image is 1.
    '''
    _, _, h, w = image_tensor.shape
    mask = torch.ones_like(image_tensor).to(device)
    mask[:, :, :border_width, :] = 0  # Top border
    mask[:, :, -border_width:, :] = 0  # Bottom border
    mask[:, :, :, :border_width] = 0  # Left border
    mask[:, :, :, -border_width:] = 0  # Right border
    return mask

In [6]:
# Example image loading
image_url = 'https://example.com/image.jpg'  # Replace with local path or URL of an image
response = requests.get(image_url)
img = Image.open(BytesIO(response.content))
img.show()

In [None]:
# Preprocess the image for model input
input_image = load_image(BytesIO(response.content))

# Create a mask for the missing border regions
border_mask = create_border_mask(input_image)

# Generate border using the GAN model
with torch.no_grad():
    output, _ = model(input_image, border_mask)  # PConv-UNet takes the image and the mask as input

In [None]:
# Postprocess the output and convert it back to an image
def postprocess_and_show(output_tensor):
    '''
    Post-process the output tensor and display the image.
    '''
    unloader = transforms.ToPILImage()
    image = output_tensor.cpu().clone()  # Clone the tensor to avoid modifying the original
    image = image.squeeze(0)  # Remove the batch dimension
    image = unloader(image)
    plt.imshow(image)
    plt.axis('off')
    plt.show()

In [None]:
# Show the output image
postprocess_and_show(output)