In [None]:
from transformers import RobertaConfig, RobertaTokenizerFast, TFRobertaModel
from transformers.optimization_tf import create_optimizer, WarmUp
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras_bert import gelu

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import pandas as pd
import numpy as np
import itertools
import random
import tensorflow as tf
import json
import datetime as dt
from pprint import pformat
import logging
import pickle

import seaborn


In [None]:
import tensorflow.keras as keras
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Input, Dense, TimeDistributed, Bidirectional, LSTM, Concatenate, Conv1D, Dropout, Dot, Lambda, GlobalMaxPool1D, GlobalAvgPool1D, Add, GaussianNoise, Embedding, RepeatVector, LayerNormalization, MultiHeadAttention
from tensorflow.keras.models import Model, load_model, Sequential
from keras_pos_embd import TrigPosEmbedding, PositionEmbedding
from keras_bert.layers import TokenEmbedding

from tensorflow.keras import layers


In [None]:
MAX_LEN = 32

In [None]:
isProtein = False
isCLS = False

import string


tokenizer_dict = {}
tokenizer_dict["<pad>"] = 0
tokenizer_dict["<mask>"] = 1

if(not isProtein):
    #speslZnaky = "áéíóúýůžščřťďěň°®…”“„’ì»�–§•✔üß€đĐ$öëä"
    speslZnaky = "áéíóúýůžščřťďěňüöëä"

    for aa in string.printable.lower()+speslZnaky:
        try:
            tokenizer_dict[aa]
        except:
            tokenizer_dict[aa] = len(tokenizer_dict)

    tokenizer_dict[0] = 2
    tokenizer_dict[1] = 3
    tokenizer_dict[2] = 4
    tokenizer_dict[3] = 5
    tokenizer_dict[4] = 6
    tokenizer_dict[5] = 7
    tokenizer_dict[6] = 8
    tokenizer_dict[7] = 9
    tokenizer_dict[8] = 10
    tokenizer_dict[9] = 11
    tokenizer_dict["\u200b"] =70
    tokenizer_dict["\uf02d"] =70
    tokenizer_dict["ø"]=85
    tokenizer_dict["ù"]=81
    tokenizer_dict["è"]=84
    tokenizer_dict["ì"]=88
    tokenizer_dict["<unknown>"] =  len(tokenizer_dict)
    unknown_token_id = tokenizer_dict["<unknown>"]
else:
    for aa in string.ascii_uppercase:
        tokenizer_dict[aa] = len(tokenizer_dict)
    tokenizer_dict[" "] = tokenizer_dict["<pad>"] 

pad_token = "<pad>"
pad_token_id = tokenizer_dict["<pad>"]
mask_token = "<mask>"
mask_token_id = tokenizer_dict["<mask>"]
#unknown_token_id = tokenizer_dict["<unknown>"]

In [None]:
if(isProtein):
    MODEL_DIR = f"Protein_proteopairs13-ml{MAX_LEN}-mask" + ( "-cls" if (isCLS) else "")
else:
    MODEL_DIR = f"Word_proteopairs13-ml{MAX_LEN}-mask" + ( "-cls" if (isCLS) else "")

In [None]:
if MODEL_DIR is not None and os.path.isdir(MODEL_DIR) :
    print("Loading model from", MODEL_DIR)
    model = load_model(os.path.join(MODEL_DIR, "model"), custom_objects={"WarmUp": WarmUp, "siamese_loss":siamese_loss})


    with open(f"{MODEL_DIR}/config.json", "r", encoding="utf-8") as fr:
        TRAIN_STATE = json.load(fr)
        INITIAL_EPOCH = TRAIN_STATE["INITIAL_EPOCH"]
    print(f"Clasification will use model with TrainState of: {TRAIN_STATE}")

In [None]:
if model.optimizer is None:
    print("Compiling model with an optimizer")
    optimizer, lr_schedule = create_optimizer(init_lr=LR_START, 
                                 num_train_steps=TRAIN_STEPS*EPOCHS, 
                                 num_warmup_steps=TRAIN_STEPS*WARMUP,
                                 )
else:
    print("Skipped model compilation, optimizer already initialized")
    optimizer = model.optimizer

model.compile(optimizer=optimizer, 
              loss={"out": siamese_loss,
                    "pos_out": "binary_crossentropy",
                    "neg_out": "binary_crossentropy",
                    "total_out": "binary_crossentropy",
                },
               metrics={
                    "pos_out": "accuracy",
                    "neg_out": "accuracy",
                    "total_out": "accuracy",
               },
              loss_weights={
                    "out": 1.0,
                    "pos_out": 0.1,
                    "neg_out": 0.1,
                    "total_out": 0.,
              }
            )

