In [1]:
%load_ext autoreload
%load_ext tensorboard
%autoreload 2
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('../..'))

import importlib

import contextlib

import json
import os
import random

import numpy as np
import torch
from collections import defaultdict
from fileutils import get_default_path

from automation.automationmanager import make_default_manager
from automation.crawler.midjourneycrawl import crawl_gallery_user, crawl_gallery_feed
from automation.midjourney.midjourneyutils import FeedType

from storage.data.image.remoteimageinfo import imgs_to_cmds
from storage.data.image.crawledimagegroups import CrawledImageGroups
from storage.data.command import Command
from storage.data.command.commandbuilder import CommandBuilder
from storage.data.user.userids import MJ_USER_TO_ID
from storage.data.user.mjuser import MJUser
import time
from util import Stopwatch
import datetime

import ai.stabledisco as sd
import ai.torchmodules as torchmodules
import ai.torchmodules.data as torchdata
import ai.torchmodules.utils as torchutils
import ai.stabledisco.utils as sdutils
import clip
import ai.nlp
import torch
import torch.nn as nn
import pandas as pd
from collections import defaultdict

In [2]:
df_file_path = get_default_path("large_datasets", "aug_prompts.feather")
if "prompt_dataframe" in dir():
    del prompt_dataframe
prompt_dataframe = pd.read_feather(df_file_path)

In [24]:
df_file_path = get_default_path("large_datasets", "aug_prompts.feather")
prompt_dataframe.to_feather(df_file_path)

In [25]:
print(len(prompt_dataframe))

16798395


In [None]:
import markovify
sentences = []

prompt_sample_num = 2000000

to_gen = 2000000
per_markov = to_gen//10
iters = to_gen // per_markov
for markov_num in range(iters):
    prompts = list(prompt_dataframe.sample(n=prompt_sample_num)["prompt"])
    print(f"Making markov {markov_num}")
    prompt_markov = markovify.NewlineText(prompts, 2, retain_original=False)
    del prompts
    prompt_markov.compile(inplace=True)
    for idx in range(per_markov):
        if idx % 2500 == 0:
            print(idx)
        sent = prompt_markov.make_sentence(test_output=False)
        if sent:
            sentences.append(sent)
    del prompt_markov

In [None]:
rows = []
for sent in sentences:
    if not sent:
        continue
    rows.append({
        "prompt": sent.lower(),
        "text_tokens": clip.tokenize(sent, truncate=True)[0].numpy().astype(np.uint16)  
    })
    
prompt_dataframe = pd.concat([prompt_dataframe, pd.DataFrame.from_records(rows)], ignore_index=True)
prompt_dataframe.drop_duplicates(subset="prompt", ignore_index=True, inplace=True)
prompt_dataframe = prompt_dataframe.sample(frac=1).reset_index(drop=True)

df_file_path = get_default_path("large_datasets", "aug_prompts.feather")
prompt_dataframe.to_feather(df_file_path)
print(len(prompt_dataframe))

In [19]:
prompt_dataframe.dtypes

prompt         object
text_tokens    object
dtype: object

In [23]:
for row in prompt_dataframe.itertuples(True):
    idx = row[0]
    if type(row.text_tokens[0]) == np.ndarray:
        prompt_dataframe.at[idx, "text_tokens"] = row.text_tokens[0]
        

In [None]:
random.shuffle(sentences)
print("\n\n".join(sentences[:1000]))

In [18]:
import pickle
df_file_path = get_default_path("large_datasets", "aug_prompts.feather")
markov_file_path = get_default_path("large_datasets", "aug_markov.pk")
with open(markov_file_path, 'wb+') as outfile:
    pickle.dump(prompt_markov, outfile)
    

In [None]:
import clip
from clip.clip import _tokenizer as clip_tokenizer
eot_token = clip_tokenizer.encoder["<|endoftext|>"]

vit14_clip_model, vit14_clip_preprocess = clip.load('ViT-L/14')
vit14_clip_model = vit14_clip_model.float().cuda()

In [4]:
sot_token = clip_tokenizer.encoder["<|startoftext|>"]

In [5]:
token_cnt = {token: 0 for token in clip_tokenizer.encoder.values()}
eos_token = clip
for idx, tokens in enumerate(prompt_dataframe["text_tokens"]):
    for val in tokens:
        token_cnt[val] += 1
        if val == eot_token:
            break
            
    if idx % 100000 == 0:
        print(idx)

0
100000
200000
300000
400000
500000
600000
700000
800000
900000
1000000
1100000
1200000
1300000
1400000
1500000
1600000
1700000
1800000
1900000
2000000
2100000
2200000


In [11]:
import numpy as np

cnt_token = [(val, key) for key, val in token_cnt.items()]
cnts = [pair[0] for pair in cnt_token]
cnt_token.sort()

target_min = np.percentile(cnts, 90)
replacement_min = np.percentile(cnts, 90)
replacement_max = np.percentile(cnts, 99.9)

for perc in range(0, 101, 5):
    print(f"perc: {perc}", np.percentile(cnts, perc))

to_add_cnts = [[target_min - val, key] for key, val in token_cnt.items() if val < target_min]
replacement_candidates = {key  for key, val in token_cnt.items() if replacement_max > val > replacement_min}

perc: 0 424.0
perc: 5 441.0
perc: 10 443.0
perc: 15 445.0
perc: 20 446.0
perc: 25 447.0
perc: 30 448.0
perc: 35 449.0
perc: 40 450.0
perc: 45 452.0
perc: 50 453.0
perc: 55 455.0
perc: 60 457.0
perc: 65 461.0
perc: 70 467.0
perc: 75 477.0
perc: 80 492.0
perc: 85 524.0
perc: 90 709.0
perc: 95 1831.2999999999884
perc: 100 11642068.0


In [10]:
import random
import copy
new_tokens = []

keep_chance = 0.75
for idx, orig_tokens in enumerate(prompt_dataframe["text_tokens"]):
    if idx % 100000 == 0:
        print(idx, "rem", len(to_add_cnts))
    replaced = 0
    tokens = copy.copy(orig_tokens)
    for idx, token in enumerate(tokens[1:]):
        if token == eot_token:
            break
            
        if token not in replacement_candidates or random.random() < keep_chance:
            continue
            
        replaced += 1
        rep_token_idx = random.randint(0, len(to_add_cnts)-1)
        rep_token = to_add_cnts[rep_token_idx][1]
        tokens[idx+1] = rep_token
        
        to_add_cnts[rep_token_idx][0] -= 1
        if to_add_cnts[rep_token_idx][0] < 1:
            del to_add_cnts[rep_token_idx]
            
        if len(to_add_cnts) < 1:
            break

    if replaced > 0:
        new_tokens.append(tokens)

    if len(to_add_cnts) < 1:
        break
print(len(new_tokens))
new_tokens = np.array(new_tokens)

0 rem 46711
100000 rem 46709
200000 rem 46705
300000 rem 46697
400000 rem 46690
500000 rem 46684
600000 rem 46676
700000 rem 46670
800000 rem 46662


KeyboardInterrupt: 

In [82]:
text = [sdutils.decode_clip_tokens(aug_tokens)[0] for aug_tokens in new_tokens]

In [83]:
new_df = pd.DataFrame.from_dict({"prompt": text, "text_tokens": new_tokens.tolist()})
prompt_dataframe = pd.concat([prompt_dataframe, new_df], ignore_index=True) 
del new_df

In [84]:
prompt_dataframe = prompt_dataframe.sample(frac=1).reset_index(drop=True)
prompt_dataframe.to_feather(df_file_path)