In [None]:
#VQ-GAN 1024 
Model = "f16_1024" #param ["f16_1024", "f16_16384", "f16_16384_hf"]
import cv2
import torch
import yaml
import torch
from PIL import Image
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel

def load_config(config_path, display=False):
  config = OmegaConf.load(config_path)
  if display:
    print(yaml.dump(OmegaConf.to_container(config)))
  return config

def load_vqgan(config, ckpt_path=None):
  model = VQModel(**config.model.params)
  if ckpt_path is not None:
    sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    missing, unexpected = model.load_state_dict(sd, strict=False)
  return model.eval()

def preprocess_vqgan(x, roll=True):
  x = 2.*x - 1.
  if roll:
    x = np.rollaxis(x,3,1)
  x = torch.Tensor(x)
  return x

def preprocess(x, permt=True):
  if permt:
    x = x.permute(0,2,3,1).numpy()
  x = np.clip(x, -1., 1.)
  x = (x + 1.)/2.
  return x

def custom_to_pil(x):
  x = np.clip(x, -1., 1.)
  x = (x + 1.)/2.
  x = (255*x).astype(np.uint8)
  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x

vq_conf = load_config(f"chk_points/vqgan_imagenet_{Model}.yaml", display=False)
vq_model = load_vqgan(vq_conf, ckpt_path=f"chk_points/vqgan_imagenet_{Model}.ckpt").to('cuda')

'''
sz = []

plt.figure(figsize=(20, 40))
img_rec = []
for i in range(1):
  #quant_states, indices = V_encoder.encode(DS.obs[i+2][0])
  x = preprocess_vqgan(DS.obs[i+2])
  with torch.no_grad():
    z, _, [_, _, ind] = vq_model.encode(x.to('cuda'))
    b,c,h,w = z.shape
    nz = vq_model.quantize.get_codebook_entry(ind, (b,h,w,c))
    rec = vq_model.decode(nz).detach().cpu()
    sz.append(h*w)
  #print(rec.shape)
  img_rec.append(preprocess(rec))

for i in range(1):
  for j in range(1):
    plt.subplot(10, 6, i*6+j+1)
    plt.axis("off")
    plt.imshow(DS.obs[j+2][i])
    plt.title(f'origin {DS.obs[j+2][i].shape}')
  for j in range(1):
    plt.subplot(10, 6, i*6+j+4)
    plt.axis("off")
    plt.imshow(img_rec[j][i])
    plt.title(f'sintetic {sz[j]} token')

plt.show();
'''

In [None]:
import time
#time.sleep(3600*8)#16:00

In [None]:
print(pd.Timestamp.now())

In [None]:
imsize = [16*9, 16*9]

In [None]:
import codecs
if 1:
    path = './data/imgs_descs.txt'
    with codecs.open(f'{path}', 'r', 'utf8', errors='ignore') as f:
        #full_text_list = ''.join(f.readlines())
        full_text_list = f.readlines()
if 0:    
    path = './data/image_annotations_plans.txt'
    with codecs.open(f'{path}', 'r', 'utf8', errors='ignore') as f:
        #full_text_list = ''.join(f.readlines())
        full_text_list = f.readlines()
    
path = './data/image_annotations_plans.txt'
with codecs.open(f'{path}', 'r', 'utf8', errors='ignore') as f:
    annotated_text_list = f.readlines()

In [None]:
text_for_annotation = full_text_list[len(annotated_text_list):]
print('to add:', len(text_for_annotation))
text_for_annotation = ''.join(text_for_annotation)

In [None]:
from pathlib import Path

p = Path("./data/imgs")
i = 0
for img_name in p.rglob("*"):
    img_name_short = str(img_name).replace('data\imgs\\', '')
    if img_name_short in text_for_annotation:
        i += 1
        #скачать и декодировать картинку
        img_orig = np.array(cv2.resize(cv2.imread(str(img_name)), imsize), dtype=np.float32)/255.
        b,g,r = img_orig[:,:,0], img_orig[:,:,1], img_orig[:,:,2]
        img_orig = np.dstack([r, g, b])
        img = preprocess_vqgan(np.stack(1*[img_orig]), True)       
        
        z, _, [_, _, ind] = vq_model.encode(img.to('cuda'))
        ind.squeeze_()
        b,c,h,w = z.shape
        #nz = vq_model.quantize.get_codebook_entry(ind, (b,h,w,c))
        #print('nz', nz)
        token_string = '<' + '><'.join( [str(el) for el in list(ind.detach().cpu().numpy())] ) + '>'
        nz = vq_model.quantize.get_codebook_entry(ind, (b,h,w,c))
        rec = vq_model.decode(nz).detach().cpu()
        if np.random.rand()<0.002:
            print(i)
            plt.imshow(img_orig)
            plt.show()
            plt.imshow(preprocess(rec)[0])
            plt.show()
        
        
        text_for_annotation = text_for_annotation.replace(f'<<{img_name_short}>>', token_string)
        
#text

In [None]:
text_for_annotation = text_for_annotation.replace('description:', 'description:<OUT>').replace('forecast vars:', 'forecast vars:<OUT>').replace('forecast img:', 'forecast img:<OUT>').replace('plan:', 'plan:<OUT>').replace('<END>\r\n', '<END>\r\n<IN>')
#text

In [None]:
with codecs.open('data/image_annotations_plans.txt', 'a', 'utf8') as f:
    f.write(text_for_annotation)