In [None]:
modelEmbedding  = model.get_layer(name= "embedding")
modelConvLayer = model.get_layer(name= "conv1d")
modelEmbedder = model.get_layer(name= "embedder")
modelDist = model.get_layer(name= "dist_model")
print(modelDist.outputs)

In [None]:

def vlastni_vectorize_batch(batch, mask=False):
    tokens_pos = []


    for pos in batch:
        #print(pos)
        pos_t1 = vectorize_seq(pos, MAX_LEN, mask)        
        tokens_pos.append(pos_t1)
        
        
    pos1, pos2 = modelEmbedder(modelConvLayer(modelEmbedding(pad_sequences(tokens_pos, maxlen=MAX_LEN, dtype=np.int32, padding="post", value=pad_token_id))))
   
    
    #targets = np.array(targets)
    #pos_out = np.ones_like(targets)
    #neg_out = np.zeros_like(targets)

    return (#{"pos1": pos1,"pos2": pos2},
            [pos1,pos2]
           # {"out": targets, "pos_out": pos_out, "neg_out": neg_out, "total_out": pos_out},

           )

In [None]:
def Mapa(vstup,exportName:str = "",title:str = ""):
    import matplotlib.pyplot as plt
    viridis = mpl.colormaps['viridis'].resampled(256)
    newcolors = viridis(np.linspace(0, 1, 256))
    pink = np.array([248/256, 24/256, 148/256, 1])
    black = np.array([0/256, 0/256, 0/256, 1])
    newcolors[-16:, :] = pink
    newcolors[:16,:] = black


    ax = seaborn.heatmap(vstup,annot =False, cmap =  mpl.colors.ListedColormap(newcolors))
    ax.set(xlabel="", ylabel="",title=title)
    ax.xaxis.tick_top()
    plt.xticks(rotation=90)
    if(exportName != ""):
        plt.savefig(exportName,bbox_inches="tight")
        plt.close()
    else:
        plt.show()
    #TODO: SaveFig
    #plt.savefig(".pdf")

In [None]:
import pickle
class Embeddings_ready():
    
    arr_embeddings = []
    strings = []
    
    def __init__(self, path:str = "",embeddings= [],text:str =""):
        if(path != ""):
            self.loadPickle(path)
        elif(text != ""):
            self.arr_embeddings = self.prepEmbeddings(text)

        elif(len(embeddings)>0):
            self.arr_embeddings = embeddings
            
        
    def loadPickle(self,path:str):
        with open(path,"rb") as p:
            self.arr_embeddings= pickle.load(p)
        
    def savePickle(self, path:str):
        with open(path,"wb") as p:
            return pickle.dump(self.arr_embeddings,p)
        
    def getEmbeddings(self):
        return self.arr_embeddings
    
    #firstSEC, secondSEC - označuje zda používáme z embeddings "pos1" nebo "pos2" sloupec - 0 = pos1, 1 = pos2
    def getMapReadyEmbeddings(self, anotherClass = None,firstSEC:int = 1, secondSEC:int = 0):
        retVal = []
        if(anotherClass == None):
            anotherClass = self
        
       
        for i in range(len(self.arr_embeddings[0])):
            retVal.append([])
            #print(len(retVal))
            for j in range(len(anotherClass.arr_embeddings[0])):
                retVal[i].append([self.arr_embeddings[firstSEC][i:i+1],anotherClass.arr_embeddings[secondSEC][j:j+1]])
                #retVal[i].
                next
        return retVal
                
    def prepEmbeddings(self,text:str,n_Chars:int = 1):
        
        splitted = textSplitter(text,str(n_Chars))
        self.strings = splitted

        embedded = vlastni_vectorize_batch(splitted)
        return embedded
    
    
    def generateMap(self,matrix=None,firstSEC:int = 1, secondSEC:int = 0):
        if(matrix == None):
            matrix = self.getMapReadyEmbeddings(firstSEC=firstSEC, secondSEC=secondSEC)
        retVal = []
        for i in range(np.size(matrix,0)):
            retVal.append([])
            for j in range(np.size(matrix,1)):
                retVal[i].append(modelDist(matrix[i][j])[1][0][0].numpy()[0])
        #print(modelDist(matrix[i][j])[1][0][0].numpy()[0])
        return retVal
        
        
        #Asi zatím nemá smysl, aby existoval:
    def getWords(self,number:int = 0):
        if(number == 0):
            if(self.words1==None):
                return None #Asi by měl Rasinout exception?
            return self.words1


        elif(number == 1):
            if(self.words2 == None):
                if(self.words1 == None):
                    return None #Asi by měl Rasinout exception?
                return self.words1
            return self.words2

