## Deep image prior
Baseline for superresolution/inpainting challenge

# Imports

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision as tv
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import numpy as np
from matplotlib import pyplot as plt
import torch.nn.functional as F
from PIL import Image
import yaml
from data import get_loaders
from utils import *
from network import skip
torch.nn.Module.add = add_module

# Get data

In [None]:
config = {
    "path_train": "data/data_train_plus_test_sourceres/train/",
    "path_test": "data/data_train_plus_test_sourceres/",
    "N": 1,
    'verbose': 1,
    "batchsize": 1
}

config = yaml.safe_load(open("c:/Users/MauriceKingma/Documents/GitHub_repositories/AIMI/configs/dip.yaml"))

if os.getcwd() != "c:/Users/MauriceKingma/Documents/GitHub_repositories/AIMI/code":
    os.chdir("c:/Users/MauriceKingma/Documents/GitHub_repositories/AIMI/code")

# train_loader, val_loader = get_loaders(config)
# set = train_loader.dataset.__getitem__(150)
source = set["source"][0]
target = set["target"]

# Create mask and masked image

In [None]:
# Load images
# source = torch.load("source.pt") # For Colab
# target = torch.load("target.pt") # For Colab

# Create masked image
# source = setData["source"][0] # For dataloader
masked = np.zeros((256,192), dtype = float)
masked[::,1::2] = source / 255
masked = np.expand_dims(masked, axis=0)

# Create mask
mask = np.zeros((256,192), dtype = float)
ones = np.ones((256,96), dtype = float)
mask[::,1::2] = ones
mask = np.expand_dims(mask, axis=0)

# Show mask and masked
plt.imshow(mask[0], cmap="gray")
plt.show()
plt.imshow(masked[0], cmap="gray")
plt.show()

# Create tensors
# masked_tensor = torch.tensor(masked).cuda() # With GPU
# mask_tensor = torch.tensor(mask).cuda() # With GPU
masked_tensor = torch.tensor(masked) # With CPU
mask_tensor = torch.tensor(mask) # With CPU
target_tensor = torch.tensor(target) # With CPU
print(target_tensor.unsqueeze(0).size())


# Official deep-image-prior implementation

In [42]:
# Args
# NET_TYPE = 'skip_depth6'
# INPUT = 'noise'
# input_depth = 32
# LR = 0.01 
# num_iter = 2
# show_every = 10
# figsize = 5
# reg_noise_std = 0.03
# OPT_OVER = 'net'
# OPTIMIZER = 'adam'
# PLOT = True
# # dtype = torch.cuda.FloatTensor # With GPU
# dtype = torch.FloatTensor # With CPU

# Parameters
input_depth = 32
img_shape = 1
# dtype = torch.cuda.FloatTensor # With GPU
dtype = torch.FloatTensor # With CPU
iterations = 2000
show_every = 1
plot = True

# Create model
net = skip(input_depth, img_shape, 
    num_channels_down = [128] * 5,
    num_channels_up =   [128] * 5,
    num_channels_skip =    [128] * 5,  
    filter_size_up = 3, filter_size_down = 3, 
    upsample_mode='bilinear', filter_skip_size=1,
    need_sigmoid=True, need_bias=True, pad='reflection', act_fun='LeakyReLU'
).type(dtype)

# Create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr = 0.01)

# Create initial input
net_input = (0.1) * torch.rand((1,32,256,192))

# Define loss and tensor types
mse = torch.nn.MSELoss().type(dtype)
# lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg').type(dtype)
masked_tensor = masked_tensor.type(dtype)
mask_tensor = mask_tensor.type(dtype)

# Images list for gif
images = []

# losses list for plot
mse_losses = []
lpips_losses = []

for iteration in range(iterations):
    # Init
    optimizer.zero_grad()

    out = net(net_input)
    out = out.squeeze(0)

    # Calculate losses
    mse_loss =  mse(out * mask_tensor, masked_tensor)
    # lpips_loss = lpips(out * mask_tensor, masked_tensor)

    # Append losses
    mse_losses.append(mse_loss.item())
    # lpips_losses.append(lpips_loss.item())

    # Set weights
    mse_loss.backward()

    # Regularization
    optimizer.step()

    # Plot output and print losses
    if plot and iteration % show_every == 0:
        plt.imshow(out.cpu().permute(1,2,0).detach().numpy()[:,:,0] * 255, cmap="gray")
        plt.show()
        print (f"Iteration {iteration}")
        print(f"MSE Loss {mse_loss.item()}")
        # print(f"LPIPS Loss {lpips_loss.item()}")

    net_input = net_input + (1 / (30)) * torch.randn_like(net_input)

plt.imshow(out.cpu().permute(1,2,0).detach().numpy()[:,:,0] * 255, cmap="gray")

TypeError: Figure.savefig() takes 2 positional arguments but 3 were given

<Figure size 640x480 with 0 Axes>