In [1]:
import os
import pickle

data_folder = "./Data"
file_path = os.path.join(data_folder, "pytorch_database.pkl")

with open(file_path, "rb") as f:
    loaded_data = pickle.load(f)


In [2]:
loaded_data.keys()

dict_keys(['seq_tensor', 'sim_tensor', 'file_names', 'mode_descriptions'])

In [3]:
loaded_data['mode_descriptions'][124]

('Keyhole2', 'Conduction2')

In [4]:
# Modes for testing
test_modes = [
    ('Conduction2', 'Conduction2'),
    ('Conduction2', 'Keyhole2'),
    ('Keyhole2', 'Conduction2'),
    ('Keyhole2', 'Keyhole2'),
]

In [5]:
mode_descriptions = loaded_data['mode_descriptions']
# Identify indexes for test data
test_indexes = [i for i, mode in enumerate(mode_descriptions) if mode in test_modes]


In [6]:
# Identify indexes for train data
total_indexes = set(range(len(mode_descriptions)))
train_indexes = list(total_indexes - set(test_indexes))

# Function to extract sub-tensors based on indexes
def extract_sub_tensors(tensor, indexes):
    return tensor[indexes]

# Allocate data to train and test dictionaries
train_data = {
    "seq_tensor": extract_sub_tensors(loaded_data["seq_tensor"], train_indexes),
    "sim_tensor": extract_sub_tensors(loaded_data["sim_tensor"], train_indexes),
    "file_names": [loaded_data["file_names"][i] for i in train_indexes],
    "mode_descriptions": [loaded_data["mode_descriptions"][i] for i in train_indexes],
}

test_data = {
    "seq_tensor": extract_sub_tensors(loaded_data["seq_tensor"], test_indexes),
    "sim_tensor": extract_sub_tensors(loaded_data["sim_tensor"], test_indexes),
    "file_names": [loaded_data["file_names"][i] for i in test_indexes],
    "mode_descriptions": [loaded_data["mode_descriptions"][i] for i in test_indexes],
}


In [7]:
# Function to print file names and modes for a given dataset
def print_file_info(dataset, dataset_name):
    print(f"Files in {dataset_name}:")
    for file_name, mode in zip(dataset["file_names"], dataset["mode_descriptions"]):
        print(f"File Name: {file_name}, Mode: {mode}")
    print("\n")  


In [8]:
# Loop over all files in the train set
print_file_info(train_data, "Train Data")


Files in Train Data:
File Name: ('File_0', 'File_10'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_1', 'File_11'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_2', 'File_12'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_3', 'File_13'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_4', 'File_14'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_5', 'File_15'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_6', 'File_16'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_7', 'File_17'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_8', 'File_18'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_9', 'File_19'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_20', 'File_30'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_21', 'File_31'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_22', 'File_32'), Mode: ('Conduction 2', 'Conduction 2')
File Name: ('File_23', 'Fil

In [9]:
# Loop over all files in the test set
print_file_info(test_data, "Test Data")

Files in Test Data:
File Name: ('File_70', 'File_85'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_71', 'File_86'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_72', 'File_87'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_73', 'File_88'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_74', 'File_89'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_75', 'File_90'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_76', 'File_91'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_77', 'File_92'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_78', 'File_93'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_79', 'File_94'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_80', 'File_95'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_81', 'File_96'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_82', 'File_97'), Mode: ('Conduction2', 'Conduction2')
File Name: ('File_83', 'File_98'), Mode: ('C

In [10]:
import torch

torch.save(train_data, './Data/train_classification')
torch.save(test_data, './Data/test_classification')