In [None]:
from hashlib import blake2b 
class MapResult():
    
    import pickle
    import math
    statistika = None
    words1= None
    words2 = None
    emb1 = ""
    emb2 = ""
    vzdalenostniMapa = None
    embeddingy = None
    dtype = [("avg",float),("index",float),("nRecords",float)]
    words = None 
    
    def __init__(self,emb1 = "", emb2 = "", words1 = None, words2=None,vzdalenostniMapa=None,text = ""):
        self.emb1 = emb1
        self.emb2 = emb2
        self.words1 = words1
        self.words2 = words2
        self.vzdalenostniMapa = vzdalenostniMapa
        if(text != ""):
            self.words = text
            embeddingy = Embeddings_ready(text=text)
            #self.emb1 = embeddingy
            #self.emb1 = #TODO - Hledání + načítání z disku podle názvu / hashe / db?... prostě něco
            self.vzdalenostniMapa = embeddingy.generateMap()
    
    def toString(self):
        return("words1: " + str(self.words1 or "None") + "\n" +
              "words2: " + str(self.words2 or "None") + "\n" +
               "emb1: " + str(self.emb1) + "\n"+
               "emb2: " + str(self.emb2) + "\n"+
               "statistika: " + str(self.statistika or "None") + "\n"+
               "vzdalenostniMapa: " + str(self.vzdalenostniMapa or "None") + "\n"              )
        
  
    
 
    def getHash(self, words:str = None):
        hasher = blake2b()
        if (self != None):
            if(self.words == None):
                words = self.words1+self.words2
            else: 
                words=self.words
            
        hasher.update(bytes(words,"UTF-8"))
        return hasher.hexdigest()
    
    def getEmbeddings(self,number:int = 0):
        if(number == 0):
            if(self.emb1==""):
                return None #Asi by měl Raisnout Exception
            return pickle.load(open(self.emb1,"rb"))
                
        elif(number == 1):
            if(self.emb2==""):
                if(self.emb1==""):
                    return None #Asi by měl Raisnout Exception
                return pickle.load(open(self.emb1,"rb"))
            return pickle.load(open(self.emb2,"rb")) 
            
        return None #DEFAULT 
    
    def getVzdalenostniMapa(self):
        if(self.vzdalenostniMapa == None):
            self.generateMap()
        return self.vzdalenostniMapa
    
    def generateMap(self,matrix=None,firstSEC:int = 0, secondSEC:int = 1):
        if(matrix == None):
            
            matrix = self.getEmbeddings().getMapReadyEmbeddings(anotherClass=self.getEmbeddings(number=1),
                                                                firstSEC=firstSEC,
                                                                secondSEC=secondSEC)
        retVal = []
        for i in range(np.size(matrix,0)):
            retVal.append([])
            for j in range(np.size(matrix,1)):
                retVal[i].append(modelDist(matrix[i][j])[1][0][0].numpy()[0])
        #print(modelDist(matrix[i][j])[1][0][0].numpy()[0])
        self.vzdalenostniMapa= retVal
    
    
    def Statistics(self, mapaData=None):
        #TODO: Další statistické funkce, možná úprava ztrátové funkce, či něco dalšího
        if(mapaData ==None):
            mapaData = self.getVzdalenostniMapa()
        nRecords = np.size(mapaData,0)-1
        rangeMax = round(nRecords/2)
        #best_avg = 0
        #pos_avg = 0
        mapaData2 = np.array(mapaData)
        all_Avgs = []
        for i in range(-rangeMax,rangeMax): 
            #Validation_bias - "ztrátová funkce" pro průměry mimo očekávanou osu
            #validation_bias = 1-abs((i+30)/nRecords)
            all_Avgs.append((np.average(mapaData2.diagonal(i)),i,rangeMax-i))
        return all_Avgs
    
    def prumery(self,statistika):
        np.flip(np.sort(np.array(statistika,dtype=self.dtype),order="avg"))
    
    def pickleThis(self,path:str = "./Pickles"):
        hash = self.getHash()
        with open(path+"/"+hash+".pickle","wb") as p:
            pickle.dump(self,p)
            p.close()
        
    

In [None]:

