# Training for N Batches with Dynamically Generated Data

In [113]:
import os
import json
import cv2
import string
import random
import albumentations as A
import copy
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, precision_recall_curve,f1_score, confusion_matrix, accuracy_score
from sklearn.naive_bayes import BernoulliNB 
import pickle
import warnings
warnings.filterwarnings('ignore') 

### Define Image Augmentation Functions


In [70]:
def generate_random_field(length = 0):
    length = length if length else random.randint(2, 20)
    field = ''.join(random.choices(string.ascii_letters + string.digits + string.punctuation, k=length))
    return field

transform = A.Compose([
        A.ImageCompression(quality_lower=10, p=0.1),
        A.OneOf([
            A.GaussNoise(p=0.8),
            A.ISONoise(p=0.2),
            A.MultiplicativeNoise(p=.05)
        ], p=0.1),
        A.OneOf([
            A.MotionBlur(p=.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ], p=0.05),
        A.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.2, rotate_limit=30, p=0.5),
        A.OneOf([
            A.OpticalDistortion(p=0.5),
            A.GridDistortion(p=.5),
            A.PiecewiseAffine(p=0.5),
        ], p=0.5),
        A.OneOf([
            A.CLAHE(clip_limit=2),
            A.Sharpen(),
            A.Emboss(),
            A.RandomBrightnessContrast(),            
        ], p=0.05),
        A.OneOf([
            A.RandomFog(),
            A.RandomRain(),
            A.RandomSnow(),
            A.RandomSunFlare(),            
        ], p=0.1),
        A.HueSaturationValue(p=0.01)
    ])

fonts = [
    cv2.FONT_HERSHEY_SIMPLEX,
    cv2.FONT_HERSHEY_COMPLEX,
    cv2.FONT_HERSHEY_PLAIN,
    cv2.FONT_HERSHEY_DUPLEX,
    cv2.FONT_HERSHEY_TRIPLEX,
    cv2.FONT_HERSHEY_COMPLEX_SMALL,
    cv2.FONT_HERSHEY_SCRIPT_COMPLEX,
    cv2.FONT_HERSHEY_SCRIPT_COMPLEX,
    cv2.FONT_ITALIC]

def generate_target_dictionary():
    with open('data_dictionary.json') as data_dict:
        categories = json.load(data_dict)['target_data']
    return categories

##### Variables we will need for augment

In [71]:
template_directory = 'templates_img'
text_locations = json.load(open('text_locations.json', 'r'))
image_dir =  "data/"
backgrounds_dir = 'image_backgrounds'

categories = generate_target_dictionary()

### Modeling Functions

In [28]:
def generate_models():
    models={}
    nb_full = BernoulliNB(alpha=10)
    for key in categories.keys():
        models[key] =  BernoulliNB(alpha=0.1)
    return models, nb_full

def fit_on_batch(x,y, models, nb_full):
    nb_full.partial_fit(x, y, classes=[0,1,2,3,4,5,6,7])
    for key in categories.keys():
        target_class = categories[key]
        models[key].partial_fit(x, y==target_class, classes=[0,1])


In [106]:
def test_on_batch(x,y, models, nb_full, batch):

    results = np.zeros((y.shape[0],8))
    nb_full_results = np.zeros((y.shape[0],8))

    y_pred = nb_full.predict_proba(x)[:]
    nb_full_results = y_pred
    for key in categories.keys():
        y_pred = models[key].predict_proba(x)[:,0]
        results[:,categories[key]] = y_pred

    test = np.copy(y)
    test2 = np.copy(y)
    test3 = np.copy(y)

    added_results = np.zeros(nb_full_results.shape)

    for  i in range(results[:,0].shape[0]):
        added_results[i,:] = results[i,:] + (np.absolute(nb_full_results[i,:]-1)/1.0e+200)
        max = np.where(results[i,:] == np.amin(results[i,:].reshape(8)))[0]
        max2 = np.where(added_results[i,:] == np.amin(added_results[i,:].reshape(8)))[0]
        max3 = np.where(nb_full_results[i,:]== np.amax(nb_full_results[i,:].reshape(8)))[0]
        if len(max)>1:
            test[i] = 5
        else:
            test[i]=max[0]

        if len(max2)>1:
            test2[i] = 5
        else:
            test2[i]=max2[0]

        if len(max3)>1:
            test3[i] = 5
        else:
            test3[i]=max3[0]

    with open("training_log.txt", "a") as file:
        
        file.writelines('\n\n\n')
        file.writelines('*|'*50)
        file.writelines('\n')
        file.writelines('*|'*50)
        file.writelines('\n')
        file.writelines('*|'*50)
        file.writelines('\n\n')
        file.writelines(f'\t\t\t\t\tIteration:{batch}')


        y_pred = test
        mask = y_pred != 5

        file.writelines('\n\nBase OVR Ensemble')
        file.writelines(f'\n\nData Size:\t{len(y_pred)}/{len(y_pred)}')
        file.writelines(f'\n\tAccuracy:\t{accuracy_score(y, y_pred)}')
        file.writelines(f'\n\tPrecision:\t{precision_score(y, y_pred, average="macro")}')
        file.writelines(f'\n\tRecall:\t{recall_score(y, y_pred, average="macro")}')
        file.writelines(f'\n\tF1:\t{f1_score(y, y_pred, average="macro")}\n')
        file.writelines(str(confusion_matrix(y, y_pred)))

        file.writelines('\n\n')
        file.writelines("==="*20)

        file.writelines('\n\nFiltered OVR Ensemble')
        file.writelines(f'\n\nData Size:\t{len(y_pred[mask])}/{len(y_pred)}')
        file.writelines(f'\n\tAccuracy:\t{accuracy_score(y[mask], y_pred[mask])}')
        file.writelines(f'\n\tPrecision:\t{precision_score(y[mask], y_pred[mask], average="macro")}')
        file.writelines(f'\n\tRecall:\t{recall_score(y[mask], y_pred[mask], average="macro")}')
        file.writelines(f'\n\tF1:\t{f1_score(y[mask], y_pred[mask], average="macro")}\n')
        file.writelines(str(confusion_matrix(y[mask], y_pred[mask])))
        
        file.writelines('\n\n')
        file.writelines("==="*20)

        y_pred = test3
        file.writelines('\n\nBase Multiclass')
        file.writelines(f'\n\nData Size:\t{len(y_pred)}/{len(y_pred)}')
        file.writelines(f'\n\tAccuracy:\t{accuracy_score(y, y_pred)}')
        file.writelines(f'\n\tPrecision:\t{precision_score(y, y_pred, average="macro")}')
        file.writelines(f'\n\tRecall:\t{recall_score(y, y_pred, average="macro")}')
        file.writelines(f'\n\tF1:\t{f1_score(y, y_pred, average="macro")}\n')
        file.writelines(str(confusion_matrix(y, y_pred)))

