In [None]:
import os
import glob
import math
import random
import csv

DATASET_INFO = {
    "premier": {
        "name": "Premier",
        "classes": ["social", "ffmpeg", "avidemux"],
        "videos-path": "/Prove/Shullani/GNN-Video-Features/Premier/",
        "graph-training": "./graph-training/"
    }
}
dataset_name = "premier"
GRAPHS = "/Prove/Shullani/GNN-Video-Features/small-graphs/"

In [None]:
def write_collection(txt_filename, video_list):                                      
    with open(txt_filename, "w") as f:                                                                                                      
        for item in video_list:  
            f.write(item)
            f.write('\n')

def write_csv_collection(csv_filename, video_list):                                          
    with open(csv_filename, "w") as f:                                                 
        writer = csv.writer(f)                                                           
        for item in video_list:                                                          
            writer.writerow(item)  
    
def get_row(path):
    dirs_split = path.split(os.sep)
    video_name = dirs_split[-1].replace(".bin", "")
    label = dirs_split[-2]
    return "{}{}{}".format(label, os.sep, video_name)


def get_split(data_set):
    data = []
    for item in data_set:
        data.append(get_row(item))
    return data


def get_all_files(graphs_path, dataset_name, experiment_name="social"):
    premier = glob.glob(os.path.join(graphs_path, dataset_name, experiment_name, "**/*.bin"))
    device_list = []
    for item in premier:
        device_list.append(os.path.basename(item).split("_")[0])
    device_list = set(device_list)
    return premier, device_list


def get_train_test_files(graphs_path, dataset_name, experiment_name="social"):
    test_set = {}
    train_set = {}
    premier, device_list = get_all_files(graphs_path, dataset_name, experiment_name)
    for dev in device_list:
        test_set[dev] = glob.glob(os.path.join(graphs_path, dataset_name, experiment_name, "**/{}_*.bin".format(dev)))
        # build train
        train = set(premier) - set(test_set[dev])
        train_set[dev] = list(train)
        print(f"{dev}| test: {len(test_set[dev])}| train: {len(train_set[dev])}")
    return train_set, test_set


def get_train_test_valid_files(graphs_path, dataset_name, valid_perc=0.1, experiment_name="social"):
    test_set = {}
    train_set = {}
    valid_set = {}
    premier, device_list = get_all_files(graphs_path, dataset_name, experiment_name)
    valid_devices = math.floor(len(device_list)*valid_perc)
    for dev in device_list:
        test_set[dev] = glob.glob(os.path.join(graphs_path, dataset_name, experiment_name, "**/{}_*.bin".format(dev)))

        # build valid
        valid_dev_list = list(device_list)
        valid_dev_list.remove(dev)
        random.shuffle(valid_dev_list)
        valid_set[dev] = []
        for v_dev in valid_dev_list[:valid_devices]:
            tmp = glob.glob(os.path.join(graphs_path, dataset_name, experiment_name, "**/{}_*.bin".format(v_dev)))
            valid_set[dev].extend(tmp)
        
        # build train
        train = set(premier) - set(test_set[dev]) - set(valid_set[dev])
        train_set[dev] = list(train)
        print(f"{dev}| test: {len(test_set[dev])}| train: {len(train_set[dev])} | valid: {len(valid_set[dev])}")
    return train_set, test_set, valid_set


def write_split(output_path, data_set, set_type, experiment_name="social"):
     for dev in data_set.keys():
            data = get_split(data_set[dev])
            data_path = os.path.join(output_path, "{}_{}_{}.txt".format(dev, experiment_name, set_type))
            print(data_path)
            write_collection(data_path, data)

# Build Train/Valid/Test leave-one-device-out

In [None]:
data_info = DATASET_INFO[dataset_name]
splits_dict = {i_class:{"train":[], "valid":[], "test":[]} for i_class in data_info["classes"]}
for i_class in data_info["classes"]:
    print(i_class, "-------------------------------------")
    splits_dict[i_class]["train"], splits_dict[i_class]["test"], splits_dict[i_class]["valid"] = get_train_test_valid_files(GRAPHS, dataset_name= data_info["name"], experiment_name=i_class)

# write splits to path

In [None]:
if dataset_name == "premier":
    for i_class in data_info["classes"]:
        output_path = os.path.join(data_info["videos-path"], i_class)
        if i_class == "social":
            write_split(output_path, splits_dict[i_class]["train"], "train", experiment_name=i_class)
            write_split(output_path, splits_dict[i_class]["test"], "test", experiment_name=i_class)
            write_split(output_path, splits_dict[i_class]["valid"], "valid", experiment_name=i_class)
        else:
            write_split(output_path, splits_dict["social"]["train"], "train", experiment_name=i_class)
            write_split(output_path, splits_dict[i_class]["test"], "test", experiment_name=i_class)
            write_split(output_path, splits_dict["social"]["valid"], "valid", experiment_name=i_class)
else:
    for i_class in data_info["classes"]:
        output_path = os.path.join(data_info["videos-path"], i_class)
        write_split(output_path, splits_dict[i_class]["train"], "train", experiment_name=i_class)
        write_split(output_path, splits_dict[i_class]["test"], "test", experiment_name=i_class)
        write_split(output_path, splits_dict[i_class]["valid"], "valid", experiment_name=i_class)

# build graph premier training 

In [None]:
test_paths = {i_class:[] for i_class in data_info["classes"]}
for key in test_paths.keys():
    local_path = os.path.join(data_info["videos-path"], key, "*_test.txt")
    # collect all test.txt files
    test_paths[key] = glob.glob(local_path)
    local_class = []
    for item in sorted(test_paths[key]):
        test = os.path.basename(item).split(".")[0]
        train = test.replace("_test", "_train")
        valid = test.replace("_test", "_valid")
        local_class.append([data_info["name"]+"-"+key, train, test, valid])
    write_csv_collection(os.path.join(data_info["graph-training"],f"graph_{dataset_name}_{key}_training.csv"), local_class)