In [None]:
path = './data/image_annotations_plans.txt'
with codecs.open(f'{path}', 'r', 'utf8', errors='ignore') as f:
    annotated_text = ''.join(f.readlines())
for num in range(10):
    annotated_text = annotated_text.replace(f'<END>\n<{num}', f'<END>\n<IN><{num}').replace(f'<END>\r\n<{num}', f'<END>\n<IN><{num}').replace('<IN><IN>', "<IN>")

with codecs.open('data/image_annotations_plans.txt', 'w', 'utf8') as f:
    f.write(annotated_text)

In [None]:
print(pd.Timestamp.now())

In [None]:
#обновить all_txt

In [None]:
#аннотированные картинки
path = 'data/image_annotations_plans.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts = ''.join(f.readlines())
texts = texts.replace('<s>', '<END>').replace('\r\n', '\n').replace('</s>', '')
#.replace('<IN>', '')
texts = texts.replace('>f', '> f').replace('>d', '> d').replace('>p', '> p').replace('>s', '> s').replace('>k', '> k').replace('  ', ' ').replace('\n', '\n<IN>').replace('<IN><IN>', '<IN>')
print('texts', len(texts))
with codecs.open('data/all_txt.txt', 'w', 'utf8') as f:
    f.write(texts)
    
    
def process_book(path, drop_spaces=True):  
    with codecs.open(f'{path}', 'r', 'utf8', errors='ignore') as f:
        texts = ''.join(f.readlines())
    texts = texts.replace('<s>', '<END>').replace('\r\n', '\n').replace('</s>', '')
    if drop_spaces:
        texts = texts.replace('\n', '\t')#у нас датасет такой
        texts = '\n' + texts
    print(path, 'texts', len(texts))
    with codecs.open('data/all_txt.txt', 'a', 'utf8') as f:
        f.write(texts)
#вики
path = 'data/wiki_data.txt'
process_book(path)
process_book(path, drop_spaces=False)

    
#/toy_text_doom_tasks
path = 'data/toy_text_doom_tasks.txt'     
process_book(path, drop_spaces=False)
    
#logic
path = 'data/formal_logic_textbook.txt'     
process_book(path)
process_book(path, drop_spaces=False)
    
    
#hpmor
path = 'data/hpmor.txt'     
process_book(path)   
process_book(path, drop_spaces=False)

path = "data/Book 1 - The Philosopher's Stone.txt"
process_book(path)   
process_book(path, drop_spaces=False)

path = 'data/Book 2 - The Chamber of Secrets.txt'
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/Book 4 - The Goblet of Fire.txt'
process_book(path)
process_book(path, drop_spaces=False)
    
#rationality.txt
path = 'data/Map and Territory.txt'     
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/doom fanfics.txt' 
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/treasure island.txt' 
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/robinson crusoe.txt' 
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/Sherlock Holmes.txt' 
process_book(path)

path = 'data/scrum.txt' 
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/military stories.txt' 
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/military stories 2.txt' 
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/doom wiki.txt'
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/homm.txt'
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/military materials.txt'
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/military materials 2.txt'
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/anatomy.txt'
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/churchill.txt'
process_book(path)
process_book(path, drop_spaces=False)

path = 'data/summary_data.txt'
process_book(path, drop_spaces=False)

path = 'data/dialogues_text.txt'
process_book(path, drop_spaces=False)

path = 'data/chat_data.txt'
process_book(path, drop_spaces=False)

In [None]:
#аннотированные картинки с памятью
path = 'data/imgs_descs_memory.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))

path = 'data/imgs_descs_memory_2.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))

path = 'data/imgs_descs_memory_3.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))

path = 'data/imgs_descs_memory_4.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))

path = 'data/imgs_descs_memory_5.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))
    
path = 'data/imgs_descs_memory_6.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))
    
path = 'data/imgs_descs_memory_7.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))
    
path = 'data/imgs_descs_memory_8.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))

path = 'data/imgs_descs_memory_9.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))

path = 'data/imgs_descs_memory_10.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))

path = 'data/imgs_descs_memory_11.txt' 
with codecs.open(f'{path}', 'r', 'utf8') as f:
    texts += '\n' + ''.join(f.readlines())
print('texts', len(texts))


texts = texts.replace('<s>', '<END>').replace('\r\n', '\n').replace('</s>', '')
with codecs.open('data/all_txt.txt', 'a', 'utf8') as f:
    f.write(texts)

In [None]:
print(pd.Timestamp.now())

In [None]:
#cut text
import codecs
thresh = 30000000
with codecs.open('data/all_txt.txt', 'r', 'utf8', errors='ignore') as f:
    lines = f.readlines()
lines = [line[:thresh] for line in lines]
lines = '\n'.join(lines)

with codecs.open('data/all_txt_cut.txt', 'w', 'utf8') as f:
    f.write(lines)

In [None]:
def img_to_tokens(path):
    img_orig = np.array(cv2.resize(cv2.imread(str(path)), imsize), dtype=np.float32)/255.
    b,g,r = img_orig[:,:,0], img_orig[:,:,1], img_orig[:,:,2]
    img_orig = np.dstack([r, g, b])
    img = preprocess_vqgan(np.stack(1*[img_orig]), True)       

    z, _, [_, _, ind] = vq_model.encode(img.to('cuda'))
    ind.squeeze_()
    token_string = '<' + '><'.join( [str(el) for el in list(ind.detach().cpu().numpy())] ) + '>'
    return ind, token_string

In [None]:
path = './data/imgs/images (2).jpg'
img_to_tokens(path)

In [None]:
1/0