In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('../..'))

import importlib

import contextlib

import json
import os
import random

import numpy as np
import torch
import pandas_streaming
from collections import defaultdict
from fileutils import get_default_path


from automation.automationmanager import make_default_manager
from automation.crawler.midjourneycrawl import crawl_gallery_user, crawl_gallery_feed
from automation.midjourney.midjourneyutils import FeedType

from storage.data.image.remoteimageinfo import imgs_to_cmds
from storage.data.image.crawledimagegroups import CrawledImageGroups
from storage.data.command import Command
from storage.data.command.commandbuilder import CommandBuilder
from storage.data.user.userids import MJ_USER_TO_ID
from storage.data.user.mjuser import MJUser
import time
from util import Stopwatch
import datetime

import ai.stabledisco as sd
import ai.torchmodules as torchmodules
import ai.torchmodules.data as torchdata
import ai.torchmodules.utils as torchutils
import ai.stabledisco.utils as sdutils
import clip
import ai.nlp
import torch
import torch.nn as nn
import pandas as pd
from storage.data.command.stablediscoprompt import arg_prompt_split

In [2]:
df_path = get_default_path("large_datasets", "aug_prompts.feather")
if "prompt_dataframe" in dir():
    del prompt_dataframe
prompt_dataframe = pd.read_feather(df_path)

In [None]:
prompt_dataframe.to_feather(df_path)

In [None]:
for row in prompt_dataframe.itertuples():
    idx = row[0]
    if idx != 0 and idx % 250000 == 0:
        print(len(prompt_dataframe) - idx)
    if row.text_tokens is None:
        continue
    prompt_dataframe.at[idx, "text_tokens"] = prompt_dataframe.at[idx, "text_tokens"].astype(np.uint16)       

In [None]:
device = torchutils.get_default_device()
vit14_clip_model, preprocessor = clip.load('ViT-L/14')

In [None]:
import torchvision.transforms as T
import PIL
img_to_tensor = T.ToTensor()
def imgs_to_tensor(imgs):
    if isinstance(imgs, PIL.Image.Image):
        return img_to_tensor(imgs).cuda()
    return torch.stack(tuple((img_to_tensor(img) for img in imgs))).cuda()
