In [1]:
import torch
import os
import random

import torch_geometric


from tqdm.notebook import tqdm_notebook

from torch_geometric.data import Data, Batch

from pathlib import Path

import zstandard as zstd
from io import BytesIO

In [2]:
def load_compressed_batch(file_path):
    with open(file_path, 'rb') as f:
        dctx = zstd.ZstdDecompressor()
        decompressor = dctx.decompressobj()
        decompressed_data = decompressor.decompress(f.read())
        decompressed_buffer = BytesIO(decompressed_data)
        graph_batch = torch.load(decompressed_buffer)

    return graph_batch

In [3]:
def load_all_and_merge(number_reference_dict, data_dir):
    """
    Args:
        number_reference_dict: dictionary with keys as atom letters and values as the reference number batch to load
    """
    all_graphs = []
    for key, value in tqdm_notebook(number_reference_dict.items()):
        subdir_list = os.listdir(data_dir / key)
        #get the path of the batch with the reference number among the subdir_list
        for subdir in subdir_list:
            if value == subdir.split('_')[-1].split('.')[0]:
                batch_path = data_dir / key / subdir
                break

        graph_batch = load_compressed_batch(batch_path)
        all_graphs += graph_batch.to_data_list()
    
    #shuffling the list of graphs
    random.shuffle(all_graphs)

    return all_graphs
    

In [20]:
def merge_under_represented(atom_letter, first_number, final_number, data_dir):
    graph_list = []
    for i in tqdm_notebook(range(first_number, final_number + 1)):
        subdir = data_dir / atom_letter / '{}10000_{}.zst'.format(atom_letter, i)
        graph_batch = load_compressed_batch(subdir)
        graph_list += graph_batch.to_data_list()
    
    graph_batch = Batch.from_data_list(graph_list)
    return graph_batch  
    
def save_graph_batch_zst(graph_batch, file_path):
    with open(file_path, 'wb') as f:
            cctx = zstd.ZstdCompressor()
            compressor = cctx.stream_writer(f)
            torch.save(graph_batch, compressor)
            compressor.flush(zstd.FLUSH_FRAME)

In [39]:
data_dir = Path('..') / '..' / 'DataPipeline' / 'data' / 'prepared_dataset'
graph_batch = merge_under_represented('F', 51, 60, data_dir)
save_graph_batch_zst(graph_batch, data_dir / 'F' / 'F100000_6.zst')

  0%|          | 0/10 [00:00<?, ?it/s]

In [22]:
graph_batch

DataBatch(x=[1303092, 7], edge_index=[2, 2623676], edge_attr=[2623676, 4], y=[700000], batch=[1303092], ptr=[100001])

In [12]:
number_reference_dict = {'C' : '123', 'N' : '24', 'O' : '27', 'S' : '25', 'F' : '26', 'Cl' : '26', 'stop' : '525'}

data_dir = Path('..') / '..' / 'DataPipeline' / 'data' / 'prepared_dataset'



In [13]:
all_graphs = load_all_and_merge(number_reference_dict, data_dir)

  0%|          | 0/7 [00:00<?, ?it/s]

In [18]:
len(all_graphs)

430000

In [14]:
def compute_average_y(graph_list):
    y_list = []
    for graph in graph_list:
        y_list.append(graph.y)
    y_list = torch.stack(y_list)
    print(y_list)
    return torch.sum(y_list, dim=0)

In [15]:
def count_number_of_occurence(graph_list):
    # for each label (between 1 and 7) count the number of occurence of the label 0.5 in the graph list 
    occurence_dict = {}
    for graph in graph_list:
        for i, label in enumerate(graph.y):
            if label < 0.6 and label > 0.4:
                if i in occurence_dict:
                    occurence_dict[i] += 1
                else:
                    occurence_dict[i] = 1
    return occurence_dict

In [16]:
count_number_of_occurence(all_graphs)

{1: 4227, 2: 6909, 0: 5887, 4: 100, 3: 277, 5: 48}

In [17]:
compute_average_y(all_graphs)

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 1.]])


tensor([101836.5000,  99431.5000,  98829.5000,   9940.5000,   9977.0000,
          9985.0000, 100000.0000])

In [11]:
count_number_of_occurence(all_graphs)

{1: 3580, 2: 4688, 0: 2203, 3: 219, 5: 44, 4: 54}

In [7]:
compute_average_y(all_graphs)

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])


tensor([ 98898.5000,  98685.0000, 102313.0000, 100054.5000,  10027.0000,
         10022.0000, 100000.0000])