In [None]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy
import matplotlib.pyplot as plt
from collections import defaultdict
import string

In [2]:
DATA_ROOT = "/data/locus/project_data/project_data2/chenwu2/creativity_data"

In [3]:
def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]

In [4]:
def form_triangle(hash_str, a, b, c):
    input_text = "".join([hash_str, " tri: "])
    target_text = input_text + "".join([a, b, "<sep>", b, c, "<sep>", c, a, "</a>"])
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item


def form_triangle_test(hash_str):
    input_text = "".join([hash_str, " tri: "])
    target_text = input_text + "".join(["</a>"])  # Placeholder
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item


def form_edge(u, v):
    input_text = "".join(["edge: "])
    target_text = input_text + "".join([u, v, "<sep>", v, u, "</a>"])
    item = {
        "input_text": input_text,
        "target_text": target_text
    }
    return item


In [None]:
D = 3  # Rough max degree
alpha = 1.2  # Max degree flexibility factor
T = 6  # Additional triangles per vertex
num_nodes = 999  # Number of vertices (999 previously)
triangle_prob = 1/3
num_samples = 15000


def generate_graph_with_triangles(D, alpha, T, num_nodes):
    # Initialize the graph as an adjacency list using a dictionary
    graph = {"<a_{}>".format(i): list() for i in range(num_nodes)}

    # Helper function to get degree of a node
    def degree(node):
        return len(graph[node])

    # Iterate over vertices to connect them based on the degree constraint
    for v in graph.keys():
        # Create a pool of all non-adjacent vertices u whose degree(u) <= alpha * D
        non_adjacent = [u for u in graph.keys() if u != v and u not in graph[v] and degree(u) <= alpha * D]
        
        # Determine how many vertices to connect to v
        needed_edges = max(0, D - degree(v))

        # Randomly sample vertices from the pool
        sampled_vertices = random.sample(non_adjacent, min(needed_edges, len(non_adjacent)))

        # Add edges between v and the sampled vertices
        for u in sampled_vertices:
            graph[v].append(u)
            graph[u].append(v)

    # Initialize a dictionary to track the number of triangles added to each vertex
    num_added_triangles = {node: 0 for node in graph.keys()}

    # Add T triangles for each vertex
    for u in graph.keys():
        while num_added_triangles[u] < T:
            # Sample two neighbors of u
            neighbors = list(graph[u])
            if len(neighbors) < 2:
                print("Not enough neighbors to form a triangle")
                break  # Not enough neighbors to form a triangle

            v, w = random.sample(neighbors, 2)

            # Add an edge between v and w if it doesn’t already exist
            if w not in graph[v]:
                graph[v].append(w)
                graph[w].append(v)

            # Increment the triangle count for u, v, and w
            num_added_triangles[u] += 1
            num_added_triangles[v] += 1
            num_added_triangles[w] += 1

    return graph


def build_dataset(hash_str_len):
 
    entities_vocab = ["<a_{}>".format(i) for i in range(num_nodes)]

    edges = generate_graph_with_triangles(D, alpha, T, num_nodes)

    # Instead of generating all indices at once, generate hash strings directly
    chars = string.ascii_lowercase + string.digits
    base = len(chars)
    used_hashes = set()  # Keep track of used hash strings
    
    train_sequences, test_sequences = [], []
    for _ in tqdm(range(num_samples)):
        if random.random() < triangle_prob:
            # Try to generate a triangle
            triangle_found = False
            while not triangle_found:
                u = random.choice(list(edges.keys()))
                neighbors = list(edges[u])
                if len(neighbors) < 2:
                    break

                v, w = random.sample(neighbors, 2)

                # Generate a unique hash string
                while True:
                    # Generate random digits and convert to hash string
                    hash_digits = [random.randint(0, base-1) for _ in range(hash_str_len)]
                    hash_str = ''.join(chars[d] for d in hash_digits)
                    if hash_str not in used_hashes:
                        used_hashes.add(hash_str)
                        break

                if w in edges[v]:
                    train_sequences.append(form_triangle(hash_str, u, v, w))
                    triangle_found = True
                else:
                    continue
        else:
            u = random.choice(list(edges.keys()))
            # Generate a single edge datapoint
            neighbors = list(edges[u])
            if neighbors:
                v = random.choice(neighbors)
                train_sequences.append(form_edge(u, v))
                
    for _ in range(1024):
        # Generate a unique hash string
        while True:
            # Generate random digits and convert to hash string
            hash_digits = [random.randint(0, base-1) for _ in range(hash_str_len)]
            hash_str = ''.join(chars[d] for d in hash_digits)
            if hash_str not in used_hashes:
                used_hashes.add(hash_str)
                break
            
        test_sequences.append(form_triangle_test(hash_str))
    
    return entities_vocab, train_sequences, test_sequences, edges

HASH_STR_LEN = 10

entity_vocab, train_sequences, test_sequences, edges = build_dataset(HASH_STR_LEN)

In [None]:
# Get the number of all possible triangles
triangle_num = 0
for u in edges.keys():
    neighbors = list(edges[u])
    if len(neighbors) >= 2:
        for v in neighbors:
            for w in neighbors:
                if v != w and w in edges[v]:
                    triangle_num += 1
print("triangle_num:", triangle_num)

# Get the number of all possible edges
edge_num = 0
for u in edges.keys():
    edge_num += len(edges[u])
print("edge_num:", edge_num)

In [None]:
vocab = []
vocab = vocab + entity_vocab
# special tokens
vocab = vocab + ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]
assert len(vocab) == len(set(vocab))
print("vocab size:", len(vocab))

In [8]:
test_size = 1024
test_sequences = choose(test_sequences, test_size)

In [None]:
print(len(train_sequences))

In [None]:
# downsampling train_inferred
dataset_name = "triangle.{}".format(HASH_STR_LEN)
if T != 6:
    dataset_name = dataset_name + ".T{}".format(T)
os.makedirs(os.path.join(DATA_ROOT, dataset_name), exist_ok=True)
train_sequences_ds = train_sequences

# Unique input_text
input_texts = [item["input_text"] for item in train_sequences_ds]
unique_input_texts = list(set(input_texts))

print(len(unique_input_texts))
print(len(train_sequences_ds))

probes = []
for item in choose(train_sequences_ds, test_size):
    probes.append(deepcopy(item))
    probes[-1]['type'] = 'train'

for item in test_sequences:
    probes.append(deepcopy(item))
    probes[-1]['type'] = 'test'

with open(os.path.join(DATA_ROOT, dataset_name, "train.json"), "w", encoding='utf-8') as f:
    json.dump(train_sequences_ds, f)
with open(os.path.join(DATA_ROOT, dataset_name, "valid.json"), "w", encoding='utf-8') as f:
    json.dump(test_sequences, f)
with open(os.path.join(DATA_ROOT, dataset_name, "test.json"), "w", encoding='utf-8') as f:
    json.dump(probes, f)
# add vocab
with open(os.path.join(DATA_ROOT, dataset_name, "vocab.json"), "w", encoding='utf-8') as f:
    json.dump(vocab, f)
# add edges
with open(os.path.join(DATA_ROOT, dataset_name, "edges.json"), "w", encoding='utf-8') as f:
    json.dump(edges, f)