In [None]:
import sys
import os
import logging

logger = logging.getLogger('retrain_ipynb')
logger.setLevel(logging.DEBUG)


ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(levelname)s - %(asctime)s - %(name)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.debug('--=logging started=--')



# Config

In [None]:
TRAIN = False

USE_CONTROL_SET = True


TRAIN_TEST_SPLIT_SEED = 49
TEST_SIZE = 0.1
TRAIN_FROM_CP = True  

# CP===check point

# EPOCHS = 120
EPOCHS = 120
LR = 5.000e-4

#learning rate

TEST_FLOW = False
DEBUG = False


# CHECKPOINT_URL = 'runs:/89d43de209874227af95fcbeaf048340/model'
CHECKPOINT_URL = None


BATCH_SIZE = 72
EMB =  1024

In [None]:


print (f'{USE_CONTROL_SET=}')
print (f'{TRAIN_FROM_CP=}')
print (f'{LR=}')
print (f'{EPOCHS=}')
print (f'{TRAIN=}')
print (f'{DEBUG=}')
print (f'{CHECKPOINT_URL=}')


print (f'{BATCH_SIZE=}')

In [None]:

nb_dir = os.path.split(os.getcwd())[0]
if nb_dir not in sys.path:
  sys.path.append(nb_dir)
 

In [None]:
import random
import math
import json
import warnings

from os import path
from pathlib import Path



# Imports

In [None]:
%matplotlib inline

import tensorflow as tf
from tensorflow import keras
import keras.backend as K

print(f'{tf.__version__=}')

import seaborn as sns
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML



import pickle
import random
import numpy as np
import pandas as pd

from pandas import DataFrame 

from bson import json_util

 
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger
from keras.models import Sequential, Model, load_model
 


#-------- ours

from analyser.legal_docs import LegalDocument, make_headline_attention_vector
from analyser.headers_detector import TOKEN_FEATURES
from analyser.hyperparams import models_path, work_dir, notebooks_dir, reports_dir
from analyser.persistence import DbJsonDoc

from trainsets.retrain_contract_uber_model import UberModelTrainsetManager

from tf_support import super_contract_model
from tf_support.super_contract_model import semantic_map_keys_contract
from tf_support.super_contract_model import validate_datapoint
from tf_support.super_contract_model import make_xyw
from tf_support.super_contract_model import config, make_att_model
from tf_support.super_contract_model import FEATURES 
from tf_support.super_contract_model import sigmoid_focal_crossentropy, losses

from tf_support.tf_subject_model import decode_subj_prediction
from tf_support.tools import KerasTrainingContext

from colab_support.renderer import *
 
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
 



# Init mlflow

In [None]:
import mlflow

ml_flow_url = configured('MLFLOW_URL')
mlflow.set_tracking_uri(ml_flow_url)
print(f'{ml_flow_url=}', 'set MLFLOW_URL env var to re-define')
mlflow.set_experiment("Обучение анализатора")

mlflow.tensorflow.autolog()

# Prepare trainset


In [None]:
from integration.db import get_mongodb_connection  
from bson.objectid import ObjectId
from pandas import DataFrame

In [None]:

COLLECTION_NAME = 'documents'
if USE_CONTROL_SET:
    COLLECTION_NAME = 'documents_temp'
    
mongodb_connection = get_mongodb_connection()

if USE_CONTROL_SET: 
    documents_collection = mongodb_connection[COLLECTION_NAME]
    


## Load DS metafile

In [None]:
umtm = UberModelTrainsetManager (work_dir, reports_dir=reports_dir)

umtm.load_contract_trainset_meta() # 'contract_trainset_meta.csv'
stats = umtm.stats

if DEBUG:
    stats

In [None]:
# stats[ ['org-1-alias', 'org-2-alias'] ]

user_dataset = stats[ stats['unseen']==False]

print(f'{len(user_dataset)=}')
print(f'{len(stats)=}')


# mlflow.log_param('dataset_len_user', len(user_dataset) )
# mlflow.log_param('dataset_len', len(stats) )

In [None]:
if DEBUG:
    user_dataset[user_dataset.subj_len>=150]

