In [1]:
import sys
sys.path.append('../../30_data_tools/')

In [2]:
from pathlib import Path
from tqdm import tqdm
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from PIL import Image
import numpy as np
import pickle
import random
import cv2

import pandas as pd
import sqlite3
from time import time

In [3]:
from mask_functions import get_config, load_mask_img, is_below_max_size, is_above_min_size, is_text_mask, filter_intersected_masks, save_masks
from helper import load_dotenv

In [4]:
from PIL import Image
Image.MAX_IMAGE_PIXELS = None

In [5]:
config = get_config()
config['target_variant'] = 'halftone600dpi'

In [6]:
dotenv = load_dotenv()

In [7]:
con = sqlite3.connect(dotenv['DB_PATH'])

In [8]:
remaining_pages = pd.read_sql(
    f'''
        SELECT cf.* FROM (
        	SELECT * FROM related_file
        	WHERE variant_name = '{ config["target_variant"] }' AND "type" = '4c'
        ) cf
        LEFT JOIN (
        	SELECT job, pdf_filename, 1 AS has_mask FROM related_file 
        	WHERE variant_name = '{ config["target_variant"] }' AND "type" = 'masks'
        ) mf ON cf.job=mf.job AND cf.pdf_filename=mf.pdf_filename 
        WHERE mf.has_mask IS NULL
    ''',
    con
).sample(frac=1)

In [9]:
import torch

In [10]:
device = torch.device('mps')

In [11]:
sam = sam_model_registry["vit_h"](checkpoint=dotenv['MODEL_DIR'] / "sam_vit_h_4b8939.pth")
#sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

In [12]:
relevant_indexes = []

for i in tqdm(range(remaining_pages.shape[0])):
    row = remaining_pages.iloc[i]
    img_path = dotenv['DATA_DIR'] / row['job'] / row['variant_name'] / row['filename']
    
    mask_path = img_path.parent / f'{ img_path.name.strip( "." + row["type"] + img_path.suffix ) }.masks.pkl'
    mask_path_300dpi = dotenv['DATA_DIR'] / row['job'] / 'halftone300dpi' / mask_path.name

    if mask_path.exists() == False and mask_path_300dpi.exists() == False:
        relevant_indexes.append(row.name)

100%|██████████████████████████████████| 1083/1083 [00:00<00:00, 24852.59it/s]


In [13]:
f"{remaining_pages.shape[0] - len(relevant_indexes)}/{remaining_pages.shape[0]}; { len(relevant_indexes) } verbleibend"

'956/1083; 127 verbleibend'

In [14]:
0 / 0

ZeroDivisionError: division by zero

In [None]:
time_rows = []

for idx in tqdm(relevant_indexes):
    row = remaining_pages.loc[idx]
    img_path = dotenv['DATA_DIR'] / row['job'] / row['variant_name'] / row['filename']
    
    mask_path = img_path.parent / f'{ img_path.name.strip( "." + row["type"] + img_path.suffix ) }.masks.pkl'
    mask_path_300dpi = dotenv['DATA_DIR'] / row['job'] / 'halftone300dpi' / mask_path.name

    if mask_path_300dpi.exists() == False:
        times = []
        img = Image.open(img_path)
        orig_size = img.size
        img = img.resize((
            int( round(img.size[0] * config['MASK_IMG_SCALE_FACTOR']) ),
            int( round(img.size[1] * config['MASK_IMG_SCALE_FACTOR']) )
        ))
        
        start = time()
        times.append(("start",time()))
        masks = mask_generator.generate( np.array(img.convert("RGB")) )
        #print( f"mask generation took: { ( time() - start ) }")
        times.append(("masks generated",time()))
        
        for m in masks:
            factor_x = orig_size[0] / m['segmentation'].shape[1]
            factor_y = orig_size[1] / m['segmentation'].shape[0]
            
            m['segmentation'] = m['segmentation'][
                int(m['bbox'][1]):int(m['bbox'][1]+m['bbox'][3]+1),
                int(m['bbox'][0]):int(m['bbox'][0]+m['bbox'][2]+1)
            ]
            m['bbox'] = [
                int(round(m['bbox'][0] * factor_x)),
                int(round(m['bbox'][1] * factor_y)),
                int(round(m['bbox'][2] * factor_x)),
                int(round(m['bbox'][3] * factor_y))
            ]
            m['point_coords'] = [[
                int(round(m['point_coords'][0][0] * factor_x)),
                int(round(m['point_coords'][0][1] * factor_y)),
            ]]
    
            m['crop_box'] = [
                int(round(m['crop_box'][0] * factor_x)),
                int(round(m['crop_box'][1] * factor_y)),
                int(round(m['crop_box'][2] * factor_x)),
                int(round(m['crop_box'][3] * factor_y))
            ]

        masks = [m for m in masks if m['area'] < (img.size[0] * img.size[1] * 0.25)]
        times.append(("area filtered",time()))
        
        masks_out = []

        for m in masks:
            mask_out = {
                'mask' : m['segmentation'],
                'bbox' : m['bbox'],
                'predicted_iou' : m['predicted_iou'],
                'stability_score' : m['stability_score'],
                'img_size' : orig_size
            }
            mask_out['mask'] = load_mask_img(mask_out)
            masks_out.append(mask_out)

        
        masks_out = [m for m in masks_out if is_below_max_size(m)]
        # filter by size
        masks_out = [m for m in masks_out if is_above_min_size(m)]
        times.append(("size filtered",time()))
        # filter by text box
        masks_out = [m for m in masks_out if is_text_mask( img, m ) == False]
        times.append(("text mask",time()))
        # filter duplicates
        masks_out = filter_intersected_masks( masks_out )
        times.append(("intersected",time()))
        time_rows.append(times)
        
        save_masks( masks_out, mask_path )

In [None]:
times = []

for i in range(len(time_rows)):
    times += [(i+1, t[0], t[1] - time_rows[i][0][1]) for t in time_rows[i]]

time_df = pd.DataFrame(
    times,
    columns=['run','step','duration']
).set_index(['run','step'])

time_df = time_df.unstack('step')
time_df.columns = [c[1] for c in time_df.columns]

In [None]:
import plotly.express as px

In [None]:
px.bar(time_df.intersected)

In [None]:
time_df.intersected.describe()

In [None]:
time_df.intersected - time_df['text mask']