In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import json
import itertools
import pdb

def compare_dicts(dict1, dict2):
    for key in dict1.keys():
        if key == 'num_workers': #num workers doesn't matter
            continue
        else:
            if key != 'text_selection_method' and key != 'negatives_creation' and key != 'data_source' and key != 'ts_encoder':
                #These have particular serialise/deseerialise issues where it's originally a tuple and saved as a list 
                if dict1[key] != dict2[key]:
                    return False
            elif key == 'data_source' or key == 'ts_encoder':
                if dict1[key]['name'] != dict2[key]['name']:
                    return False
            else:
                if dict1[key][0] != dict2[key][0] or dict1[key][1] != dict2[key][1]:
                    return False
    return True
    
def check_args_not_used(data_parameters, model_parameters, output_file):
    with open(output_file, 'r') as file:
        data = json.load(file)
    for i in data:
        seen_dataset_params = i['dataset_params']
        seen_model_params = i['model_params']
        if compare_dicts(seen_dataset_params, data_parameters) and compare_dicts(seen_model_params, model_parameters):
            #return False if we have seen the args
            return False, i
    return True, None

def find_matching_combinations(json_file_path, model_param_grid, dataset_param_grid):    
    # Generate all combinations from model_param_grid and dataset_param_grid
    model_combinations = list(itertools.product(*model_param_grid.values()))
    dataset_combinations = list(itertools.product(*dataset_param_grid.values()))
    
    # Recreate dictionaries from combinations
    model_combinations_dicts = [dict(zip(model_param_grid.keys(), comb)) for comb in model_combinations]
    dataset_combinations_dicts = [dict(zip(dataset_param_grid.keys(), comb)) for comb in dataset_combinations]
    combined_combinations = [
        {"model_params": model_comb, "dataset_params": dataset_comb}
        for model_comb in model_combinations_dicts
        for dataset_comb in dataset_combinations_dicts
    ]
    # Store results
    matching_entries = []
    for combos in combined_combinations:
        model_param = combos['model_params']
        dataset_param = combos['dataset_params']

        args_seen, value = check_args_not_used(data_parameters=dataset_param, model_parameters=model_param, output_file=json_file_path)
        if not args_seen:
            #False if we have seen this set
            matching_entries.append(value)
            

    return matching_entries
model_param_grid = {
            "ts_encoder": [{"name": 'TimeSeriesTransformerModel'}],#{"name": "InformerModel"}, {"name": 'AutoFormerModel'}],
            "text_encoder": [{"name": 'bert-base-uncased'}],#, {"name": 'bert-base-cased'}],
            "text_encoder_pretrained": [True],                                                                       
            "text_aggregation_method": ["mean"],                                                    
            "projection_dim": [400, 500, 600],                                                                        
            "learning_rate": [1e-5],                                                                             
            "optimizer": ['adam'],                                                                                          
            "criterion": ['CosineEmbeddingLoss'],
            "num_epochs": [10],                                                                                             
            "batch_size": [6],                                                                                             
            "num_workers": [4],  
        }

dataset_param_grid = {                                                                            
    "ts_window": [6],#4, 6 & 7 had a random error out     3, 4, 5, 6, 7, 10                                                                    
    "ts_overlap": ['start'],                                                                    
    "text_window": [3],             #3, 4, 5, 6, 7                                           
    'text_selection_method': [('TFIDF', 5)],
    "data_source": [{
        "name": "EDT",
        "text_path": "./data/EDT/evaluate_news.json",
        "ts_path": "./data/stock_emotions/price/",
        "ts_date_col": 'Date',
        'text_date_col': 'date',
        'text_col': 'text',
        'train_dates': '01/01/2020 - 03/09/2020',
        'test_dates': '04/09/2020 - 31/12/2020'
    },{
        "name": "stock_emotion",
        "text_path": "./data/stock_emotions/tweet/processed_stockemo.csv",
        "ts_path": "./data/stock_emotions/price/",
        "ts_date_col": 'Date',
        'text_date_col': 'date',
        'text_col': 'text',
        'train_dates': '01/01/2020 - 03/09/2020',
        'test_dates': '04/09/2020 - 31/12/2020'
    },  {
        "name": "stock_net",
        "text_path": "./data/stocknet/tweet/organised_tweet.csv",
        "ts_path": "./data/stocknet/price/raw/",
        "ts_date_col": 'Date',
        'text_date_col': 'created_at',
        'text_col': 'text',
        'train_dates': '01/01/2014 - 01/08/2015',
        'test_dates': '01/08/2015 - 01/01/2016'
    },],                                                            
    "negatives_creation": [("sentence_transformer_dissimilarity", "mean")],# ("sentence_transformer_dissimilarity", "max"), ("sentence_transformer_dissimilarity", "min"), ("naive", 30), ("naive", 45), ("naive", 60)],                          
    "random_state": [42, 43, 44],
}

matching_results = find_matching_combinations('./output_frand_normalized_plotting.json', model_param_grid, dataset_param_grid)
print(matching_results)

[{'end_time': None, 'search_index': 1006, 'epochs': 10, 'dataset_params': {'ts_window': 6, 'ts_overlap': 'start', 'text_window': 3, 'text_selection_method': ['TFIDF', 5], 'data_source': {'name': 'EDT', 'text_path': './data/EDT/evaluate_news.json', 'ts_path': './data/stock_emotions/price/', 'ts_date_col': 'Date', 'text_date_col': 'date', 'text_col': 'text', 'train_dates': '01/01/2020 - 03/09/2020', 'test_dates': '04/09/2020 - 31/12/2020'}, 'negatives_creation': ['sentence_transformer_dissimilarity', 'mean'], 'random_state': 42}, 'model_params': {'ts_encoder': {'name': 'TimeSeriesTransformerModel', 'ts_window': 6, 'context_length': 1, 'prediction_length': 0, 'lags_sequence': [1, 2, 3, 4, 5], 'num_features': 3}, 'text_encoder': {'name': 'bert-base-uncased'}, 'text_encoder_pretrained': True, 'text_aggregation_method': 'mean', 'projection_dim': 500, 'learning_rate': 1e-05, 'optimizer': 'adam', 'criterion': 'CosineEmbeddingLoss', 'num_epochs': 10, 'batch_size': 6, 'num_workers': 4}, 'df_len'

In [5]:
len(matching_results)

9

In [3]:
import json

def save_to_json(data, file_path):
    # Use json.dump to save data to a JSON file
    with open(file_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)


save_to_json(matching_results, './normalized/windows.json')