# Weights: вычисление весов samples  

 - weight id proportional to log of contract price (less errors in expencive contracts)
 - more weight for user-corrected datapoints
 - normalize weights, so the sum == Number of samples
 - smaller weight for docs with human mark-up errors

In [None]:
errors_file = reports_dir / 'user_markup_errors.csv'
print(f'{errors_file=}')
try:
  errors_df = pd.read_csv(errors_file, index_col=0)
except:
  print(f'cannot read {errors_file}')
  errors_df = DataFrame(columns=['errors severity'])

In [None]:


stats = umtm.stats
stats = stats[stats.documentType != 'ANNEX']
stats = stats[stats.documentType != 'undefined']

print(len(stats))
get_feature_log_weights(stats, 'documentType')

# stats.sort_values(['Дата']) 

if DEBUG:
    stats

In [None]:
# from trainsets.trainset_tools import get_feature_log_weights
# _w=get_feature_log_weights(umtm.stats, 'subject')
# _w*_w*_w

In [None]:
%%time

import datetime
from trainsets.trainset_tools import get_feature_log_weights

subject_weights = get_feature_log_weights(stats, 'subject')
subject_weights = subject_weights * subject_weights * subject_weights

for i, row in stats.iterrows():
  subj_name = row['subject']

  #error weight
  error_weight = 1.0 
  if i in errors_df.index:
      error_weight = 1.0 + errors_df.at[i, 'errors severity']


  sample_weight = 0.5 
  value_weight = 1.0
    
  if i in errors_df.index:
       
      if type(errors_df.at[i, 'Дата'])==str:
#             print (errors_df.at[i, 'Дата'], i, 'EXISTS')
          value_weight *=1.5
    
  if not pd.isna(row['user_correction_date']):  # more weight for user-corrected datapoints
    sample_weight = 5.0   # TODO: must be estimated anyhow smartly    

  
  if not pd.isna(row['value_log1p']):
    # чтобы всех запутать, вес пропорционален логорифму цены контракта
    # (чтобы было меньше ошибок в контрактах на большие суммы)
    value_weight = 1.0 + row['value_log1p']

  sample_weight *=  value_weight 
  subject_weight =  subject_weights[subj_name] 
    
  sample_weight /= error_weight  

  stats.at[i, 'subject_weight'] = subject_weight + random.random()*0.05
  stats.at[i, 'sample_weight'] = sample_weight + random.random()*0.05

# normalize weights, so the sum == Number of samples
# stats.sample_weight /= stats.sample_weight.mean()
# stats.subject_weight /= stats.subject_weight.mean()

print(f'{stats.sample_weight.mean()=}')
print(f'{stats.subject_weight.mean()=}')
print(f'{stats.sample_weight.min()=}')
print(f'{stats.subject_weight.min()=}')
print(f'{stats.sample_weight.max()=}')
print(f'{stats.subject_weight.max()=}')

stats.to_csv( work_dir / 'contract_trainset_meta.csv', index=True)

# stats

# Validating training set

In [None]:
%%time


# stats['valid'] = True
stats['error'] = ''

 
pos = 0 
for i in stats.index:

  try:
    validate_datapoint(str(i), stats)

  except Exception as e:
    ms = f'{pos} of {len(stats.index)} : {e}' 
    logger.error(ms)

    stats.at[i, 'valid'] = False
    stats.at[i, 'error'] = str(e)
  pos += 1
# stats

 

umtm.stats = umtm.stats[  pd.isna(umtm.stats.value_span) + (umtm.stats.value_span < 10000) ] #remove big docs from TS
stats_valid = stats[stats['valid']]

del stats
print(len(stats_valid))
stats = stats_valid
umtm.stats = stats_valid

In [None]:
print(f'{stats.sample_weight.mean()=}')
print(f'{stats.subject_weight.mean()=}')

stats.sample_weight /= stats.sample_weight.mean()
stats.subject_weight /= stats.subject_weight.mean()



In [None]:
print(f'{stats.sample_weight.mean()=}')
print(f'{stats.subject_weight.mean()=}')

print('\n\nsample_weight')
print('MIN\t', stats.sample_weight.min())
print('MAX\t', stats.sample_weight.max())
print('MEAN\t', stats.sample_weight.mean())

