<a href="https://colab.research.google.com/github/ArkS0001/HeritageAI-Generating-Cultural-Heritage-Imagery-with-LLMs/blob/main/PlanB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install opencv-python numpy torch torchvision matplotlib gdown


Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-

In [3]:
!git clone https://github.com/vrindaprabhu/deepfillv2_colab.git
!gdown "https://drive.google.com/u/0/uc?id=1uMghKl883-9hDLhSiI8lRbHCzCmmRwV-&export=download"
!mv /content/deepfillv2_WGAN_G_epoch40_batchsize4.pth deepfillv2_colab/model/deepfillv2_WGAN.pth

Cloning into 'deepfillv2_colab'...
remote: Enumerating objects: 99, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 99 (delta 2), reused 1 (delta 1), pack-reused 96[K
Receiving objects: 100% (99/99), 571.56 KiB | 510.00 KiB/s, done.
Resolving deltas: 100% (44/44), done.
Downloading...
From: https://drive.google.com/u/0/uc?id=1uMghKl883-9hDLhSiI8lRbHCzCmmRwV-&export=download
To: /content/deepfillv2_WGAN_G_epoch40_batchsize4.pth
100% 64.8M/64.8M [00:00<00:00, 74.8MB/s]


In [4]:
# Import necessary libraries
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image
import gdown

# Define the DeepFill v2 generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.middle = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=4, dilation=4),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=8, dilation=8),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=16, dilation=16),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

# Function to preprocess the image
def preprocess_image(image_path):
    # Load the image
    image = cv2.imread(image_path)
    # Convert to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # Resize to 256x256
    image = cv2.resize(image, (256, 256))
    # Create a mask (this is a simple example, a real mask should come from the user)
    mask = np.zeros_like(image)
    mask[100:150, 100:150, :] = 1  # Simulate some missing area
    # Normalize and convert to tensors
    image = transforms.ToTensor()(image)
    mask = transforms.ToTensor()(mask)
    # Concatenate image and mask
    image_with_mask = torch.cat((image, mask), dim=0)
    return image, mask, image_with_mask.unsqueeze(0)

# Function to generate an image using the generator
def generate_image(generator, image_with_mask):
    with torch.no_grad():
        generated_image = generator(image_with_mask).cpu()
    return generated_image

# Main function
def main(image_path, model_path):
    # Preprocess the image
    original_image, mask, image_with_mask = preprocess_image(image_path)

    # Initialize and load the pre-trained generator
    generator = Generator()
    generator.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    generator.eval()

    # Generate the reconstructed image
    reconstructed_image = generate_image(generator, image_with_mask)

    # Save and display the generated image
    save_image(reconstructed_image, 'reconstructed_image.png')

    # Display the original, mask, and reconstructed images
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.title('Original Image')
    plt.imshow(np.transpose(original_image.numpy(), (1, 2, 0)))
    plt.subplot(1, 3, 2)
    plt.title('Mask')
    plt.imshow(np.transpose(mask.numpy(), (1, 2, 0)))
    plt.subplot(1, 3, 3)
    plt.title('Reconstructed Image')
    plt.imshow(np.transpose(reconstructed_image[0].numpy(), (1, 2, 0)))
    plt.show()

# Set paths to the image and model
image_path = '/content/testPLAN-B.jpg'
model_path = '/content/deepfillv2_colab/model/deepfillv2_WGAN.pth'

# Download the pre-trained model if not already downloaded
# gdown.download('https://drive.google.com/uc?id=1qdWbW0_0XBIkq2PZBdlrwbpa7GnGZBF7', model_path, quiet=False)

# Run the main function
main(image_path, model_path)


RuntimeError: Error(s) in loading state_dict for Generator:
	Missing key(s) in state_dict: "encoder.0.weight", "encoder.0.bias", "encoder.2.weight", "encoder.2.bias", "encoder.4.weight", "encoder.4.bias", "encoder.6.weight", "encoder.6.bias", "encoder.8.weight", "encoder.8.bias", "middle.0.weight", "middle.0.bias", "middle.2.weight", "middle.2.bias", "middle.4.weight", "middle.4.bias", "middle.6.weight", "middle.6.bias", "decoder.0.weight", "decoder.0.bias", "decoder.2.weight", "decoder.2.bias", "decoder.4.weight", "decoder.4.bias", "decoder.6.weight", "decoder.6.bias". 
	Unexpected key(s) in state_dict: "coarse.0.conv2d.weight", "coarse.0.conv2d.bias", "coarse.0.mask_conv2d.weight", "coarse.0.mask_conv2d.bias", "coarse.1.conv2d.weight", "coarse.1.conv2d.bias", "coarse.1.mask_conv2d.weight", "coarse.1.mask_conv2d.bias", "coarse.2.conv2d.weight", "coarse.2.conv2d.bias", "coarse.2.mask_conv2d.weight", "coarse.2.mask_conv2d.bias", "coarse.3.conv2d.weight", "coarse.3.conv2d.bias", "coarse.3.mask_conv2d.weight", "coarse.3.mask_conv2d.bias", "coarse.4.conv2d.weight", "coarse.4.conv2d.bias", "coarse.4.mask_conv2d.weight", "coarse.4.mask_conv2d.bias", "coarse.5.conv2d.weight", "coarse.5.conv2d.bias", "coarse.5.mask_conv2d.weight", "coarse.5.mask_conv2d.bias", "coarse.6.conv2d.weight", "coarse.6.conv2d.bias", "coarse.6.mask_conv2d.weight", "coarse.6.mask_conv2d.bias", "coarse.7.conv2d.weight", "coarse.7.conv2d.bias", "coarse.7.mask_conv2d.weight", "coarse.7.mask_conv2d.bias", "coarse.8.conv2d.weight", "coarse.8.conv2d.bias", "coarse.8.mask_conv2d.weight", "coarse.8.mask_conv2d.bias", "coarse.9.conv2d.weight", "coarse.9.conv2d.bias", "coarse.9.mask_conv2d.weight", "coarse.9.mask_conv2d.bias", "coarse.10.conv2d.weight", "coarse.10.conv2d.bias", "coarse.10.mask_conv2d.weight", "coarse.10.mask_conv2d.bias", "coarse.11.conv2d.weight", "coarse.11.conv2d.bias", "coarse.11.mask_conv2d.weight", "coarse.11.mask_conv2d.bias", "coarse.12.gated_conv2d.conv2d.module.bias", "coarse.12.gated_conv2d.conv2d.module.weight_u", "coarse.12.gated_conv2d.conv2d.module.weight_v", "coarse.12.gated_conv2d.conv2d.module.weight_bar", "coarse.12.gated_conv2d.mask_conv2d.module.bias", "coarse.12.gated_conv2d.mask_conv2d.module.weight_u", "coarse.12.gated_conv2d.mask_conv2d.module.weight_v", "coarse.12.gated_conv2d.mask_conv2d.module.weight_bar", "coarse.13.conv2d.weight", "coarse.13.conv2d.bias", "coarse.13.mask_conv2d.weight", "coarse.13.mask_conv2d.bias", "coarse.14.gated_conv2d.conv2d.module.bias", "coarse.14.gated_conv2d.conv2d.module.weight_u", "coarse.14.gated_conv2d.conv2d.module.weight_v", "coarse.14.gated_conv2d.conv2d.module.weight_bar", "coarse.14.gated_conv2d.mask_conv2d.module.bias", "coarse.14.gated_conv2d.mask_conv2d.module.weight_u", "coarse.14.gated_conv2d.mask_conv2d.module.weight_v", "coarse.14.gated_conv2d.mask_conv2d.module.weight_bar", "coarse.15.conv2d.weight", "coarse.15.conv2d.bias", "coarse.15.mask_conv2d.weight", "coarse.15.mask_conv2d.bias", "coarse.16.conv2d.weight", "coarse.16.conv2d.bias", "coarse.16.mask_conv2d.weight", "coarse.16.mask_conv2d.bias", "refine_conv.0.conv2d.weight", "refine_conv.0.conv2d.bias", "refine_conv.0.mask_conv2d.weight", "refine_conv.0.mask_conv2d.bias", "refine_conv.1.conv2d.weight", "refine_conv.1.conv2d.bias", "refine_conv.1.mask_conv2d.weight", "refine_conv.1.mask_conv2d.bias", "refine_conv.2.conv2d.weight", "refine_conv.2.conv2d.bias", "refine_conv.2.mask_conv2d.weight", "refine_conv.2.mask_conv2d.bias", "refine_conv.3.conv2d.weight", "refine_conv.3.conv2d.bias", "refine_conv.3.mask_conv2d.weight", "refine_conv.3.mask_conv2d.bias", "refine_conv.4.conv2d.weight", "refine_conv.4.conv2d.bias", "refine_conv.4.mask_conv2d.weight", "refine_conv.4.mask_conv2d.bias", "refine_conv.5.conv2d.weight", "refine_conv.5.conv2d.bias", "refine_conv.5.mask_conv2d.weight", "refine_conv.5.mask_conv2d.bias", "refine_conv.6.conv2d.weight", "refine_conv.6.conv2d.bias", "refine_conv.6.mask_conv2d.weight", "refine_conv.6.mask_conv2d.bias", "refine_conv.7.conv2d.weight", "refine_conv.7.conv2d.bias", "refine_conv.7.mask_conv2d.weight", "refine_conv.7.mask_conv2d.bias", "refine_conv.8.conv2d.weight", "refine_conv.8.conv2d.bias", "refine_conv.8.mask_conv2d.weight", "refine_conv.8.mask_conv2d.bias", "refine_conv.9.conv2d.weight", "refine_conv.9.conv2d.bias", "refine_conv.9.mask_conv2d.weight", "refine_conv.9.mask_conv2d.bias", "refine_atten_1.0.conv2d.weight", "refine_atten_1.0.conv2d.bias", "refine_atten_1.0.mask_conv2d.weight", "refine_atten_1.0.mask_conv2d.bias", "refine_atten_1.1.conv2d.weight", "refine_atten_1.1.conv2d.bias", "refine_atten_1.1.mask_conv2d.weight", "refine_atten_1.1.mask_conv2d.bias", "refine_atten_1.2.conv2d.weight", "refine_atten_1.2.conv2d.bias", "refine_atten_1.2.mask_conv2d.weight", "refine_atten_1.2.mask_conv2d.bias", "refine_atten_1.3.conv2d.weight", "refine_atten_1.3.conv2d.bias", "refine_atten_1.3.mask_conv2d.weight", "refine_atten_1.3.mask_conv2d.bias", "refine_atten_1.4.conv2d.weight", "refine_atten_1.4.conv2d.bias", "refine_atten_1.4.mask_conv2d.weight", "refine_atten_1.4.mask_conv2d.bias", "refine_atten_1.5.conv2d.weight", "refine_atten_1.5.conv2d.bias", "refine_atten_1.5.mask_conv2d.weight", "refine_atten_1.5.mask_conv2d.bias", "refine_atten_2.0.conv2d.weight", "refine_atten_2.0.conv2d.bias", "refine_atten_2.0.mask_conv2d.weight", "refine_atten_2.0.mask_conv2d.bias", "refine_atten_2.1.conv2d.weight", "refine_atten_2.1.conv2d.bias", "refine_atten_2.1.mask_conv2d.weight", "refine_atten_2.1.mask_conv2d.bias", "refine_combine.0.conv2d.weight", "refine_combine.0.conv2d.bias", "refine_combine.0.mask_conv2d.weight", "refine_combine.0.mask_conv2d.bias", "refine_combine.1.conv2d.weight", "refine_combine.1.conv2d.bias", "refine_combine.1.mask_conv2d.weight", "refine_combine.1.mask_conv2d.bias", "refine_combine.2.gated_conv2d.conv2d.module.bias", "refine_combine.2.gated_conv2d.conv2d.module.weight_u", "refine_combine.2.gated_conv2d.conv2d.module.weight_v", "refine_combine.2.gated_conv2d.conv2d.module.weight_bar", "refine_combine.2.gated_conv2d.mask_conv2d.module.bias", "refine_combine.2.gated_conv2d.mask_conv2d.module.weight_u", "refine_combine.2.gated_conv2d.mask_conv2d.module.weight_v", "refine_combine.2.gated_conv2d.mask_conv2d.module.weight_bar", "refine_combine.3.conv2d.weight", "refine_combine.3.conv2d.bias", "refine_combine.3.mask_conv2d.weight", "refine_combine.3.mask_conv2d.bias", "refine_combine.4.gated_conv2d.conv2d.module.bias", "refine_combine.4.gated_conv2d.conv2d.module.weight_u", "refine_combine.4.gated_conv2d.conv2d.module.weight_v", "refine_combine.4.gated_conv2d.conv2d.module.weight_bar", "refine_combine.4.gated_conv2d.mask_conv2d.module.bias", "refine_combine.4.gated_conv2d.mask_conv2d.module.weight_u", "refine_combine.4.gated_conv2d.mask_conv2d.module.weight_v", "refine_combine.4.gated_conv2d.mask_conv2d.module.weight_bar", "refine_combine.5.conv2d.weight", "refine_combine.5.conv2d.bias", "refine_combine.5.mask_conv2d.weight", "refine_combine.5.mask_conv2d.bias", "refine_combine.6.conv2d.weight", "refine_combine.6.conv2d.bias", "refine_combine.6.mask_conv2d.weight", "refine_combine.6.mask_conv2d.bias". 