In [None]:
import argparse
import os
import json
import tqdm
import math
import numpy as np
import random
import time
import torch
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from PIL import Image
import pprint
import pandas as pd
import seaborn as sns
import cv2

np.random.seed(42)
random.seed(42)

from data.segmentation_dataset import SA1BDataset, COCODataset, LVISDataset, SeqMaskDataset, V3DetDataset, VisualGenomeDataset


In [None]:
sa1b=SA1BDataset(sa1b_root='/home/dchenbs/workspace/datasets/sa1b')
# coco = COCODataset(coco_root='/home/dchenbs/workspace/datasets/coco2017', split='train')
# lvis=LVISDataset(lvis_root='/home/dchenbs/workspace/datasets/lvis', coco_root='/home/dchenbs/workspace/datasets/coco2017', split='train')
# v3det=V3DetDataset(v3det_root='/home/dchenbs/workspace/datasets/v3det', split='train')
# visual_genome=VisualGenomeDataset(visual_genome_root='/home/dchenbs/workspace/datasets/VisualGenome', split='train')

# datasets = [v3det, visual_genome, sa1b, coco, lvis, ]
datasets = [sa1b]

In [None]:
segment_per_dataset = 32
for dataset in datasets:
    num_segments = 0
    widths = []
    heights = []
    while num_segments < segment_per_dataset:
        segments = dataset.load_segments_from_one_image()
        for segment in segments:
            width = segment['bbox'][2]
            height = segment['bbox'][3]
            if width/height < 1/8 and width*height > 10000:
            # if width/height > 8 and width*height > 10000:
                img = Image.open(segment['image_path']).convert('RGB')
                # draw segment bbox
                img = np.array(img)
                img = cv2.rectangle(img, (segment['bbox'][0], segment['bbox'][1]), (segment['bbox'][0]+segment['bbox'][2], segment['bbox'][1]+segment['bbox'][3]), (0, 255, 0), 8)

                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.imshow(img)
                plt.axis('off')
                
                segment_image = segment['patch']# * segment['mask'][:, :, None]
                segment_image[segment['mask'] == 0] = 255

                plt.subplot(1, 2, 2)
                plt.imshow(segment_image)
                plt.axis('off')
                plt.show()

                num_segments += 1
                break

In [None]:
from tqdm import tqdm

segment_per_dataset = 200000
widths_and_heights = {}
for dataset in datasets:
    num_segments = 0
    widths = []
    heights = []
    pbar = tqdm(total=segment_per_dataset)
    while num_segments < segment_per_dataset:
        segments = dataset.load_segments_from_one_image()
        for segment in segments:
            widths.append(segment['bbox'][2])
            heights.append(segment['bbox'][3])
        num_segments += len(segments)
        pbar.update(len(segments))
    pbar.close()

    widths = np.array(widths)
    heights = np.array(heights)
    widths_and_heights[dataset.dataset_name] = (widths, heights)

In [None]:
sns.set_theme(style="darkgrid")
limit = 300

fig_height = 8
for dataset in datasets[:3]:
    print(dataset.dataset_name)
    widths, heights = widths_and_heights[dataset.dataset_name]
    aspect_ratios = widths / heights
    # for aspect ratios that < 1, we use 1 / aspect ratio
    aspect_ratios[aspect_ratios < 1] = 1 / aspect_ratios[aspect_ratios < 1]
    aspect_ratios = np.clip(aspect_ratios, 0, 10)
    aspect_ratios = np.log(aspect_ratios)

    plt.figure(figsize=(fig_height, fig_height))
    
    # sns.scatterplot(
    #     x=widths, y=heights, hue=aspect_ratios, 
    #     legend=False, s=5, alpha=0.5, linewidth=0)
    
    sns.jointplot(
        x=widths, y=heights, 
        # marginal_kws=dict(bins=50),
        xlim=(0, limit), ylim=(0, limit),
        s=8, alpha=0.1, linewidth=0, 
        height=fig_height
        )
    
    colors = sns.color_palette('ch:s=.25,rot=-.25')

    plt.plot([0, limit], [0, limit], linestyle='-', alpha=0.8, label='Aspect Ratio = 1', color=colors[1], linewidth=2)
    # plot height = 0.1 * width and height = 10 * width
    for i in range(1, 3):
        slope = 2 ** i
        color = colors[i*2]
        plt.plot([0, limit], [0, limit * slope], linestyle='-', alpha=0.8, color=color, linewidth=2, label=f'Aspect Ratio = {slope}')
        plt.plot([0, limit], [0, limit / slope], linestyle='-', alpha=0.8, color=color, linewidth=2)

    # draw (y = 32 / x) as color='red'
    x = np.linspace(1, limit, 1000)
    y = 1024 / x
    plt.plot(x, y, color='red', label='Width * Height = 1024', linewidth=1)
    
    plt.legend(loc='upper right')

    plt.xlabel('Width')
    plt.ylabel('Height')
    plt.xlim(0, limit)
    plt.ylim(0, limit)
    plt.show()