In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

1. Deep Image Prior Approach

Helper Fns

In [None]:
"""
Function to load image and convert to tensor. This function will
rescale the pixel values to lie in [0,1], and add a batch dimension so that
the image shape is (batch, channels, height, width)
"""
def load_image_as_tensor(img_path, img_size=256):
  img = Image.open(img_path)
  transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
  return transform(img).unsqueeze(0)


"""
Function to convert a tensor image into a numpy image. Input image shape is assumed
to be (batch, channels, height, width). Output image shape will be (height, width, channels)
"""
def tensor_image_to_numpy(img_tensor):
  img_np = img_tensor.detach().cpu().numpy()
  return np.transpose( img_np[0,...], (1, 2, 0))

"""
Function to add per-pixel gaussian noise to a given image, with standard deviation of sigma.
"""
def add_noise_to_image(x, sigma):
  x_noise = x + torch.normal(torch.zeros(x.shape), torch.ones(x.shape)*sigma)
  return x_noise.to(x.device)

In [None]:
"""
Function to load image and convert to tensor. This function will
rescale the pixel values to lie in [0,1], and add a batch dimension so that
the image shape is (batch, channels, height, width)
"""
def load_image_as_tensor(img_path, img_size=512):
  img = Image.open(img_path)
  transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])
  return transform(img).unsqueeze(0)

"""
Function to convert a tensor image into a numpy image. Input image shape is assumed
to be (batch, channels, height, width). Output image shape will be (height, width, channels)
"""
def tensor_image_to_numpy(img_tensor):
  img_np = img_tensor.detach().cpu().numpy()
  return np.transpose( img_np[0,...], (1, 2, 0))

"""
Function to add per-pixel gaussian noise to a given image, with standard deviation of sigma.
"""
def add_noise_to_image(x, sigma):
  x_noise = x + torch.normal(torch.zeros(x.shape), torch.ones(x.shape)*sigma)
  return x_noise.to(x.device)

