In [None]:
import argparse
import os
import json
import tqdm
import math
import numpy as np
import time
import torch
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from PIL import Image

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from data.dataset import get_dataset, SeqMaskDataset, LVISDataset, V3DetDataset, COCODataset, VisualGenomeDataset
from model import Seq2SeqAutoEncoderModel

model_dir = '/home/dchenbs/workspace/Seq2Seq-AutoEncoder/runs/Nov14_17-31-06_host19-SA1B-[327MB-16queries-1024]-[lr1e-05-bs16x1step-8gpu]/checkpoints/checkpoint_ep0_step350k'
model = Seq2SeqAutoEncoderModel.from_pretrained(model_dir).cuda().eval()

In [None]:
def decode_image_from_data(data, width, height, num_queries, img_channels=3):
    data = data.numpy()
    segment_data = data[1:, :3]
    shape_encoding_seq = data[1:, 3]
    is_data_seq = data[1:, 4]

    shape_encoding_seq = shape_encoding_seq > 0.3
    is_data_seq = is_data_seq > 0.5

    # find the last positive element in shape_encoding_seq, and use it to truncate data and sequences
    last_positive_index = np.nonzero(shape_encoding_seq)[0][-1]
    shape_encoding_seq = shape_encoding_seq[:last_positive_index+1]
    is_data_seq = is_data_seq[:last_positive_index+1]
    segment_data = segment_data[:last_positive_index+1] * 255

    # height is the number of non zero element in shape_encoding_seq
    height_decoded = np.sum(shape_encoding_seq)

    # width is the largest interval between two consecutive non zero elements in shape_encoding_seq
    width_decoded = 0
    true_indices = np.where(shape_encoding_seq)[0]
    true_indices = np.insert(true_indices, 0, 0)
    diffs = np.diff(true_indices)
    if diffs.size > 0:
        width_decoded = np.max(diffs)

    width_decoded += 1 # don't know why, but fix bug

    segment = np.zeros((height_decoded, width_decoded, 3))
    mask = np.zeros((height_decoded, width_decoded))

    # split segment_data into parts according to shape_encoding_seq=True positions, splited parts could be in different length
    split_indices = np.where(shape_encoding_seq)[0]
    split_indices += 1
    split_segment_data = np.split(segment_data, split_indices)
    split_segment_data = [x for x in split_segment_data if len(x) > 0]

    split_is_data = np.split(is_data_seq, split_indices)
    split_is_data = [x for x in split_is_data if len(x) > 0]

    for row_id in range(len(split_segment_data)):
        segment_split = split_segment_data[row_id]
        mask_split = split_is_data[row_id]
        segment[row_id, :len(segment_split), :] = segment_split
        mask[row_id, :len(mask_split)] = mask_split
    
    # apply mask to the segment: set all masked pixels to 255
    segment[mask == 0] = 255  
    segment = segment[:, :-1, :].astype(np.uint8)
    segment = transforms.ToPILImage()(segment)

    return segment, np.array(is_data_seq), np.array(shape_encoding_seq)


def visualize_segments(sample_info, original_segment, reconstructed_segment):

    if 'image_path' in sample_info.keys():
        fig, ax = plt.subplots(1, 3)
        fig.set_size_inches(12, 4)
        fig.suptitle(sample_info['name'])
        ax[0].imshow(Image.open(sample_info['image_path']))
        x, y, w, h = sample_info['bbox']
        rect = plt.Rectangle((x, y), w, h, fill=False, color='red')
        ax[0].add_patch(rect)
        ax[1].imshow(original_segment)
        ax[2].imshow(reconstructed_segment)
    else:
        fig, ax = plt.subplots(1, 2)
        ax[0].imshow(original_segment)
        ax[1].imshow(reconstructed_segment)

    return fig

In [None]:
coco_root = '/home/dchenbs/workspace/datasets/coco2017'
coco_dataset = SeqMaskDataset(
    dataset=COCODataset(coco_root=coco_root, split='val'), 
    num_queries=model.config.num_queries, 
    virtual_dataset_size=244707, 
    data_seq_length=model.config.data_seq_length,
    min_resize_ratio=1,
)

lvis_root = '/home/dchenbs/workspace/datasets/lvis'
coco_root = '/home/dchenbs/workspace/datasets/coco2017'
lvis_dataset = SeqMaskDataset(
    dataset=LVISDataset(lvis_root=lvis_root, coco_root=coco_root, split='val'), 
    num_queries=model.config.num_queries, 
    virtual_dataset_size=244707, 
    data_seq_length=model.config.data_seq_length,
    min_resize_ratio=1,
)

v3det_root = '/home/dchenbs/workspace/datasets/v3det'
v3det_dataset = SeqMaskDataset(
    dataset=V3DetDataset(v3det_root=v3det_root, split='val'), 
    num_queries=model.config.num_queries, 
    virtual_dataset_size=244707, 
    data_seq_length=model.config.data_seq_length,
    min_resize_ratio=1,
)

visual_genome_root = '/home/dchenbs/workspace/datasets/VisualGenome'
visual_genome_dataset = SeqMaskDataset(
    dataset=VisualGenomeDataset(visual_genome_root=visual_genome_root, split='val'), 
    num_queries=model.config.num_queries, 
    virtual_dataset_size=244707, 
    data_seq_length=model.config.data_seq_length,
    min_resize_ratio=1,
)

for dataset in [coco_dataset, lvis_dataset, v3det_dataset, visual_genome_dataset]:
    for _ in range(1):
        batch_size = 1
        batch_data = []
        batch_sample_info = []
        for i in range(batch_size):
            for j in range(50):
                index = np.random.randint(0, len(dataset))
                this_data, this_sample_info = dataset[index]
            batch_data.append(this_data)
            batch_sample_info.append(this_sample_info)

        batch_data = torch.stack(batch_data).cuda()
        batch_latents = model.encode(batch_data)
        batch_reconstructed = model.generate(batch_latents, show_progress_bar=True)


        for i in range(batch_size):
            data = batch_data[i]
            reconstructed = batch_reconstructed[i]
            sample_info = batch_sample_info[i]

            original_segment, original_is_data, original_shape_encoding = decode_image_from_data(
                data.cpu(), 
                sample_info['width'], 
                sample_info['height'], 
                dataset.num_queries, 
                img_channels=dataset.img_channels
                )
            reconstructed_segment, reconstructed_is_data, reconstructed_shape_encoding = decode_image_from_data(
                reconstructed.cpu(), 
                sample_info['width'], 
                sample_info['height'], 
                dataset.num_queries, 
                img_channels=dataset.img_channels
                )

            fig = visualize_segments(sample_info, original_segment, reconstructed_segment)
            print(f"[{dataset.dataset.dataset_name}]: {sample_info['caption']}")

            plt.show()
        