def GenerateMap(vstup1, vstup2 = "", delic="1"):
    if(vstup2 == ""):
        vstup2 =vstup1
    splittedString1 = textSplitter(vstup1,delic)
    splittedString2 = textSplitter(vstup2,delic)
    batch = []

    delka1 = len(splittedString1)                
    delka2 = len(splittedString2)

    #print(splittedString1)


    mapa = np.zeros((delka1,delka2))
    mapa_vstup1 = np.zeros((delka1,delka1))
    mapa_vstup2 = np.zeros((delka2,delka2))

    for i in range(delka1-1):
        #X
        for j in range(delka2-1):
            #Y
            #print(splittedString[i] + " " + splittedString[i+1]     )
            batch = []
            batch.append(((splittedString1[i],splittedString1[i+1],splittedString2[j],splittedString2[j+1]),0))
            vectorized = vectorize_batch(batch,mask=False)        
            out = model.predict(vectorized)
            #mapa[i,j+1] = out["out"][0][3]
            #mapa[i+1,j] = out["out"][0][2]
            #mapa_vstup1[i,i+1] = out["out"][0][0] 
            #mapa_vstup2[j,j+1] = out["out"][0][1] 

            mapa[i,j+1] = out["neg_out"][0][0][0]
           # mapa[i+1,j] = out["out"][0][2]
            mapa_vstup1[i,i+1] = out["pos_out"][0][0][0]
          #  mapa_vstup2[j,j+1] = out["out"][0][1] 

    mapa[0,0] = model.predict(vectorized,batch_size = 32)["out"][0][0]


    mapa_vstup1 = heatmapa(vstup1,vstup1,delic)
    mapa_vstup2 = heatmapa(vstup2,vstup2,delic)


    #for i in out["out"]:
     #   mapa[i,:] = out["out"][i*delka2:i*delka2+delka2-1]
    #
    #    d_pos1, pos_out = dist_model([pos1b, pos2a])
     #   d_pos2, pos_out2 = dist_model([neg1b, neg2a])

      #  d_neg1, neg_out = dist_model([pos1b, neg2a])
       # d_neg2, neg_out2 = dist_model([neg1b, pos2a])



    zmenenaMapa = pd.DataFrame(data = mapa, index = splittedString1, columns = splittedString2)
    zmenenaMapa_vstup1 = pd.DataFrame(data = mapa_vstup1, index = splittedString1, columns = splittedString1)
    zmenenaMapa_vstup2 = pd.DataFrame(data = mapa_vstup2, index = splittedString2, columns = splittedString2)
    Mapa(zmenenaMapa)
    Mapa(zmenenaMapa_vstup1)
    Mapa(zmenenaMapa_vstup2)

In [None]:
from os.path import exists
import os
import time
from IPython.display import clear_output
done = 0
skippedDone = 0
tooLong = 0
count = 0

inputText = ""
pickleRoute = "./Pickles/"
if(isProtein):
    inputText = "uniref90_tax-free_shuffled_val.tsv"
    pickleRoute +="Protein"
else:
    inputText = "CC-aug2018-oct2021_ces.10G_val.tsv"
    pickleRoute +="Words"
    
pickleRoute  += ( "CLS" if (isCLS) else "")

with open("./"+inputText,"r",encoding = "utf8") as reader:
 #testovaci_vzor.txt
#with open("./testovaci_vzor.txt","r",encoding = "utf8") as reader:
    startTime = time.time()
    #sumTime = 0
    for i in reader:
        if(count == 0):
            count +=1
            continue
        
        vstup = i.split('\t')[2]
        if(len(vstup)< 1200):
            

            print("Current Len: " , str(len(vstup)), "(+32) / Avg: ", str((time.time()-startTime)/done) if (not done == 0) else 0, " secs" )
            #print(vstup)
            if(exists(pickleRoute+"/"+MapResult.getHash(self = None, words = vstup)+".pickle")):
                #print("-",end="")
                clear_output(wait=True)
                skippedDone += 1
                print("Done:",str(done+skippedDone), " / " ,str(done), " / " ,str(skippedDone) )
                print("Too Long Skipped:" ,str(tooLong))
                
                continue


            vysl = MapResult(text=vstup)
            vysl.pickleThis(path = pickleRoute)
            #print(".", end ="")
            done += 1
        else:
            #print("*", end ="")
            tooLong += 1
        
        count +=1
        #os.system('cls')
        clear_output(wait=True)
        print("Done:",str(done+skippedDone), " / " ,str(done), " / " ,str(skippedDone) )
        print("Too Long Skipped:" ,str(tooLong))
        
    #print(reader.readline().split('\t')[2])