"""
Function to display a grid of images.
"""
def image_grid(imgs, rows, cols):
  assert len(imgs) == rows*cols

  w, h = imgs[0].size
  grid = Image.new('RGB', size=(cols*w, rows*h))
  grid_w, grid_h = grid.size

  for i, img in enumerate(imgs):
      grid.paste(img, box=(i%cols*w, i//cols*h))
  return grid

U-Net

In [None]:
class encoder_block(nn.Module):
  def __init__(self, in_channels, conv_channels):
    super(encoder_block, self).__init__()

    self.nn_layers = nn.Sequential(
        nn.Conv2d(in_channels, conv_channels, kernel_size=3, padding='same'),
        nn.LeakyReLU(),
        nn.BatchNorm2d(conv_channels),
        nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1),
        nn.LeakyReLU(),
        nn.BatchNorm2d(conv_channels),
    )
  def forward(self, x):
    return self.nn_layers(x)

In [None]:
class decoder_block(nn.Module):
  def __init__(self, conv_channels, out_channels):
    super(decoder_block, self).__init__()

    self.nn_layers = nn.Sequential(
        nn.Upsample(scale_factor=2),
        nn.Conv2d(conv_channels, conv_channels, kernel_size=3, padding='same'),
        nn.LeakyReLU(),
        nn.BatchNorm2d(conv_channels),
        nn.Conv2d(conv_channels, out_channels, kernel_size=3, padding='same'),
        nn.Sigmoid(),
        nn.BatchNorm2d(out_channels),

    )
  def forward(self, x):
    return self.nn_layers(x)

In [None]:
class end_block(nn.Module):
  def __init__(self, conv_channels, out_channels):
    super(end_block, self).__init__()

    self.nn_layers = nn.Sequential(
        nn.Upsample(scale_factor=2),
        nn.Conv2d(conv_channels, conv_channels, kernel_size=3, padding='same'),
        nn.BatchNorm2d(conv_channels),
        nn.Sigmoid(),
        nn.Conv2d(conv_channels, out_channels, kernel_size=3, padding='same')
        #nn.BatchNorm2d(out_channels),
        #nn.LeakyReLU(),

    )
  def forward(self, x):
    return self.nn_layers(x)

In [None]:
lass ENCODER_DECODER(nn.Module):
    def __init__(self, depth = 4, in_channels = 3, out_channels = 3, conv_channels = 128):
        super(ENCODER_DECODER, self).__init__()

        self.depth = depth
        self.in_channels = in_channels # Number of channels of input image
        self.out_channels = out_channels # Number of channels of output image
        self.conv_channels = conv_channels # Number of output channels per convolution

        self.block_array = nn.ModuleList()


        self.block_array.append(encoder_block(in_channels, conv_channels))
        self.block_array.append(encoder_block(conv_channels, conv_channels*2))
        self.block_array.append(encoder_block(conv_channels*2, conv_channels*4))
        self.block_array.append(encoder_block(conv_channels*4, conv_channels*8))
        self.block_array.append(decoder_block(conv_channels*8, conv_channels*4))
        self.block_array.append(decoder_block(conv_channels*4, conv_channels*2))
        self.block_array.append(decoder_block(conv_channels*2, conv_channels))
        self.block_array.append(end_block(conv_channels, out_channels))


    def forward(self, x):
        for layer in self.block_array:
            x = layer(x)
        return x

Actual Masking and Inpainting:

In [None]:
"""
This function will take an input image tensor of shape (batch, channels, height, width),
and set all pixel values to 0 within the n x n region with top left corner located at x,y.
"""
def black_out_region(img, n, x, y):
  mask = torch.ones((img.shape))
  mask[:, :, x:x+n, y:y+n] = 0
  return mask


In [None]:
# First, load some real image
img_path = 'endless_field.jpg' # REPLACE WITH YOUR IMAGE NAME HERE
img_size = 256
x = load_image_as_tensor(img_path, img_size).to('cuda')

# inpainted
mask = black_out_region(x, 32, 150, 150).to('cuda')
x_inpainted = mask * x

# If you want to show your image, you can do something like this:
x_inpainted_np = tensor_image_to_numpy(x_inpainted)
plt.imshow(x_inpainted_np)
plt.axis(False)
print(x_inpainted_np.min())
print(x_inpainted_np.max())

In [None]:
# Now, run gradient descent using your CNNs to denoise your image! Use the ADAM optimizer
# with a learning rate <= 0.01.

# Create input noise image
in_channels = 16 # Number of channels of input noise image
input_noise = torch.randn(1,in_channels,img_size,img_size)*0.1
input_noise = input_noise.to('cuda')

net = ENCODER_DECODER(depth=4, in_channels=in_channels, out_channels=3, conv_channels=256)
net = net.to('cuda')
lr = 1e-4
optim = torch.optim.Adam(net.parameters(), lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim)

gt_inpainted_image = x_inpainted.to('cuda')

n_iterations = 1000 # Replace with whatever number you want
for i in range(n_iterations):

  output = net(input_noise)
  # loss = torch.sum(torch.square((black_out_region(gt_inpainted_image - output, 32, 150, 150)))) / ((img_size**2)*3)
  loss = torch.nn.functional.mse_loss(gt_inpainted_image * mask, output * mask)

  # Keep these three following lines as is.
  optim.zero_grad()
  loss.backward()
  optim.step()
  scheduler.step(loss)

  if i%10 == 0:
    print(loss.item())
    print(scheduler.get_last_lr())
    plt.figure()
    plt.imshow(output.permute((0,2,3,1)).detach().cpu().numpy()[0, :, :, :])
    plt.show()
  if i == 200:
    break
  # You probably want to write some code here that will store the current image
  # into an array sporadically, so that you can create a figure from them.



Approach 2: Stable Diffusion

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
!pip install diffusers
!pip install transformers scipy ftfy accelerate

In [None]:
# Load diffusion model for image generation
import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

In [None]:
# Load diffusion model for image inpainting

from diffusers import StableDiffusionInpaintPipeline

model_id = "stabilityai/stable-diffusion-2-inpainting"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

In [None]:
!pip install 'git+https://github.com/facebookresearch/segment-anything.git'

from segment_anything import sam_model_registry, SamPredictor

!wget -q -nc https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
CHECKPOINT_PATH='/content/sam_vit_b_01ec64.pth'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_b"

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
sam.to(device=DEVICE)

In [None]:
import cv2

# Give the path of your image
IMAGE_PATH = 'swan.jpeg'
# Read the image from the path
image = cv2.imread(IMAGE_PATH)
print(image)
# Convert to RGB format
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
sam.to(device=DEVICE)
mask_predictor = SamPredictor(sam)
mask_predictor.set_image(image_rgb)

# Provide points as input prompt [X, Y]-coordinates
input_point = np.array([[71, 41], [64, 44], [85, 109], [141, 110]])
input_label = np.array([1, 1, 1, 1])

# Predicting Segmentation mask
masks, scores, logits = mask_predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=False,
)


mask = masks.astype(float) * 255
mask = np.transpose(mask, (1, 2, 0))
_ , bw_image = cv2.threshold(mask, 100, 255, cv2.THRESH_BINARY)
cv2.imwrite('mask.png', bw_image)
del sam, mask_predictor   # delete models to conserve GPU memory

image = Image.open(IMAGE_PATH)
mask = Image.open('mask.png')
prompt = "a surfer riding a surf board"
output = "/*.png"             # output filename

# inpainted = inpaint_stablediffusion(image, mask, prompt)
inpainted = pipe(image=image, mask_image=mask, prompt=prompt, guidance_scale=2)
inpainted['images'][0].save('inpainted.png')
inpainted['images'][0]