In [None]:
from sklearn.utils import class_weight
from trainsets.trainset_tools import get_feature_log_weights

_classes = stats_valid['subject'].unique().tolist()

print(f'classes: {_classes}')

# class_weights = class_weight.compute_class_weight('balanced', _classes, umtm.stats['subject'])
# class_weights = dict(zip(_classes, class_weights))


class_weights = get_feature_log_weights(stats_valid, 'subject')
class_weights

In [None]:
from trainsets.trainset_tools import get_feature_log_weights
import mlflow

fig = plt.figure(figsize=(13, 6))

stats_valid['subject_weight'].hist(bins=40, alpha=0.5)
stats_valid['sample_weight'].hist(bins=40, alpha=0.5)

plt.xscale('linear') # log?
plt.show()
mlflow.log_figure(fig, 'Weights Distribution.png')

In [None]:
# fig = plt.figure(figsize=(13, 6))

p = sns.jointplot(x="subject_weight", y="sample_weight", data=stats_valid)
plt.show()
# print(p)

# look into trainset (take a sample)

In [None]:
%matplotlib inline
# umtm.calculate_samples_weights()
 
if True:   
    SAMPLE_DOC_ID =  stats_valid.index[2]

    print('SAMPLE_DOC_ID', SAMPLE_DOC_ID)
    (emb, tok_f), (sm, subj), (sample_weight, subject_weight) = make_xyw(SAMPLE_DOC_ID, stats)


    print('semantic map shape is:', sm.shape)
    _crop = 500
    # plot_embedding(tok_f[:_crop], title=f'Tokens features {SAMPLE_DOC_ID}') 
    # plot_embedding(emb[:_crop], title=f'Embedding {SAMPLE_DOC_ID}') 
    plot_embedding(sm[:_crop], title=f'Semantic map {SAMPLE_DOC_ID}')



In [None]:

if DEBUG: 
  plot_embedding(sm[:, 1::2][:200], title=f'Semantic map {SAMPLE_DOC_ID}')