def checkpoint(models, nb_full, iteration):
    with open(f'models/EnsembleModels_{iteration}', 'ab') as file:
        pickle.dump(models, file)       

    with open(f'models/MulticlassModel_{iteration}', 'ab') as file:
        pickle.dump(nb_full, file)                       



### Image functions

In [107]:
def load_backgrounds():
    for filename in os.listdir(backgrounds_dir):
        backgrounds = []
        img = cv2.imread(backgrounds_dir+ '/'+filename, 1)
        backgrounds.append(img)
    return backgrounds



def agument_image(image, backgrounds):
    background_img = backgrounds[random.randint(0, len(backgrounds))-1]
    img = image
    for loc in doc_info:
        font = random.choice(fonts)
        cv2.putText(img, generate_random_field(),
                    (int(loc['x']),int(loc['y'])), font,
                    1, (0, 0, 0), 1)
    # resize(210,275)
    x_size = random.randint(-150,400)
    y_size = random.randint(-150,400)
    x_size = x_size if x_size > 150 else 0
    y_size = y_size if y_size > 150 else 0
    x_offset = int(x_size/1.5)
    y_offset = int(y_size/1.5)
    background_img = cv2.resize(background_img, (850+abs(x_size), 1100+abs(y_size))) 

    background_img[y_offset:y_offset+img.shape[0], x_offset:x_offset+img.shape[1]] = img
    transformed = transform(image=cv2.resize(background_img, (200,200)))
    img = transformed['image']

    # img = cv2.resize(img, (100,100))
    # img = cv2.resize(img, (200,200))

    # img = cv2.resize(img, (850+abs(x_size), 1100+abs(y_size)))
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = cv2.adaptiveThreshold(img,255,cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY,11,2)
    return img


for filename in text_locations:
    if text_locations[filename] != {}:
        image = cv2.imread(template_directory+ '/'+filename, 1)
        image = cv2.resize(image, (850, 1100)) 

        doc_info = text_locations[filename]
        for i in range(3):
            pass
            

### Main Loop!!

In [110]:
def main(batch_size, num_batches, test_frequency, models=False, nb_full=False):
    backgrounds = load_backgrounds()
    chunk = batch_size//8
    
    if not models:
        models, nb_full = generate_models()
    
    for iteration in range(num_batches):
        print(f'\nIteration:\t{iteration}')
        x = np.zeros(shape=(batch_size,200*200))
        y = np.zeros((batch_size,))
        index = 0
        #generate our data
        for filename in text_locations:
            
            if text_locations[filename] != {}:
                image = cv2.imread(template_directory+ '/'+filename, 1)
                image = cv2.resize(image, (850, 1100)) 

                doc_info = text_locations[filename]
                for row in range(chunk):
                    print(index,end='\r', flush=True)
                    img = agument_image(image, backgrounds)
                    x[index] =  np.reshape(img, (200*200))
                    y[index] = categories[filename[:4]]
                    index = index + 1
        

        if iteration%test_frequency == 0 and iteration!= 0:
            test_on_batch(x, y, models, nb_full, iteration)
            checkpoint(models, nb_full, iteration)
        else:
            fit_on_batch(x, y, models, nb_full)
    return models, nb_full


In [115]:
models, nb_full = main(batch_size=8000, num_batches=7, test_frequency=3)


Iteration:	0
6012

{'Inco': BernoulliNB(alpha=0.1),
 'Teac': BernoulliNB(alpha=0.1),
 'Cons': BernoulliNB(alpha=0.1),
 'Publ': BernoulliNB(alpha=0.1),
 'Econ': BernoulliNB(alpha=0.1),
 'TeaF': BernoulliNB(alpha=0.1),
 'Gene': BernoulliNB(alpha=0.1),
 'Reaf': BernoulliNB(alpha=0.1)}



array([[-3.47265635, -3.47265635, -3.41858912, ..., -3.41858912,
        -3.36729583, -3.47265635],
       [-0.69314718, -0.69314718, -0.69314718, ..., -0.69314718,
        -0.69314718, -0.69314718],
       [-0.69314718, -0.69314718, -0.69314718, ..., -0.69314718,
        -0.69314718, -0.69314718],
       ...,
       [-0.69314718, -0.69314718, -0.69314718, ..., -0.69314718,
        -0.69314718, -0.69314718],
       [-0.69314718, -0.69314718, -0.69314718, ..., -0.69314718,
        -0.69314718, -0.69314718],
       [-0.69314718, -0.69314718, -0.69314718, ..., -0.69314718,
        -0.69314718, -0.69314718]])