In [1]:
import repaint_sampling as RS
import repaint_patcher as RP
import prepare_glide_inpaint as PGI
from image_util import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch as th
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder, DatasetFolder
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import json
import os

size = 64
large_size = 256

common_transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert("RGB")),
    transforms.Resize(large_size),
    transforms.CenterCrop(large_size),
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2 - 1),
])
with open('data/datasets/imagenet.json') as f:
    imagenet_labels = json.load(f)

def imagenet_target_transform(target):
    return imagenet_labels[int(datasets['imagenet_val'].classes[target])]



class CocoFolder(Dataset):
    def __init__(self, root, annotations_json, transform=None, target_transform=None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        with open(annotations_json) as f:
            coco_json = json.load(f)

        annotations_by_id = {_ann['image_id']: _ann for _ann in coco_json['annotations']}
        coco_filename_to_annotations = {_img['file_name']: annotations_by_id[_img['id']]['caption'] for _img in coco_json['images']}
        self.classes = list(set(coco_filename_to_annotations.keys()))
        self.classes.sort()

        all_files = os.listdir(root)
        all_images = [f for f in all_files if f.endswith('.jpg')]
        self.image_paths = [os.path.basename(f) for f in all_images]
        self.targets = [coco_filename_to_annotations[os.path.basename(f)] for f in all_images]
        self.image_paths = [os.path.join(root, f) for f in self.image_paths]
        self.samples = list(zip(self.image_paths, self.targets))
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, target = self.samples[idx]
        img = Image.open(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

datasets = {
    # 'imagenet_val': ImageFolder('data/datasets/ILSVRC2012_img_val_subset', transform=common_transform, target_transform=imagenet_target_transform),
    'coco_val2017': CocoFolder('data/datasets/val2017', 'data/annotations/captions_val2017.json', transform=common_transform),
    # 'places_365_train': ImageFolder('data/datasets/places365_standard/train', transform=common_transform, target_transform=lambda x: p365t_classes[x].replace("_", " ")),
    # 'places_365_val': ImageFolder('data/datasets/places365_standard/val', transform=common_transform,  target_transform=lambda x: p365v_classes[x].replace("_", " ")),
}

# p365t_classes = datasets['places_365_train'].classes
# p365v_classes = datasets['places_365_val'].classes

masks = {
    'ex64': read_mask('data/masks/64/ex64.png', size=64),
    'genhalf': read_mask('data/masks/64/genhalf.png',size=64),
    'sr64': read_mask('data/masks/64/sr64.png',size=64),
    'thick': read_mask('data/masks/64/thick.png',size=64),
    'thin': read_mask('data/masks/64/thin.png',size=64),
    'vs64': read_mask('data/masks/64/vs64.png',size=64),
}

In [3]:
common_transform_large = transforms.Compose([
    transforms.Lambda(lambda x: x.convert("RGB")),
    transforms.Resize(large_size),
    transforms.CenterCrop(large_size),
    transforms.Resize((large_size, large_size)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x * 2 - 1),
])

datasets_large = {
    # 'imagenet_val': ImageFolder('data/datasets/ILSVRC2012_img_val_subset', transform=common_transform, target_transform=imagenet_target_transform),
    'coco_val2017': CocoFolder('data/datasets/val2017', 'data/annotations/captions_val2017.json', transform=common_transform_large),
    # 'places_365_train': ImageFolder('data/datasets/places365_standard/train', transform=common_transform, target_transform=lambda x: p365t_classes[x].replace("_", " ")),
    # 'places_365_val': ImageFolder('data/datasets/places365_standard/val', transform=common_transform,  target_transform=lambda x: p365v_classes[x].replace("_", " ")),
}

masks_large = {
    'ex64': read_mask('data/masks/64/ex64.png', size=256),
    'genhalf': read_mask('data/masks/64/genhalf.png',size=256),
    'sr64': read_mask('data/masks/64/sr64.png',size=256, resample=Image.NEAREST),
    'thick': read_mask('data/masks/64/thick.png',size=256),
    'thin': read_mask('data/masks/64/thin.png',size=256),
    'vs64': read_mask('data/masks/64/vs64.png',size=256, resample=Image.NEAREST),
}

In [5]:
import random

import torchvision.transforms.functional as F

# Create the output directory if it doesn't exist
output_dir = 'large_masked_coco'
os.makedirs(output_dir, exist_ok=True)

# Function to apply mask to image
def apply_mask(image, mask):
    return image * mask.squeeze(0)

# Iterate over each mask in masks_large
for mask_name, mask in masks_large.items():
    # Read two random images from coco_val2017 in datasets_large
    coco_dataset = datasets_large['coco_val2017']
    indices = random.sample(range(len(coco_dataset)), 2)
    
    for idx in indices:
        img, _ = coco_dataset[idx]
        
        # Apply the mask to the image
        masked_img = apply_mask(img, mask)
        
        # Save both the original image and the masked one
        original_img_path = os.path.join(output_dir, f'{mask_name}_original_{idx}.png')
        masked_img_path = os.path.join(output_dir, f'{mask_name}_masked_{idx}.png')
        
        F.to_pil_image((img + 1) / 2).save(original_img_path)
        _, target = coco_dataset[idx]
        target_label = target.replace(" ", "_")
        masked_img_path = os.path.join(output_dir, f'{mask_name}_masked_{idx}_{target_label}.png')
        F.to_pil_image((masked_img + 1) / 2).save(masked_img_path)