In [None]:
if DEBUG:   
    nonzerozz = np.where(sm[:, 1::2] > 0)[0]
    max_len = 1536
    # nonzerozz = list(set(nonzerozz))
    # nonzerozz

    # c=random.choice(nonzerozz)
    # sm[c-1:c]
    sm = sm*100.
    for i in range(0,2000):
        segment_center = random.choice(nonzerozz)

        _off = random.randint(-max_len//40, max_len//2)
        start_from = segment_center - _off
        if start_from < 0:
            start_from = 0
        if start_from >=len(emb):
            start_from = len(emb)-1

        sm[start_from: start_from+max_len]+=1

    plot_embedding(sm, title=f'Semantic map {SAMPLE_DOC_ID}')


# Batch generator & TODOs 🙏


- [X] TODO: add outliers to the trainset ?
- [ ] TODO: try sparse_categorical_entropy instead of one-hot encodings
- [ ] TODO: model 5.2, 5.1: bipolar concat layer is wrong because we concatenate thongs of different magnitudes. Add a Sigmoid activation layer
- [ ] TODO: chechk what is better: to pad with zeros or to pad with means
- [X] TODO: add weights to samples
- [ ] TODO: sum semantic map alongside vertical axis, and mutiply it (as a mask) by the subject detection seq
- [ ] TODO: introduce individual per tag threshosholds, also, the current 0.3 threshold is strange.

In [None]:


MAX_LEN = 1536
def make_generator(self, indices: [int], batch_size: int, augment_samples=False):
  #   np.random.seed(43)

  while True:
    # next batch
    batch_indices = np.random.choice(a=indices, size=batch_size)

    max_len = MAX_LEN
    start_from = 0

    if augment_samples:
      max_len = random.randint(300, MAX_LEN)

    batch_input_emb = []
    batch_input_token_f = []
    batch_output_sm = []
    batch_output_subj = []

    weights = []
    weights_subj = []

    # Read in each input, perform preprocessing and get labels
    for doc_id in batch_indices:

      dp = make_xyw(doc_id, stats)
      (emb, tok_f), (sm, subj), (sample_weight, subject_weight) = dp

      #       print(dp)

      subject_weight_K = 1.0
      if augment_samples:
        start_from = 0

        # row = stats_valid.loc[doc_id]
        if random.random() < 0.6:  # 60% of samples
          nonzerozz = np.where(sm[:, 1::2] > 0)[0] #take every second row, because these are end marks
#           nonzerozz = nonzerozz
          
          segment_center = random.choice(nonzerozz)
          if len(nonzerozz)==0:
             segment_center=0


          # segment_center = random.randint(0, len(emb) - 1)  ##select random token as a center

          # if not pd.isna(row['value_span']) and random.random() < 0.7:  ##select value token as a center
          #   segment_center = int(row['value_span'])

          # _off = random.randint(max_len // 4, max_len // 2)
          _off = random.randint(-max_len//10, max_len//2)
          start_from = segment_center - _off
          if start_from < 0:
            start_from = 0
          if start_from >=len(emb):
            start_from = len(emb)-1
#           print('start_from', start_from)
#           if random_row != 1:#subject row, see semantic_map_keys
#               subject_weight_K = 0.1  # lower subject weight because there mighе be no information about subject around doc. value

      # dp = self.trim_maxlen(dp, start_from, max_len)
      dp = UberModelTrainsetManager.trim_maxlen(dp, start_from, max_len)

      # TODO: find samples maxlen

      (emb, tok_f), (sm, subj), (sample_weight, subject_weight) = dp
      #       print((sample_weight, subject_weight))
      subject_weight *= subject_weight_K

      batch_input_emb.append(emb)
      batch_input_token_f.append(tok_f)

      batch_output_sm.append(sm)
      batch_output_subj.append(subj)

      if np.isnan(sample_weight):
        raise ValueError()

      if np.isnan(subject_weight):
        raise ValueError()

      weights.append(sample_weight)
      weights_subj.append(subject_weight)
      # end if emb
    # end for loop

    # Returns a tuple of (input, output, weights) to feed the network
    #     print('batch_output_subj', len(batch_output_subj))
    #     print('batch_output_sm', len(batch_output_subj))

    yield ([np.array(batch_input_emb), np.array(batch_input_token_f)],
           [np.array(batch_output_sm), np.array(batch_output_subj)],
           [np.array(weights), np.array(weights_subj)])
    

    
_train, _test = train_test_split(stats_valid, test_size=TEST_SIZE, stratify=stats_valid[['subject']], random_state=TRAIN_TEST_SPLIT_SEED)

train_indices = list(_train.index)
test_indices = list(_test.index)

    
####---test
_gen = make_generator(umtm, train_indices, batch_size=10, augment_samples=True)

sample = next(_gen)
# print(len(sample))
del _gen
 
(emb, tok_f), (sm, subj), (sample_weight, subject_weight) = sample
    
print('semantic map shape is:', sm.shape)
_crop = 1500
_ = plot_embedding(tok_f[0][:_crop], title=f'Tokens features') 
# plot_embedding(emb[:_crop],   title=f'Embedding {SAMPLE_DOC_ID}') 
_ = plot_embedding(pd.DataFrame( sm[0], columns= semantic_map_keys_contract) [:_crop],    title=f'Semantic map', height=8)

## [debug] Diagnose SM Rows in TS

In [None]:
# plot_embedding(pd.DataFrame( sm[0], columns= semantic_map_keys_contract) [:_crop],    title=f'Semantic map', height=8)
if DEBUG:
    _crop = 350
    batch_size=100
    mtx = np.zeros((batch_size, _crop))


    _gen = make_generator(umtm, train_indices, batch_size=batch_size, augment_samples=True) 

 


    for i in range(0, batch_size):

        sample = next(_gen)
        (emb, tok_f), (sm, subj), (sample_weight, subject_weight) = sample


        sub = sm[0][:,4:6][:_crop] #pd.DataFrame( sm[0], columns= semantic_map_keys_contract) [:_crop][['date-begin']]

    #     print(sub[:,0].max())
        mtx [i][0:len(sub[:,0])] = sub[:,0]

    fig = plot_embedding(mtx.T,    title=f'Semantic map, combined Date rows of {batch_size} samples', height=5)
    mlflow.log_figure(fig, 'Diagnose SM Rows in TS.png')
    del _gen
    del mtx

In [None]:
from IPython.display import display, HTML, Markdown

mlflow.log_param('training set', len(train_indices) )
mlflow.log_param('evaluation set', len(test_indices) )

_s = f"#### {len(train_indices)} -- total  docs training set"
display(Markdown(_s))
_s = f"#### {len(test_indices)} -- total  docs validation set"
display(Markdown(_s))

In [None]:
if DEBUG:
    dp = make_xyw(stats.index[0], stats_valid)

    (emb, tok_f), (sm, subj), (sample_weight, subject_weight) = dp

    fig = plot_embedding(sm[:500],    title=f'Semantic map ')

In [None]:


print('train_indices[0]:', train_indices[0])
print('test_indices[0]:', test_indices[0])


def plot_subject_distr(df:DataFrame, title):  
    target='subject'
    fig = plt.figure(figsize=(16,4))   
    sns.set(style="whitegrid")
    chart = sns.countplot(data=df.sort_values(target), y=target)
    t = f'{title}: Frequency Distribution of subjects; {len(df)} total'
    plt.title(t) 

    _fn = reports_dir / f'Distribution of subjects -{title}.png'
    plt.savefig(_fn , bbox_inches='tight', pad_inches=0)
    plt.show()

    mlflow.log_artifact(_fn)


 
plot_subject_distr(stats_valid, 'ALL')
plot_subject_distr(stats_valid[stats_valid.index.isin(train_indices)], 'training')
plot_subject_distr(stats_valid[stats_valid.index.isin(test_indices)], 'eval')



if DEBUG:   
  # test_gen = make_generator(umtm, test_indices, BATCH_SIZE)
  train_gen = make_generator(umtm, train_indices, BATCH_SIZE, augment_samples=True)
  
  x, y, w = next(train_gen)
  
#   print('X:', len(x), 'X[0]=', x[0].shape, 'X[1]=', x[1].shape)
#   print('Y:', len(y), 'Y[0]=', y[0].shape, 'Y[1]=', y[1].shape)
  

#   plot_embedding(x[0][0], 'X2: Token Embeddings')
#   plot_embedding(x[1][0], 'X1: Token Features')
#   plot_embedding(y[0][0], 'Y: Semantic Map')
  
#   print(y[0][1])

#   del x
#   del w
#   del y
#   del train_gen

In [None]:

ctx = KerasTrainingContext(checkpoints_path=umtm.reports_dir, session_index=1)

ctx.set_batch_size_and_trainset_size(BATCH_SIZE, 
                                     len(test_indices), 
                                     4 * len(train_indices))

DEFAULT_TRAIN_CTX = ctx
CLASSES = 43
FEATURES = 14

metrics = [  'mse', 'binary_crossentropy']
# metrics = ['kullback_leibler_divergence', 'mse', 'binary_crossentropy']


def train(umodel):
  test_gen = make_generator(umtm, test_indices, BATCH_SIZE)
  train_gen = make_generator(umtm, train_indices, BATCH_SIZE, augment_samples=True) 
  ctx.train_and_evaluate_model(umodel, generator=train_gen, test_generator=test_gen)

def overtrain(umodel):
  test_gen = make_generator(umtm, list(train_indices) + list(test_indices), BATCH_SIZE)
  train_gen = make_generator(umtm, list(train_indices) + list(test_indices), BATCH_SIZE, augment_samples=True) 
  ctx.train_and_evaluate_model(umodel, generator=train_gen, test_generator=test_gen)


# Models 🦖

In [None]:
import analyser
def get_weights_filename(model_factory_fn):
    weights = ctx.model_checkpoint_path / f'{model_factory_fn.__name__}.h5'
    print(weights.is_file(), weights)
    if not weights.is_file():
        weights = Path(analyser.hyperparams.models_path) / f'{model_factory_fn.__name__}.h5'
        print(weights.is_file(), weights)
    
    return weights

# get_weights_filename(uber_detection_model_005_1_1)

## 🥰 Att model

In [None]:



model_factory_fn = make_att_model      
# TRAIN_FROM_CP = True
if TRAIN_FROM_CP:
    weights = get_weights_filename(model_factory_fn)
else:
    # weights = Path(analyser.hyperparams.models_path) / f'{model_factory_fn.__name__}.h5'
    #TODO: fix this mess with TRAIN_FROM_CP and CHECKPOINT_URL flags
    weights = None
    # 
    
# weights = Path(analyser.hyperparams.models_path) / f'{model_factory_fn.__name__}.h5'


# TRAIN_FROM_CP=True        
if not TEST_FLOW:
    
    umodel = make_att_model() 
    
    print(f'{umodel.name=}')    

    if CHECKPOINT_URL is not None:
        logger.info(f'LOADING {CHECKPOINT_URL}')
        umodel = mlflow.tensorflow.load_model(CHECKPOINT_URL)
        # TODO:TEST IT, FIX IT
    else:
        if TRAIN_FROM_CP:
            logger.info(f'LOADING {weights}')
            print(f'LOADING {weights=}')
            umodel.load_weights(weights, by_name=True, skip_mismatch=True)
        else:
            logger.warning(f'skip loading weights, because {TRAIN_FROM_CP=}')

    # if DEBUG:    
    umodel.summary()

# raise "forsibly stopped"


In [None]:
# dot_img_file = f'{umodel.name}.png'
# keras.utils.plot_model(umodel, to_file=dot_img_file, show_shapes=True)



In [None]:
def train_and_evaluate_model(self, model:Model, generator, test_generator, retrain=False, lr=None):
    print(f'model.name == {model.name}')
    self.trained_models[model.name] = model.name
    if self.EVALUATE_ONLY:
      print(f'training skipped EVALUATE_ONLY = {self.EVALUATE_ONLY}')
      return

    _log_fn = f'{model.name}.{self.session_index}.log.csv'
    _logger1 = CSVLogger(self.model_checkpoint_path / _log_fn, separator=',', append=not retrain)
    _logger2 = CSVLogger(_log_fn, separator=',', append=not retrain)

    checkpoint_weights = ModelCheckpoint(self.model_checkpoint_path / (model.name + ".h5"),
                                         monitor='val_loss', mode='min', save_best_only=True, save_weights_only=True,
                                         verbose=1)

    lr_logged = None
    if not retrain:
      lr_logged, epoch = self.get_lr_epoch_from_log(model.name)
    else:
      epoch = 0

    if lr_logged is not None:
      K.set_value(model.optimizer.lr, lr_logged)

    if lr is not None:
      K.set_value(model.optimizer.lr, lr)

    print(f'continue: lr:{K.get_value(model.optimizer.lr)}, epoch:{epoch}')


    history = model.fit(
                    generator, batch_size=BATCH_SIZE,
#                     steps_per_epoch=train_steps,
                    epochs=self.EPOCHS,
                    validation_data=test_generator,
                    validation_steps=self.validation_steps,
                    steps_per_epoch=self.steps_per_epoch,
#                     class_weight=class_weights,
                    initial_epoch=epoch, 
#                     workers=8,
                    callbacks=[self.reduce_lr, checkpoint_weights, _logger2, _logger1]
                    )
    
    

    self.HISTORIES[model.name] = history
    self.save_stats(model.name)

    return history


if not TEST_FLOW:

    if TRAIN:
      config.LR = LR
      ctx.unfreezeModel(umodel)
    #   umodel.summary()

      ctx.EPOCHS = EPOCHS
      ctx.EVALUATE_ONLY = False

      BATCH_SIZE = 96
      test_gen = make_generator(umtm, train_indices + test_indices, BATCH_SIZE)
      train_gen = make_generator(umtm, train_indices + test_indices, BATCH_SIZE, augment_samples=True) 

      train_and_evaluate_model(ctx, umodel, train_gen, test_generator=test_gen, retrain=True, lr=config.LR)
    else:
      logger.warning(f'skip training, because TRAIN={TRAIN}')


    threshold = umodel.get_layer('O1_tagging').get_weights()
    if threshold:
        print('threshold=', threshold[0][0])

        mlflow.log_metric('trained_tags_threshold', threshold[0][0])

## Register model in MLFlow

In [None]:
mlflow.log_artifact(ctx.model_checkpoint_path / f'{model_factory_fn.__name__}.h5')

# Evaluate last checkpoint

In [None]:
if not TEST_FLOW:
    if umodel:
        del umodel

    #######################################
    #######################################
    model_fn = make_att_model
    # model_fn = uber_detection_model_003
    #######################################
    #######################################


    weights = ctx.model_checkpoint_path /  f'{model_factory_fn.__name__}.h5'
    logger.info(f'LOADING {weights}')
    print(f'LOADING {weights}')

    umodel = make_att_model() 
    umodel.load_weights(weights, by_name=False, skip_mismatch=False)
    umodel.trainable = False
    umodel.summary()

    # umodel = ctx.init_model(model_fn, trained=True, trainable=False, weights=ctx.model_checkpoint_path / f'{model_fn.__name__}.h5')



    #TODO: remove next 2 lines
    ctx.trained_models[umodel.name] = umodel.name
    models = ctx.trained_models

### training history

In [None]:
def plot_compare_models(
    models: [str],
    metrics, 
    title="metric/epoch",
    image_save_path = umtm.reports_dir):
    
  _metrics = [m for m in metrics if not m.startswith('val_')]

  for i, m in enumerate(models):

    data: pd.DataFrame = ctx.get_log(m)

    if data is not None:
      data.set_index('epoch')

      for metric in _metrics:
        plt.figure(figsize=(16, 6))
        plt.grid(True)
        plt.title(f'{metric} [{m}]')
        for metric_variant in ['', 'val_']:
          key = metric_variant + metric
          if key in data:

            x = data['epoch'][-100:]
            y = data[key][-100:]


            c = 'red'  # plt.cm.jet_r(i * colorstep)
            if metric_variant == '':
              c = 'blue'
            plt.plot(x, y, label=f'{key}', alpha=0.2, color=c)

            y = y.rolling(4, win_type='gaussian').mean(std=4)
            plt.plot(x, y, label=f'{key} SMOOTH', color=c)

            plt.legend(loc='upper right')

        
        plt.title(f'{[m]} {title}')
        plt.grid(True)
        img_path = os.path.join(image_save_path, f'{m}-{metric}.png')
        
        plt.savefig(img_path, bbox_inches='tight')        
        plt.show()
    else:
      logger.error('cannot plot')
    
if not TEST_FLOW:
    models = list(ctx.trained_models.keys())


    plot_compare_models(models, ['loss'], 'Loss')

    # plot_compare_models(models, ['O1_tagging_kullback_leibler_divergence'], 'TAGS: Kullback Leibler divergence')
    # plot_compare_models(models, ['O1_tagging_mse'], 'TAGS: MSE')
    # plot_compare_models(models, ['O2_subject_kullback_leibler_divergence'], 'Subj: Kullback Leibler divergence')
    # plot_compare_models(models, ['O2_subject_mse'],  'Subjects: MSE')

    plot_compare_models(models, ['O1_tagging_loss', 'O2_subject_loss'], 'Loss')


In [None]:
from tf_support.super_contract_model import make_xyw
logger.error("fix prediction!!")
if False:
    sample_index = umtm.stats [umtm.stats['value']>0].index[2]
    logger.info(f'making prediction for sample doc {sample_index}')

    x, y, _ = make_xyw(sample_index, umtm.stats)
    print(f'shape of x[0]={x[0].shape}')
    print(f'shape of x[1]={x[1].shape}')

    t1 = np.expand_dims(x[0], axis=0)
    t2 = np.expand_dims(x[1], axis=0)

    print(f'shape of t1={t1.shape}')
    print(f'shape of t2={t2.shape}')
    print(f'umodel.name ={umodel.name}')

    prediction = umodel.predict(x=[t1, t2], batch_size=1)

    tagsmap = pd.DataFrame(prediction[0][0], columns=semantic_map_keys_contract)
    # .T
    plot_embedding(tagsmap[:500], f'Predicted Semantic Map {tagsmap.shape}')

# Evaluate recent model (with external notebook)

In [None]:
if not TEST_FLOW:
    %run -i -t {notebooks_dir}/eval_contract_uber_model.ipynb
 

In [None]:
print('see results at')
print(f'{mlflow.get_registry_uri()}/#/experiments/{mlflow.last_active_run().info.experiment_id}/runs/{mlflow.last_active_run().info.run_id}')