"""
Copyright 2024 Georgia Institute of Technology

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
 """

# Gradient Testing on Generation

In [None]:
%load_ext autoreload
%autoreload 2

import time
import sys
import numpy as np
from math import log
import math
from scipy.stats import entropy
from Generator.cython_print_ips import print_ips
oldstderr = sys.stderr

import matplotlib.pyplot as plt
from iplibrary import AS_Processor
from Generator.SeedObject import SeedObject
from Generator.RunGeneration import Training_Evaluation, ComparisonModel_online
from Generator.Generator import Generator_AUL
from Generator.ModelWrapper import ModelBase
from Generator.AdditionalGenerators import  IterativeOnPriorPatternsLowerGeneratorFaster, FixedLowBit
from Generator.AllocationGenerator import NaiveWeightUpdateWithoutZerosEvenFirst
from Generator.NN_models import GeneratorMaskedLSTM
from Generator.DatasetEncoders import AllocationEncoder
from Generator.Sampling import predict_base_function, predict_base_function_faster
from Generator.MultiProcessSIDGenerator import MultiProcessSIDGenerator
from Generator.MultiProcessIIDGenerator import MultiProcessIIDGenerator, MultiProcessSamplingIIDGenerator


import config as conf

## Run Model

In [None]:
Comparison_name = "EXPERIMENT_NAME_HERE"
t1 = time.time()
seedDatasetObject = SeedObject(Comparison_name, 
                               sid_checkpoint=conf.CHECKPOINT_LSTM, 
                               dataset_prefix=conf.DATASET_FILE, 
                               lower_names_to_use=["all_ips"], 
                               lower=True, 
                               prefix_filename=conf.UPDATED_PFX_AS_FILE)
t2 = time.time()
print("Seed Dataset Time: ", t2-t1)

In [None]:
# Weighted Search - 100M - Balanced with Hits
t1 = time.time()
c = ComparisonModel_online(100000000, 
                            Comparison_name, 
                            Generator_AUL,
                            seedDatasetObject, 
                            ppi=1000000,
                            per_iteration=100000,
                            Lower64=MultiProcessIIDGenerator,
                            Upper64=MultiProcessSIDGenerator,
                            Allocations=NaiveWeightUpdateWithoutZerosEvenFirst,
                            Upper64_HPs={"sampler":predict_base_function_faster, 
                                     "model":GeneratorMaskedLSTM,
                                     "sampling_batch_size":10500,
                                     "gpus":8,
                                     "lr":1e-3, 
                                     "dropout":0.2, 
                                     "layers":[512, 256], 
                                     "encoder":AllocationEncoder, 
                                     "preload":True, 
                                     "validation_split":0.15},                         
                           Lower64_HPs={"Allocations":seedDatasetObject.allocation_proc_models.allocation_strings,
                                     "subprocesses":40},                            
                           Allocation_HPs={"threshold":0.5, 
                                            "observation_window":3},
                           )
t2 = time.time()
print("Time: ", t2-t1)
print("Time per IP:", (t2-t1)/100000000)

## Train Model

In [None]:
Comparison_name = "EXPERIMENT_NAME_HERE"
t1 = time.time()
seedDatasetObject = SeedObject(Comparison_name, 
                               sid_checkpoint=conf.CHECKPOINT_LSTM, 
                               dataset_prefix=conf.DATASET_FILE, 
                               lower_names_to_use=["all_ips"], 
                               lower=True,
                               prefix_filename=conf.UPDATED_PFX_AS_FILE)
t2 = time.time()
print("Seed Dataset Time: ", t2-t1)

In [None]:
# Weighted Search
t1 = time.time()
c = Training_Evaluation(50,
                        5000000, 
                        Comparison_name, 
                        Generator_AUL,
                        seedDatasetObject, 
                        ppi=100000,
                        Lower64=FixedLowBit,
                        Upper64=ModelBase,
                        Allocations=NaiveWeightUpdateWithoutZerosEvenFirst,
                        Upper64_HPs={"sampler":predict_base_function_faster, 
                                     "model":GeneratorMaskedLSTM,
                                     "sampling_batch_size":10500,
                                     "lr":1e-3, 
                                     "dropout":0.2, 
                                     "layers":[512, 256], 
                                     "encoder":AllocationEncoder, 
                                     "preload":False, 
                                     "validation_split":0.15},                            
                        Lower64_HPs={"Allocations":seedDatasetObject.allocation_proc_models.allocation_strings},
                        Allocation_HPs={"threshold":0.5, 
                                        "observation_window":3}
                       )
t2 = time.time()
print("Time: ", t2-t1)
print("Time per IP:", (t2-t1)/100000000)