preprocess = T.Compose([
        T.Resize(vit14_clip_model.visual.input_resolution, interpolation=T.InterpolationMode.BICUBIC),
        T.CenterCrop(vit14_clip_model.visual.input_resolution),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

In [None]:
print(len(prompt_dataframe[prompt_dataframe["img_features"].isna()]))
print(len(prompt_dataframe[prompt_dataframe["img_features"].notna()]))

In [None]:
prompt_dataframe["partial_img_features"] = None

In [None]:
import contextlib
partial = False
if partial:
    col = "partial_img_features"
else:
    col = "img_features"

def save_features(data_frame, idx_paths, partial=False):
    if len(idx_paths) == 0:
        print("No paths")
        return
    
    with torch.no_grad():
        to_process = []
        img_to_size = defaultdict(list)
        for idx, img_path in idx_paths:
            try:
                img = sdutils.load_img(img_path)
            except:
                continue
                
            img_to_size[img.size].append((idx, img))

        written = 0
        for size, lst in img_to_size.items():
            idxs = [x[0] for x in lst]
            if partial:
                features = get_partial_features([x[1] for x in lst])
                col = "partial_img_features"
            else:
                features = get_img_features([x[1] for x in lst])
                col = "img_features"
            
            for idx, feature in zip(idxs, features):
                data_frame.at[idx, col] = feature    
            del features
            
            written += len(lst)
            
        del img_to_size
        return written
    
def get_img_features(imgs):
    img_tensors = imgs_to_tensor(imgs)
    del imgs

    preprocessed_tensors = preprocess(img_tensors)
    del img_tensors

    return vit14_clip_model.encode_image(preprocessed_tensors).cpu().numpy().tolist()

def get_partial_features(imgs):
    img_tensors = imgs_to_tensor(imgs).half()
    del imgs

    x = preprocess(img_tensors)
    del img_tensors
    
    x = vit14_clip_model.visual.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([vit14_clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
    x = x + vit14_clip_model.visual.positional_embedding.to(x.dtype)
    x = vit14_clip_model.visual.ln_pre(x)

    x = x.permute(1, 0, 2)  # NLD -> LND
    x = vit14_clip_model.visual.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD

    return x.cpu().numpy().tolist()
    

imgs_paths_to_process = []
per_df_write = 1000
start = time.perf_counter()
written = 0
per_save = 20000
torchutils.torch_garbage_collect()
for row in prompt_dataframe.itertuples():
    idx = row[0]
    
    if (not partial and row.img_features is None) or (partial and row.partial_img_features is None):
        imgs_paths_to_process.append((idx, row.img_path))
        
    if len(imgs_paths_to_process) > 0 and len(imgs_paths_to_process) % per_df_write == 0:
        written += save_features(prompt_dataframe, imgs_paths_to_process)
        imgs_paths_to_process = []
        torchutils.torch_garbage_collect()
        
        if written >= per_save:
            print(f"Writting {written}")
            written = 0
            prompt_dataframe.to_feather(df_file_path)
        
        end = time.perf_counter()
        diff = end - start
        start = end
        
        per_entry = diff/per_df_write
        rem_rows = len(prompt_dataframe[prompt_dataframe[col].isna()])
        print(f"Time per entry {per_entry}. Rem time {per_entry*rem_rows/60} minutes for {rem_rows} rows")
        print(f"Completed rows {len(prompt_dataframe) - rem_rows}")
print("done")

In [7]:
print(len(prompt_dataframe))
prompt_dataframe.reset_index(drop=True, inplace=True)
prompt_dataframe.drop_duplicates(subset="prompt", ignore_index=True, inplace=True)
prompt_dataframe = prompt_dataframe.sample(frac=1).reset_index(drop=True)
len(prompt_dataframe) 

16798194


16793602

In [9]:
prompt_dataframe.to_feather(df_path)

In [8]:
print(len(prompt_dataframe))

16793602


In [None]:
prompt_dataframe["cfg"] = pd.to_numeric(prompt_dataframe["cfg"], downcast="float")
prompt_dataframe["width"] = pd.to_numeric(prompt_dataframe["width"], downcast="integer")
prompt_dataframe["height"] = pd.to_numeric(prompt_dataframe["height"], downcast="integer")
prompt_dataframe["num_imgs"] = pd.to_numeric(prompt_dataframe["num_imgs"], downcast="integer")
prompt_dataframe["steps"] = pd.to_numeric(prompt_dataframe["steps"], downcast="integer")
prompt_dataframe["seed"] = pd.to_numeric(prompt_dataframe["seed"], downcast="integer")

In [None]:
prompt_dataframe["prompt"] = prompt_dataframe["prompt"].astype(str)
prompt_dataframe["sampler"] = prompt_dataframe["sampler"].astype(str)
prompt_dataframe["user"] = prompt_dataframe["user"].astype(str)

In [None]:
df_path = get_default_path("large_datasets", "aug_prompts.feather")
prompt_dataframe.to_feather(df_path)

In [10]:
list(prompt_dataframe.loc[:1000, "prompt"])

['a drawing of beethoven',
 'breathtaking detailed painting of a beautiful side portrait of zendaya in a white cat sitting with her hands, cloak, charming, dnd character art by artgerm and greg rutkowski and artgerm and liam brazier and victo ngai and tristan eaton',
 'farm girl with goats and lambs   is this not just too adorable and so retro - love it',
 'arreglos de mickey mouse | decoracion con globos de mickey mouse-3 ...',
 "these greek salad hummus pita pizzas are the perfect summertime appetizer or main course meal. olive oil grilled pita topped with spicy hummus and fresh greek salad. you're going to love this!",
 'william robinson leigh poster featuring the painting zuni pottery maker by william robinson leigh',
 'interesting 60 containers home on dwts',
 'easy wedding cake cupcakes cupcake elegance wedding cakes glass house mountains 13836',
 'passing through the ethereal portal to the iris van herpen doom experience, majestic dream in the space ',
 'image for tropical beach

In [None]:
prompt_dataframe.memory_usage(deep=True)

In [None]:
prompt_dataframe["text_tokens"] = None
prompt_dataframe["img_features"] = None

prompt_dataframe["text_tokens"] = prompt_dataframe["text_tokens"].astype(object)
prompt_dataframe["img_features"] = prompt_dataframe["img_features"].astype(object)

In [None]:
step_size = 100
finished_cnt = 0
finished_per_save = 25000

def load_img_for_row(path, width, height):
    sdutils.load_img(path)
    img = sdutils.load_img(path)
    if img.size == (width, height):
        img = img.crop([0, 0, width, height])
    
    return vit14_clip_preprocess(img)
    
def get_img_features(imgs_paths, widths, heights):
    preprocessed_imgs = tuple((load_img_for_row(img, width, height) for img, width, height in zip(imgs_paths, widths, heights)))
    preprocessed_imgs = torch.stack(preprocessed_imgs).cuda()
    encoded_imgs = vit14_clip_model.encode_image(preprocessed_imgs).float()
    return (encoded_imgs / encoded_imgs.norm(dim=-1, keepdim=True)).cpu().numpy().tolist()

def get_text_tokens(prompts):
    return clip.tokenize(prompts, truncate=True).cpu().numpy().tolist()

start_time = time.perf_counter()
prompts = prompt_dataframe["prompt"]
img_paths = prompt_dataframe["img_path"]
widths = prompt_dataframe["width"]
heights = prompt_dataframe["height"]

tokens = []
img_features = []

with torch.no_grad():
    for start in range(0, len(prompt_dataframe), step_size):
        tokens += get_text_tokens(prompts[start:start+step_size])
        img_features += get_img_features(img_paths[start:start+step_size],
                                         widths[start:start+step_size],
                                         heights[start:start+step_size])
        break
    
    

end = time.perf_counter()
per_datum = (end - start_time)/step_size
print(f"Time per datum: {per_datum}")
print(f"Rem time: {per_datum * (len(prompt_dataframe) - idx)/60} minutes")


In [None]:
prompt_dataframe.dtypes

In [3]:
from util import ReTerm
import re
def arg_prompt_split(command):
    args_iters = ReTerm.sd_args_regex.finditer(command)
    
    args = [command[it.start():it.end()].strip() for it in args_iters]
    prompt = re.sub('|'.join(args), "", command).strip()

    command_mapper = {
        '-width': "-W",
        "-height": "-H",
        "-cfg": "-C",
        "-cfg_scale": "-C",
        "-seed": "-S",
        "-steps": "-s",
        "-sampler": "-A",
        "-prior": "-p",
        "-ascii": "-a",
        "-separate-images": "-i",
        "-grid": "-g",
        "-number": "-n",
        "-tokenize": "-t"
    }
    no_value = {"-a", "-t", "-i", "-g", "-t", "-p"}
    
    recognized_args = command_mapper.keys() | command_mapper.values()

    def split_key_val(arg):
        split = [part.strip() for part in re.split('[ =]', arg.strip())]
        split[0] = re.sub(" +", " ", split[0])
        split[0] = re.sub("–", "-", split[0])
        split[0] = re.sub("--", "-", split[0]).strip()
        
        without_nums = re.sub(r"\d+", "", split[0])
        if without_nums in recognized_args and without_nums != split[0]:
            nums = re.findall(r"\d+", split[0])
            split.append(nums[0])
            split[0] = without_nums
            
        if split[0].lower() in command_mapper:
            split[0] = command_mapper[split[0].lower()]

        if len(split) != 2:
            return split[0], ""

        return split[0], split[1]

    key_val_pairs = [split_key_val(arg) for arg in args]
    args_to_remove = set()
    for idx, arg_pair in enumerate(key_val_pairs):
        key, val = arg_pair
        
        if key not in recognized_args or (not val and key not in no_value):
            # Add the origional arg to the prompt. This is a common typo
            # which affects the tokens
            prompt += f" {args[idx]}"
            args_to_remove.add(idx)
        
    key_val_pairs = [pair for idx, pair in enumerate(key_val_pairs) if idx not in args_to_remove]

    args = {entry[0]: entry[1] for entry in key_val_pairs if len(entry) == 2}
    return prompt.lower(), args

def fix_row_args(prompt_dataframe, idx, split_prompt, args):
    before_prompt = prompt_dataframe.loc[idx, "prompt"]
    #print(f"{idx} Before: {before_prompt[-15:]}")
    prompt_dataframe.at[idx, "prompt"] = split_prompt.strip()
    prompt_dataframe.at[idx, "text_tokens"] = clip.tokenize(split_prompt, truncate=True)[0].numpy().astype(np.uint16)
    seed = 0
    #print(args)
    """
    if "-S" in args:
        seed = args["-S"]
        prompt_dataframe.at[idx, "seed"] = int(seed)
    
    if "-s" in args:
        steps = args["-s"]
        prompt_dataframe.at[idx, "seed"] = int(steps)

    cfg = 7.0
    if "-C" in args:
        cfg = args["-C"]
        cfg = re.findall("\d+\.?\d*", cfg)[0]
        prompt_dataframe.at[idx, "cfg"] = float(cfg)
        
        
    sampler = "k_lms"
    if "-A" in args:
        sampler = args["-A"]
        prompt_dataframe.at[idx, "sampler"] = sampler
        
    num_imgs = 1
    if "-n" in args:
        num_imgs = args["-n"]
        prompt_dataframe.at[idx, "num_imgs"] = int(num_imgs)
        
    width = 512
    if "-W" in args:
        width = args["-W"]
        prompt_dataframe.at[idx, "width"] = width
        if width != 512:
            prompt_dataframe.at[idx, "img_features"] = None
        
    height = 512
    if "-H" in args:
        height = args["-H"]
        prompt_dataframe.at[idx, "height"] = height
        if height != 512:
            prompt_dataframe.at[idx, "img_features"] = None
    """
        
    #print(f"{idx} After: {prompt_dataframe.iloc[idx]}\n")


fix_cnt = 0
cut_cnt = 0
to_remove = []
for row in prompt_dataframe.itertuples(True):
    idx = row[0]

    if idx != 0 and idx % 100000 == 0:
        print(f"Fixed {fix_cnt} at idx {idx}")
        print(f"Cut {cut_cnt} at idx {idx}")
        
    prompt = row.prompt

    """
    cut_idx = re.search(r"[` ]*(\*\*WARNING\*\*|The seeds for each individual image are)", row.prompt) 
    if cut_idx:
        cut_idx = cut_idx.span()[0]
        cut_cnt += 1
        
        prompt = prompt[:cut_idx]
        print(prompt_dataframe.at[idx, "prompt"])
        prompt_dataframe.at[idx, "prompt"] = prompt
        print(prompt_dataframe.at[idx, "prompt"])
        prompt_dataframe.at[idx, "text_tokens"] = clip.tokenize(prompt, truncate=True)[0].numpy().astype(np.uint16)  
    """
    split_prompt, args = arg_prompt_split(prompt)

    if split_prompt != prompt.strip():
        fix_cnt += 1
        fix_row_args(prompt_dataframe, idx, split_prompt, args)
        
                       
    if row.text_tokens is None:
        prompt_dataframe.at[idx, "text_tokens"] = clip.tokenize(split_prompt, truncate=True)[0].numpy().astype(np.uint16)  
        
print("Cut fixes",cut_cnt) 
print(fix_cnt)
print(len(to_remove))

Fixed 105 at idx 100000
Cut 0 at idx 100000
Fixed 257 at idx 200000
Cut 0 at idx 200000
Fixed 357 at idx 300000
Cut 0 at idx 300000
Fixed 468 at idx 400000
Cut 0 at idx 400000
Fixed 567 at idx 500000
Cut 0 at idx 500000
Fixed 679 at idx 600000
Cut 0 at idx 600000
Fixed 794 at idx 700000
Cut 0 at idx 700000
Fixed 895 at idx 800000
Cut 0 at idx 800000
Fixed 1000 at idx 900000
Cut 0 at idx 900000
Fixed 1096 at idx 1000000
Cut 0 at idx 1000000
Fixed 1205 at idx 1100000
Cut 0 at idx 1100000
Fixed 1310 at idx 1200000
Cut 0 at idx 1200000
Fixed 1421 at idx 1300000
Cut 0 at idx 1300000
Fixed 1536 at idx 1400000
Cut 0 at idx 1400000
Fixed 1649 at idx 1500000
Cut 0 at idx 1500000
Fixed 1745 at idx 1600000
Cut 0 at idx 1600000
Fixed 1844 at idx 1700000
Cut 0 at idx 1700000
Fixed 1949 at idx 1800000
Cut 0 at idx 1800000
Fixed 2056 at idx 1900000
Cut 0 at idx 1900000
Fixed 2170 at idx 2000000
Cut 0 at idx 2000000
Fixed 2284 at idx 2100000
Cut 0 at idx 2100000
Fixed 2400 at idx 2200000
Cut 0 at idx 

In [4]:
def calc_jaccard_similarity(text_a, text_b):
    doc1_tokens=set(text_a.lower().split())
    doc2_tokens=set(text_b.lower().split())
    div_inter = len(doc1_tokens.union(doc2_tokens))
    if div_inter == 0:
        return 0
    return len(doc1_tokens.intersection(doc2_tokens))/div_inter

prompt_dataframe.sort_values('prompt',inplace=True, ascending=False)
prompt_dataframe.reset_index(drop=True, inplace=True)

last_text = ""
threshold = 0.80
to_remove = []
next_remove_candidates = []
max_similar = 2
for cnt, row in enumerate(prompt_dataframe.itertuples(True)):
    if cnt % 100000 == 0:
        print(f"Removed {len(to_remove)} at cnt {cnt}")

    if len(row.prompt) == 0:
        print(row.prompt)
        to_remove.append(row[0])
    jaccard_similarity = calc_jaccard_similarity(last_text, row.prompt)
    if jaccard_similarity > threshold:
        next_remove_candidates.append((f"{row.prompt}", row[0]))
    else:
        next_remove_candidates.sort(key=lambda x: len(x))
        next_remove_candidates = [x[1] for x in next_remove_candidates]
        if len(next_remove_candidates) > max_similar:
            to_remove += next_remove_candidates[:-(max_similar)]
        next_remove_candidates = []
        last_text = row.prompt
        
print(len(to_remove))

Removed 0 at cnt 0
Removed 4 at cnt 100000
Removed 4 at cnt 200000
Removed 4 at cnt 300000
Removed 6 at cnt 400000
Removed 6 at cnt 500000
Removed 6 at cnt 600000
Removed 6 at cnt 700000
Removed 6 at cnt 800000
Removed 6 at cnt 900000
Removed 8 at cnt 1000000
Removed 11 at cnt 1100000
Removed 11 at cnt 1200000
Removed 11 at cnt 1300000
Removed 13 at cnt 1400000
Removed 13 at cnt 1500000
Removed 14 at cnt 1600000
Removed 16 at cnt 1700000
Removed 17 at cnt 1800000
Removed 17 at cnt 1900000
Removed 19 at cnt 2000000
Removed 20 at cnt 2100000
Removed 23 at cnt 2200000
Removed 27 at cnt 2300000
Removed 27 at cnt 2400000
Removed 28 at cnt 2500000
Removed 32 at cnt 2600000
Removed 32 at cnt 2700000
Removed 32 at cnt 2800000
Removed 32 at cnt 2900000
Removed 32 at cnt 3000000
Removed 32 at cnt 3100000
Removed 32 at cnt 3200000
Removed 33 at cnt 3300000
Removed 34 at cnt 3400000
Removed 38 at cnt 3500000
Removed 38 at cnt 3600000
Removed 39 at cnt 3700000
Removed 40 at cnt 3800000
Removed 41 a

In [None]:
for remove_idx in to_remove[:100]:
    print(prompt_dataframe.at[remove_idx, "prompt"])

In [None]:
print(len(to_remove))

In [6]:
print(len(prompt_dataframe))
prompt_dataframe.drop(prompt_dataframe.index[to_remove], inplace=True)
print(len(prompt_dataframe))
del to_remove

16798395
16798194


In [None]:
%timeit os.path.exists("/home/ubuntu/Main/sd_discord/preprocessed_imgs/0.png")

In [None]:
def load_img_for_row(img_path, width, height):
    img = sdutils.load_img(img_path)
    if img.size != (width, height):
        img = img.crop([0, 0, width, height])
    
    return img
    
def preprocess_img(preprocessor, img_path, width, height):
    return preprocessor(load_img_for_row(img_path, width, height))

preprocessed = preprocess_img(vit14_clip_preprocess, prompt_dataframe.at[0, "img_path"], prompt_dataframe.at[0, "width"], prompt_dataframe.at[0, "height"])

In [None]:
print(preprocessed.shape)

In [None]:
def load_img_for_row(row):
    img = sdutils.load_img(row.img_path)
    if img.size != (row.width, row.height):
        img = img.crop([0, 0, row.width, row.height])
    
    return img
    
def get_img_features(rows):
    imgs = [load_img_for_row(row) for row in rows]
    return [clip_model.encode_image_features(img)[0].cpu().numpy().tolist() for img in imgs]

idxs = []
text = []
imgs = []
per_process = 10000

finished_cnt = 0
finished_per_save = 20000
with torch.no_grad():
    start = time.perf_counter()
    for idx, row in enumerate(prompt_dataframe.itertuples()):
        idxs.append(idx)
        text.append(row.prompt)
        imgs.append(load_img_for_row(row))

        if len(idxs) > per_process:
            batch_tokens = clip.tokenize(text, truncate=True).cpu().numpy().tolist()
            img_features = [clip_model.encode_image_features(img)[0].cpu().numpy().tolist() for img in imgs]

            idxs = []
            text = []
            imgs = []
            for row_idx, text_tokens, img_feature in zip(idxs, batch_tokens, img_features):
                prompt_dataframe.at[row_idx, 'text_tokens'] = text_tokens
                prompt_dataframe.at[row_idx, 'img_features'] = img_feature
            finished_cnt += len(idxs)
            if finished_cnt > finished_per_save:
                prompt_dataframe.to_csv(df_file_path)
                finished_cnt = 0
                
            
            end = time.perf_counter()
            per_datum = (end - start)/per_process
            print(f"Time per datum: {per_datum}")
            print(f"Rem time: {per_datum * (len(prompt_dataframe) - idx)/60} minutes")
            start = time.perf_counter()
for row_idx, text_tokens, img_feature in zip(idxs, batch_tokens, img_features):
    prompt_dataframe.at[row_idx, 'text_tokens'] = text_tokens
    prompt_dataframe.at[row_idx, 'img_features'] = img_feature
    
finished_cnt += len(idxs)
if finished_cnt > finished_per_save:
    prompt_dataframe.to_csv(df_file_path)
    finished_cnt = 0


In [None]:
print(prompt_dataframe.loc[0, 'text_tokens'])
print(prompt_dataframe.loc[0, 'img_features'])

In [None]:
def process_elem(x, target_type):
    if type(x) != str:
        return x
    split_list = x.strip('[]').split(', ')
    return [target_type(x) for x in split_list]

prompt_dataframe['text_tokens'] = [process_elem(x, int) for x in prompt_dataframe['text_tokens']]
prompt_dataframe['urls'] = [process_elem(x, str) for x in prompt_dataframe['urls']]
prompt_dataframe['img_features'] = [process_elem(x, float) for x in prompt_dataframe['img_features']]