In [1]:
import os
import pandas as pd
import numpy as np
import csv
import json

## Plot


In [3]:
import json
from collections import Counter
import pandas as pd
import kaleido
import plotly.express as px

# Load your JSON data here
with open('../data/2018/design2018_cpc.json') as file:
    json_data = json.load(file)

# Count occurrences of each CPC code
cpc_counter = Counter()
for entry in json_data:
    for cpc in entry['cpc']:
        cpc_counter[cpc] += 1

# Aggregate counts by CPC classes (first 3 characters) and main classes (first character)
class_counter = Counter()
main_class_counter = Counter()

for cpc, count in cpc_counter.items():
    class_counter[cpc[:3]] += count
    main_class_counter[cpc[0]] += count

# Construct DataFrame for hierarchical plot
df_hierarchy = pd.DataFrame([
    ('CPC', None, sum(main_class_counter.values())),
] + [
    (f'{mc}', 'CPC', count) for mc, count in main_class_counter.items()
] + [
    (f'{cl}', f'{cl[0]}', count) for cl, count in class_counter.items()
] + [
    (cpc, f'{cpc[:3]}', count) for cpc, count in cpc_counter.items()
], columns=['id', 'parent', 'value'])

fig = px.sunburst(
    df_hierarchy,
    names='id',
    parents='parent',
    values='value',
    color_discrete_sequence=px.colors.qualitative.Pastel,
)

fig.update_traces(
    branchvalues='total',  # Ensures hierarchy spreads properly
    sort=False             # Optional: disables sorting which can distort layout
)

fig.update_layout(
    uniformtext=dict(minsize=10),
    height=700,
    width=700
)

fig.write_image('cpc_hierarchy_plot.svg')  # Uncomment if you want to save to PNG
fig.show()

In [7]:
import json
from collections import Counter
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# Load your JSON data here
with open('../data/2018/design2018_cpc.json') as file:
    json_data = json.load(file)

# Count occurrences of each CPC code
cpc_counter = Counter()
for entry in json_data:
    for cpc in entry['cpc']:
        cpc_counter[cpc] += 1

# Aggregate counts by CPC classes (first 3 characters) and main classes (first character)
class_counter = Counter()
main_class_counter = Counter()

for cpc, count in cpc_counter.items():
    class_counter[cpc[:3]] += count
    main_class_counter[cpc[0]] += count

# Construct DataFrame for hierarchical plot
df_hierarchy = pd.DataFrame(
    [('CPC', None, sum(main_class_counter.values()))] +
    [(mc, 'CPC', cnt) for mc, cnt in main_class_counter.items()] +
    [(cl, cl[0], cnt) for cl, cnt in class_counter.items()] +
    [(cpc, cpc[:3], cnt) for cpc, cnt in cpc_counter.items()],
    columns=['id', 'parent', 'value']
)

# Define pastel colors per main family
families = list(main_class_counter.keys())
palette = px.colors.qualitative.Pastel
# Extend palette if needed
colors = palette * ((len(families) // len(palette)) + 1)
family_colors = {fam: colors[i] for i, fam in enumerate(families)}

# Assign fill colors
def get_fill_color(label):
    if label == 'CPC':
        return '#FFFFFF'
    if label in family_colors:
        return family_colors[label]
    return family_colors[label[0]]

fill_colors = [get_fill_color(lbl) for lbl in df_hierarchy['id']]

# Highlight borders for A, A01, A01A
highlight = {'A', 'A01', 'A01A'}
line_colors = ['black' if lbl in highlight else '#FFFFFF' for lbl in df_hierarchy['id']]
line_widths = [3 if lbl in highlight else 0 for lbl in df_hierarchy['id']]

# Create sunburst chart
fig = go.Figure(go.Sunburst(
    labels=df_hierarchy['id'],
    parents=df_hierarchy['parent'],
    values=df_hierarchy['value'],
    branchvalues='total',
    marker=dict(
        colors=fill_colors,
        line=dict(color=line_colors, width=line_widths)
    ),
    insidetextorientation='radial'
))

# Layout adjustments for full circle
fig.update_layout(
    uniformtext=dict(minsize=10),
    height=700,
    width=700,
    margin=dict(t=40, l=20, r=20, b=20)
)

# Save to PNG (ensure kaleido is installed)
fig.write_image('cpc_hierarchy_highlight.svg')
fig.show()


In [10]:
import json
from collections import Counter
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# Load your JSON data here
with open('../data/2018/design2018_cpc.json') as file:
    json_data = json.load(file)

# Count occurrences of each CPC code
cpc_counter = Counter()
for entry in json_data:
    for cpc in entry['cpc']:
        cpc_counter[cpc] += 1

# Aggregate counts by CPC classes (first 3 characters) and main classes (first character)
class_counter = Counter()
main_class_counter = Counter()

for cpc, count in cpc_counter.items():
    class_counter[cpc[:3]] += count
    main_class_counter[cpc[0]] += count

# Construct DataFrame for hierarchical plot
df_hierarchy = pd.DataFrame(
    [('CPC', None, sum(main_class_counter.values()))] +
    [(mc, 'CPC', cnt) for mc, cnt in main_class_counter.items()] +
    [(cl, cl[0], cnt) for cl, cnt in class_counter.items()] +
    [(cpc, cpc[:3], cnt) for cpc, cnt in cpc_counter.items()],
    columns=['id', 'parent', 'value']
)


# Define pastel colors per main family
families = list(main_class_counter.keys())
palette = px.colors.qualitative.Pastel
colors = palette * ((len(families) // len(palette)) + 1)
family_colors = {fam: colors[i] for i, fam in enumerate(families)}

# Assign fill colors based on family
fill_colors = []
for lbl in df_hierarchy['id']:
    if lbl == 'CPC':
        fill_colors.append('#FFFFFF')
    elif lbl in family_colors:
        fill_colors.append(family_colors[lbl])
    else:
        fill_colors.append(family_colors[lbl[0]])

# Highlight borders for specific nodes and increase default white gaps
highlight = {'A', 'A01', 'A01G'}
line_colors = []
line_widths = []
for lbl in df_hierarchy['id']:
    if lbl in  highlight:
        # thicker black border for highlighted nodes
        line_colors.append('black')
        line_widths.append(2.5)
    else:
        # default white border for separation
        line_colors.append('white')
        line_widths.append(0.75)

# Create sunburst chart with increased separation
fig = go.Figure(go.Sunburst(
    labels=df_hierarchy['id'],
    parents=df_hierarchy['parent'],
    values=df_hierarchy['value'],
    branchvalues='total',
    marker=dict(
        colors=fill_colors,
        line=dict(color=line_colors, width=line_widths)
    ),
    insidetextorientation='radial'
))

# Layout adjustments for full circle
fig.update_layout(
    
    uniformtext=dict(minsize=10),
    height=700,
    width=700,
    margin=dict(t=40, l=20, r=20, b=20)
)

# Save to PNG (ensure kaleido is installed)
fig.write_image('cpc_hierarchy_highlight.svg')
fig.show()


In [None]:
# Define pastel colors per main family
families = list(main_class_counter.keys())
palette = px.colors.qualitative.Pastel
colors = palette * ((len(families) // len(palette)) + 1)
family_colors = {fam: colors[i] for i, fam in enumerate(families)}

# Assign fill colors based on family
fill_colors = []
for lbl in df_hierarchy['id']:
    if lbl == 'CPC':
        fill_colors.append('#FFFFFF')
    elif lbl in family_colors:
        fill_colors.append(family_colors[lbl])
    else:
        fill_colors.append(family_colors[lbl[0]])

In [8]:
df_hierarchy

Unnamed: 0,id,parent,value
0,All,,727402
1,Main Class A,All,232539
2,Main Class B,All,158097
3,Main Class G,All,96837
4,Main Class Y,All,16775
...,...,...,...
875,H25B,Class H25,7
876,Y02W,Class Y02,6
877,D04G,Class D04,22
878,C10J,Class C10,7


## CPC

### Main dataset (patent->CPC)

In [2]:
# Directory containing the files
directory_path = "patent_cpc_data"

# Initialize a list to store data from all files
all_data = []

# Loop through each file in the directory
for filename in os.listdir(directory_path):
    # Construct the full file path
    file_path = os.path.join(directory_path, filename)

    # Check if it's a .txt file
    if filename.endswith(".txt"):
        print(f"Processing file: {filename}")

        # Read the file and extract Patent ID and CPC
        with open(file_path, "r") as file:
            for line in file:
                if len(line) > 22:
                    try:
                        # Extract Patent ID (11 digits after the 10th position)
                        patent_id = line[10:21]

                        # Extract Main CPC (1 characters after the Patent ID)
                        main_cpc = line[21]

                        # Extract Big CPC (2 characters after the Patent ID)
                        big_cpc = line[21:24]

                        # Extract Medium CPC (4 characters after the Patent ID)
                        medium_cpc = line[21:25]

                        # Extract Refined CPC (remainder of the CPC code until a space)
                        refined_cpc = line[25:].split()[0]

                        # Append to the data list
                        all_data.append(
                            {
                                "Patent ID": patent_id,
                                "Main CPC": main_cpc,
                                "Big CPC": big_cpc,
                                "Medium CPC": medium_cpc,
                                "Refined CPC": refined_cpc,
                            }
                        )
                    except IndexError:
                        print(f"Skipping invalid line in {filename}: {line.strip()}")

# Convert the list of dictionaries to a dataframe
df = pd.DataFrame(all_data)

# Display the combined dataframe
print(df)

# Save the dataframe to a CSV file (optional)
output_file = "combined_patent_cpc_data.csv"
df.to_csv(output_file, index=False)
print(f"Combined data saved to {output_file}")

FileNotFoundError: [Errno 2] No such file or directory: 'patent_cpc_data'

In [None]:
#Json with the patents and cpc codes
json_path = "../data/2018/design2018_cpc.json"

# Initialize a list to store data from all files
all_data = []
patents=set()

# Read the file and extract Patent ID and CPC
with open(json_path, "r") as file:
    data = json.load(file)
    for image in data:
        for cpc in image["cpc"]:
            # Extract Patent ID 
            figure_id = image["subfigure_file"]
            patent_id = image["patentID"]
            # Extract Main CPC 
            medium_cpc = cpc

            # Append to the data list
            all_data.append(
                {
                    "Figure ID": figure_id,
                    "Patent ID": patent_id,
                    "Main CPC": medium_cpc[0],
                    "Big CPC": medium_cpc[:3],
                    "Medium CPC": medium_cpc,
                }
            )
            

            
# Convert the list of dictionaries to a dataframe
df = pd.DataFrame(all_data)
print(df)

# Dataframe to a CSV file 
output_file = "data/2018/combined_design_patent_cpc_data.csv"
df.to_csv(output_file, index=False)
print(f"Combined data saved to {output_file}")

                               Figure ID            Patent ID Main CPC  \
0       USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
1       USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
2       USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
3       USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
4       USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
...                                  ...                  ...      ...   
727397  USD0830379-20181009-D00002_3.png  USD0830379-20181009        G   
727398  USD0830379-20181009-D00002_3.png  USD0830379-20181009        H   
727399  USD0830379-20181009-D00002_3.png  USD0830379-20181009        G   
727400  USD0834328-20181127-D00006_6.png  USD0834328-20181127        A   
727401  USD0834328-20181127-D00006_6.png  USD0834328-20181127        A   

       Big CPC Medium CPC  
0          A23       A23L  
1          A47       A47G  
2          A23       A23J  

In [23]:
import xml.etree.ElementTree as ET

# Path to the .xsd file
xsd_file = "./venv/CPCDefinitionsSchema10.xsd"

# Parse the .xsd file
try:
    tree = ET.parse(xsd_file)
    root = tree.getroot()

    namespace = {"xs": "http://www.w3.org/2001/XMLSchema"}

    # Find all elements with name 'definitions'
    definitions = root.findall('.//xs:element[@name="definition-item"]', namespace)

    # Extract details of each 'definitions' element
    for definition in definitions:
        print(f"Definition Element: {definition.attrib}")

    def print_tree(element, level=0):
        indent = "  " * level
        print(f"{indent}<{element.tag} {dict(element.attrib)}>")
        for child in element:
            print_tree(child, level + 1)
        print(f"{indent}</{element.tag}>")

    # Print the root and its children
    print_tree(root)
except FileNotFoundError:
    print(f"The file {xsd_file} was not found.")
except ET.ParseError as e:
    print(f"Error parsing the XSD file: {e}")

Definition Element: {'name': 'definition-item', 'type': 'definition-item-type'}
<{http://www.w3.org/2001/XMLSchema}schema {}>
  <{http://www.w3.org/2001/XMLSchema}annotation {}>
    <{http://www.w3.org/2001/XMLSchema}documentation {}>
    </{http://www.w3.org/2001/XMLSchema}documentation>
  </{http://www.w3.org/2001/XMLSchema}annotation>
  <{http://www.w3.org/2001/XMLSchema}element {'name': 'definitions', 'type': 'definitions-type'}>
    <{http://www.w3.org/2001/XMLSchema}annotation {}>
      <{http://www.w3.org/2001/XMLSchema}documentation {}>
      </{http://www.w3.org/2001/XMLSchema}documentation>
    </{http://www.w3.org/2001/XMLSchema}annotation>
  </{http://www.w3.org/2001/XMLSchema}element>
  <{http://www.w3.org/2001/XMLSchema}element {'name': 'abbreviations', 'type': 'section-body-type'}>
    <{http://www.w3.org/2001/XMLSchema}annotation {}>
      <{http://www.w3.org/2001/XMLSchema}documentation {}>
      </{http://www.w3.org/2001/XMLSchema}documentation>
    </{http://www.w3.o

### IPC dataset

In [36]:
import os
import xml.etree.ElementTree as ET


def extract_cpc_and_title_from_xml(file_path):
    """
    Extract CPC and Title from a given XML file.

    Args:
        file_path (str): Path to the XML file.

    Returns:
        list: A list of tuples containing CPC and Title pairs.
    """
    try:
        tree = ET.parse(file_path)
        root = tree.getroot()

        results = []
        # for item in root.findall(".//definition-item"):
        cpc = root.find(".//classification-symbol")
        title = root.find(".//definition-title")
        if cpc is not None and title is not None:
            results.append((cpc.text.strip(), title.text.strip()))
        return results
    except ET.ParseError:
        print(f"Error parsing {file_path}. Skipping...")
        return []


def process_directory(directory_path):
    """
    Process all XML files in a directory and extract CPC and Title.

    Args:
        directory_path (str): Path to the directory containing XML files.

    Returns:
        dict: A dictionary with CPC codes as keys and Titles as values.
    """
    combined_results = {}

    for file_name in os.listdir(directory_path):
        if file_name.endswith(".xml"):
            file_path = os.path.join(directory_path, file_name)
            file_results = extract_cpc_and_title_from_xml(file_path)
            for cpc, title in file_results:
                # Combine results, overwrite if duplicate CPC is found
                combined_results[cpc] = title

    return combined_results


# Directory containing XML files
directory_path = "venv/FullCPCDefinitionXML202501"

combined_results = process_directory(directory_path)

# Print the combined results
for cpc, title in combined_results.items():
    print(f"CPC: {cpc}, Title: {title}")

import csv


def save_dict_to_csv(data, output_file):
    """
    Save a dictionary to a CSV file.

    Args:
        data (dict): The dictionary to save (keys as rows, values as columns).
        output_file (str): Path to the output CSV file.
    """
    try:
        with open(output_file, mode="w", newline="", encoding="utf-8") as file:
            writer = csv.writer(file)
            # Write header
            writer.writerow(["CPC", "Title"])
            # Write data
            for cpc, title in data.items():
                writer.writerow([cpc, title])
        print(f"Dictionary saved to {output_file}")
    except Exception as e:
        print(f"Error saving dictionary to CSV: {e}")
        
# Save the combined results to a CSV file
output_csv_path = "cpc_definitions.csv"
save_dict_to_csv(combined_results, output_csv_path)

CPC: G06E, Title: OPTICAL COMPUTING DEVICES; {COMPUTING DEVICES USING OTHER RADIATIONS WITH SIMILAR PROPERTIES} (optical logic elements per se
CPC: A41D, Title: OUTERWEAR; PROTECTIVE GARMENTS; ACCESSORIES
CPC: B64G, Title: COSMONAUTICS; VEHICLES OR EQUIPMENT THEREFOR
CPC: B65C, Title: LABELLING OR TAGGING MACHINES, APPARATUS, OR PROCESSES  (nailing or stapling in general
CPC: A61F, Title: FILTERS IMPLANTABLE INTO BLOOD VESSELS; PROSTHESES; DEVICES PROVIDING PATENCY TO, OR PREVENTING COLLAPSING OF, TUBULAR STRUCTURES OF THE BODY, e.g. STENTS; ORTHOPAEDIC, NURSING OR CONTRACEPTIVE DEVICES; FOMENTATION; TREATMENT OR PROTECTION OF EYES OR EARS; BANDAGES, DRESSINGS OR ABSORBENT PADS; FIRST-AID KITS  (dental prosthetics
CPC: F16M, Title: FRAMES, CASINGS OR BEDS OF ENGINES, MACHINES OR APPARATUS, NOT SPECIFIC TO ENGINES, MACHINES OR APPARATUS PROVIDED FOR ELSEWHERE; STANDS; SUPPORTS
CPC: A61Q, Title: SPECIFIC USE OF COSMETICS OR SIMILAR TOILETRY PREPARATIONS
CPC: D03D, Title: WOVEN FABRICS; M

In [19]:
import pandas as pd
from transformers import CLIPModel, CLIPProcessor
import torch

# Load the CLIP model and processor from Hugging Face
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

# Assuming `medium_ipc` is a DataFrame with CPC codes and titles
results = []
medium_ipc = pd.read_csv("../data/2018/graph/cpc_definitions.csv")

# Iterate through the DataFrame and encode titles using CLIP
for i, row in medium_ipc.iterrows():
    cpc_code = row[0]  
    title = row[1]     
    
    # Truncate the title if it's too long
    title = title[:200]  # Arbitrary cutoff, adjust as needed
    
    try:
        # Preprocess and tokenize the title
        inputs = processor(text=[title], return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)
        
        # Encode the text using the CLIP model
        with torch.no_grad():
            text_features = model.get_text_features(**inputs)  # Extract text embeddings
            text_features = text_features.cpu().numpy()[0]  # Convert to numpy array and get the first item
        
        # Append the CPC code and encoded features to the results
        results.append([cpc_code, text_features])
        
    except Exception as e:
        print(f"Error processing {cpc_code}: {e}")
        continue

# Create a DataFrame with the results
big_features_df = pd.DataFrame(results, columns=["CPC", "Title_Embedding"])

# Save the DataFrame to a pickle file
output_file = "../data/2018/graph/medium_ipc_embeddings.pkl"
big_features_df.to_pickle(output_file)

  cpc_code = row[0]  # Assuming the first column is CPC
  title = row[1]     # Assuming the second column is the title


### Images embeddings

In [18]:
import os
import torch
import pandas as pd
from tqdm import tqdm
import torchvision.io as tvio
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPModel, CLIPProcessor
import numpy as np
import torchvision.transforms as transforms

class ImageDataset(Dataset):
    def __init__(self, image_paths, processor):
        self.image_paths = image_paths
        self.processor = processor
        # Define transforms for tensor inputs
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711]
            )
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        try:
            # Load image using torchvision.io
            image = tvio.read_image(path)
            
            # Convert to float and normalize to [0, 1]
            image = image.float() / 255.0
            
            # Handle grayscale images (1 channel)
            if image.shape[0] == 1:
                image = image.repeat(3, 1, 1)
            # Handle RGBA images (4 channels)
            elif image.shape[0] == 4:
                image = image[:3]
                
            # Apply transforms
            image = self.transform(image)
            
            return image, os.path.basename(path)
        except Exception as err:
            print(f"Error processing {path}: {err}")
            # Return a placeholder in case of error
            return None, os.path.basename(path)

def get_image_paths(directory, list_images, valid_exts=(".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
    """Return a sorted list of full image file paths for the given directory."""
    return sorted([
        os.path.join(directory, fname)
        for fname in os.listdir(directory)
        if fname.lower().endswith(valid_exts) and fname in list_images
    ])

def compute_embeddings(dataloader, model, device):
    """Process the image files using DataLoader and compute image embeddings."""
    image_names = []
    embeddings = []

    with torch.no_grad():
        for batch_images, batch_names in tqdm(dataloader, desc="Processing batches"):
            # Filter out None values (failed images)
            valid_indices = [i for i, img in enumerate(batch_images) if img is not None]
            if not valid_indices:
                continue
                
            valid_images = torch.stack([batch_images[i] for i in valid_indices]).to(device)
            valid_names = [batch_names[i] for i in valid_indices]
            
            # Get image features
            batch_features = model.get_image_features(pixel_values=valid_images)
            
            # Normalize the features
            batch_features = batch_features / batch_features.norm(dim=1, keepdim=True)
            
            # Convert tensor features to numpy arrays
            batch_embeddings = batch_features.cpu().numpy()

            embeddings.extend(batch_embeddings)
            image_names.extend(valid_names)

    return image_names, embeddings

def collate_fn(batch):
    """Custom collate function to handle None values."""
    images = []
    names = []
    for image, name in batch:
        images.append(image)
        names.append(name)
    return images, names

def main():
    # Set the directory containing your images
    image_dir = "../data/2018/test_query"  # <-- update this path
    
    # Set output pickle file path
    output_file = "../data/2018/graph/query_images_embeddings_model_9_v3.pkl"

    # Set the device. Use GPU if available.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load the CLIP model and processor from Hugging Face
    model = CLIPModel.from_pretrained(f'models/patent-wise/fine_tuned_clip_model_9_v2').to(device)
    processor = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711]
            )
        ])
    model.eval()

    # Get all image paths from the directory
    print(f"Scanning directory: {image_dir}")
    image_paths = get_image_paths(image_dir, images_graph)
    if not image_paths:
        print(f"No images found in directory: {image_dir}")
        return

    # Create dataset and dataloader
    dataset = ImageDataset(image_paths, processor)
    dataloader = DataLoader(
        dataset,
        batch_size=128,
        shuffle=False,
        num_workers=32,
        pin_memory=True,
        collate_fn=collate_fn,
        prefetch_factor=16,
        persistent_workers=True
    )

    print(f"Processing {len(image_paths)} images with DataLoader")

    # Compute embeddings for all images
    names, embeddings = compute_embeddings(dataloader, model, device)

    # Create a DataFrame to store results
    df = pd.DataFrame({
        "image_name": names,
        "embedding": list(embeddings)
    })

    # Save the DataFrame to a pickle file
    df.to_pickle(output_file)
    print(f"Embeddings saved for {len(df)} images in '{output_file}'")
    
    # Print sample of the dataframe
    print("\nDataFrame sample:")
    print(df.head())
    print(f"\nEmbedding shape: {df['embedding'].iloc[0].shape}")

if __name__ == "__main__":
    main()

Using device: cuda
Scanning directory: ../data/2018/test_query
Processing 27101 images with DataLoader


Processing batches: 100%|██████████| 212/212 [01:03<00:00,  3.33it/s]


Embeddings saved for 27101 images in '../data/2018/graph/query_images_embeddings_model_9_v3.pkl'

DataFrame sample:
                          image_name  \
0   USD0806351-20180102-D00002_5.png   
1  USD0806351-20180102-D00004_12.png   
2   USD0806352-20180102-D00002_2.png   
3   USD0806352-20180102-D00007_7.png   
4   USD0806353-20180102-D00006_6.png   

                                           embedding  
0  [0.01102656, -0.036491994, -0.009097237, -0.00...  
1  [0.01062351, -0.03684217, -0.008084138, -0.006...  
2  [0.009231105, -0.032987878, -0.009340418, -0.0...  
3  [0.010716212, -0.036647614, -0.011199252, -0.0...  
4  [0.010633418, -0.033302743, -0.009265652, -0.0...  

Embedding shape: (512,)


### IPC Big

In [2]:
import os
import re
import pandas as pd

# Directory containing .txt files
directory = "../data/IPC definitions"

# Regex pattern for 3-digit codes
pattern = r"^([A-Z]\d{2})\t\t(.+)"

# List to store extracted data
data = []

# Iterate through all files in the directory
for filename in os.listdir(directory):
    if filename.endswith(".txt"):
        filepath = os.path.join(directory, filename)

        # Open and read the file
        with open(filepath, "r", encoding="utf-8") as file:
            for line in file:
                # Match the pattern and extract the code and description
                match = re.match(pattern, line)
                if match:
                    code, description = match.groups()
                    data.append({"Code": code, "Description": description})

# Create a DataFrame from the data
big_ipc = pd.DataFrame(data)

# Display the DataFrame
print(big_ipc)

    Code                                        Description
0    E01        CONSTRUCTION OF ROADS, RAILWAYS, OR BRIDGES
1    E02  HYDRAULIC ENGINEERING; FOUNDATIONS; SOIL SHIFTING
2    E03                             WATER SUPPLY; SEWERAGE
3    E04                                           BUILDING
4    E05        LOCKS; KEYS; WINDOW OR DOOR FITTINGS; SAFES
..   ...                                                ...
132  G11                                INFORMATION STORAGE
133  G12                                 INSTRUMENT DETAILS
134  G16  INFORMATION AND COMMUNICATION TECHNOLOGY [ICT]...
135  G21               NUCLEAR PHYSICS; NUCLEAR ENGINEERING
136  G99  SUBJECT MATTER NOT OTHERWISE PROVIDED FOR IN T...

[137 rows x 2 columns]


In [30]:
import pandas as pd
from transformers import CLIPModel, CLIPProcessor
import torch

# Load the CLIP model and processor from Hugging Face
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

# Assuming `big_ipc` is a DataFrame with CPC codes and titles
results = []

# Iterate through the DataFrame and encode titles using CLIP
for i, row in big_ipc.iterrows():
    cpc_code = row[0]  # Assuming the first column is CPC
    title = row[1]     # Assuming the second column is the title

    # Preprocess and tokenize the title
    inputs = processor(text=[title], return_tensors="pt", padding=True).to(device)

    # Encode the text using the CLIP model
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)  # Extract text embeddings
        text_features = text_features.cpu().numpy()  # Convert to numpy array if needed

    # Append the CPC code and encoded features to the results
    results.append([cpc_code, text_features])

# Create a DataFrame with the results
big_features_df = pd.DataFrame(results, columns=["CPC", "Title_Embedding"])

# Save the DataFrame to a pickle file
output_file = "big_ipc_embeddings.pkl"
big_features_df.to_pickle(output_file)

  cpc_code = row[0]  # Assuming the first column is CPC
  title = row[1]     # Assuming the second column is the title


### IPC MAIN

In [41]:
data = [
    ["A", "HUMAN NECESSITIES"],
    ["B", "PERFORMING OPERATIONS; TRANSPORTING"],
    ["C", "CHEMISTRY; METALLURGY"],
    ["D", "TEXTILES; PAPER"],
    ["E", "FIXED CONSTRUCTIONS"],
    ["F", "MECHANICAL ENGINEERING; LIGHTING; HEATING; WEAPONS; BLASTING"],
    ["G", "PHYSICS"],
    ["H", "ELECTRICITY"],
    ["Y",
        "GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS",
    ],
]
main_ipc = pd.DataFrame(data, columns=["CPC", "Title"])

In [43]:
results = []
# Iterate through the DataFrame and encode titles using CLIP
for i, row in main_ipc.iterrows():
    cpc_code = row[0]  # Assuming the first column is CPC
    title = row[1]     # Assuming the second column is the title

    # Preprocess and tokenize the title
    inputs = processor(text=[title], return_tensors="pt", padding=True).to(device)

    # Encode the text using the CLIP model
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)  # Extract text embeddings
        text_features = text_features.cpu().numpy()  # Convert to numpy array if needed

    # Append the CPC code and encoded features to the results
    results.append([cpc_code, text_features])

main_features_df = pd.DataFrame(results, columns=["CPC", "Title"])
output_file = "main_ipc_embeddings.pkl"
main_features_df.to_pickle(output_file)

  cpc_code = row[0]  # Assuming the first column is CPC
  title = row[1]     # Assuming the second column is the title


Unnamed: 0,CPC,Title
0,A,"[[0.09007428, 0.11744866, -0.17544757, 0.12049..."
1,B,"[[-0.32633457, 0.06152109, -0.2303406, -0.3052..."
2,C,"[[-0.3054312, 0.13740289, -0.22999647, 0.04931..."
3,D,"[[0.12014803, 0.17657274, 0.14181863, -0.27732..."
4,E,"[[0.12759846, -0.4255452, 0.0639952, 0.0106926..."
5,F,"[[-0.34825346, 0.36170906, -0.35306203, 0.0689..."
6,G,"[[-0.27509984, 0.14176705, -0.002594471, -0.18..."
7,H,"[[-0.13190566, -0.22574532, 0.16931647, 0.1632..."
8,Y,"[[0.28983444, -0.47066772, -0.07885873, -0.111..."


### Patent title

In [None]:
#Json with the patents and cpc codes
json_path = "../data/2018/design2018_cpc.json"

# Initialize a list to store data from all files
all_data = []
patents=set()

# Read the file and extract Patent ID and CPC
with open(json_path, "r") as file:
    data = json.load(file)
    for image in data:
        if image["patentID"] not in patents:
            
            patent_id = image["patentID"]
                # Extract Main CPC 
            title = image["object_title"]

                # Append to the data list
            all_data.append(
                {
                "Patent ID": patent_id,
                "Title": title,
                }
                )
            patents.add(image["patentID"])

            
# Convert the list of dictionaries to a dataframe
df_pat = pd.DataFrame(all_data)
print(df_pat)

# Dataframe to a CSV file 
output_file = "patent_title_data.csv"
df_pat.to_csv(output_file, index=False)
print(f"Combined data saved to {output_file}")

In [50]:
import pandas as pd
from transformers import CLIPModel, CLIPProcessor
import torch

# Load the CLIP model and processor from Hugging Face
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

# Assuming `big_ipc` is a DataFrame with CPC codes and titles
results = []

# Iterate through the DataFrame and encode titles using CLIP
for i, row in df_pat.iterrows():
    patent_id = row[0]  # Assuming the first column is CPC
    title = row[1]     # Assuming the second column is the title

    # Preprocess and tokenize the title
    inputs = processor(text=[title], return_tensors="pt", padding=True).to(device)

    # Encode the text using the CLIP model
    with torch.no_grad():
        text_features = model.get_text_features(**inputs)  # Extract text embeddings
        text_features = text_features.cpu().numpy()  # Convert to numpy array if needed

    # Append the CPC code and encoded features to the results
    results.append([patent_id, text_features])

# Create a DataFrame with the results
patent_features_df = pd.DataFrame(results, columns=["Patent ID", "Title_Embedding"])

# Save the DataFrame to a pickle file
output_file = "patent_embeddings.pkl"
patent_features_df.to_pickle(output_file)

  patent_id = row[0]  # Assuming the first column is CPC
  title = row[1]     # Assuming the second column is the title


In [3]:
csv_file_path = "filtered_pgpub_claims_2012_2014.csv"  # Replace with your CSV file path
# Replace with your desired Parquet file path

# Read the CSV file
df_claims = pd.read_csv(csv_file_path)

In [3]:
df_claims

Unnamed: 0,pub_no,appl_id,claim_no,claim_txt,dependencies,ind_flg
0,20120003120,13159492,1,A method for recovering heat in a device for t...,,1
1,20120003120,13159492,2,"A method according to claim 1, wherein the pre...",1,0
2,20120003120,13159492,3,A device for the sterilization of biological m...,,1
3,20120003150,13231050,15,"The method of claim 1, further comprising admi...",1,0
4,20120003150,13231050,17,"The method of claim 1, wherein said ribonuclea...",1,0
...,...,...,...,...,...,...
19388245,20140380539,13986990,1,1. A new and distinct variety of Phalaenopsis ...,,1
19388246,20140380540,13986991,1,1. A new and distinct variety of Phalaenopsis ...,,1
19388247,20140380541,13986997,1,1. A new and distinct cultivar of Agapanthus p...,,1
19388248,20140380542,13987008,1,1. A new and distinct cultivar of Campanula pl...,,1


In [4]:
df_claims = df_claims.drop(columns=["appl_id", "claim_no", "dependencies", "ind_flg"])

In [5]:
df_claims["claim_txt"] = df_claims["claim_txt"].fillna("").astype(str)

In [6]:
import numpy as np

In [19]:
# Split dataset into 1000 chunks
chunks = np.array_split(df_claims, 1000)


# Function to aggregate claims for a chunk
def process_chunk(chunk):
    return chunk.groupby("pub_no", as_index=False).agg({"claim_txt": " ".join})


# Process each chunk and collect results
results = [process_chunk(chunk) for chunk in chunks[:100]]
final_result = pd.concat(results).reset_index(drop=True)

print(final_result)

  return bound(*args, **kwds)


            pub_no                                          claim_txt
0      20120003120  A method for recovering heat in a device for t...
1      20120003150  The method of claim 1, further comprising admi...
2      20120003151  (Canceled). (Canceled). (Canceled). The method...
3      20120003153  (Canceled). The indolyl-oxadiazolyl-diazabicyc...
4      20120003156  A method for treating neoplasia in a subject, ...
...            ...                                                ...
89296  20120069587  1. A light clip comprising: a base; a magnet a...
89297  20120069588  1. A light interfacing board, comprising: a st...
89298  20120069589  1. An LED landing light arrangement for an air...
89299  20120069590  1. A direction indicator comprising: a directi...
89300  20120069591  1. An exterior mirror vision system for a vehi...

[89301 rows x 2 columns]


In [20]:
chunks[1]

Unnamed: 0,pub_no,claim_txt
19389,20120015887,A pharmaceutical composition for enhancement o...
19390,20120015887,"The 4-copy branched peptide of claim 1, wherei..."
19391,20120015887,"The 4-copy branched peptide of claim 1, wherei..."
19392,20120015887,A 4-copy branched peptide represented by a for...
19393,20120015889,The method of claim 18 wherein said primary tu...
...,...,...
38773,20120034231,The antibody or binding protein according to c...
38774,20120034231,A process for the production of the antibody o...
38775,20120034231,"The antibody of claim 4, wherein the mutated o..."
38776,20120034231,(Canceled).


In [8]:
print(final_result.head(1)["claim_txt"].values[0])

A method for recovering heat in a device for the sterilization of biological material, comprising transferring heat from a sterilized effluent stream to a stream in a heat recovery circuit transferring heat from the stream in the heat recovery circuit to a stream of biologically contaminated feed while maintaining the pressure (p12) in the sterilized effluent stream higher than the pressure (p10) in the heat recovery circuit, which is maintained higher than the pressure (p11) in the stream of biologically contaminated feed. A method according to claim 1, wherein the pressure p12 in the sterilized effluent stream is over 6 bar, the pressure p10 in the heat recovery circuit is at minimum 1 bar and at maximum 3 bar, and the pressure p11 in the stream of biologically contaminated feed is 0.5 bar or less. A device for the sterilization of biological material, comprising: a feed line for contaminated material; a unit for heat treatment of said material; an effluent line for sterilized materi

In [21]:
final = pd.DataFrame()
for i in range(10):
    results_ = [process_chunk(chunk) for chunk in chunks[i * 100 : 100 * i + 100]]
    final_result = pd.concat(results)
    final = pd.concat([final_result, final]).reset_index(drop=True)
print(final)

             pub_no                                          claim_txt
0       20120003120  A method for recovering heat in a device for t...
1       20120003150  The method of claim 1, further comprising admi...
2       20120003151  (Canceled). (Canceled). (Canceled). The method...
3       20120003153  (Canceled). The indolyl-oxadiazolyl-diazabicyc...
4       20120003156  A method for treating neoplasia in a subject, ...
...             ...                                                ...
893005  20120069587  1. A light clip comprising: a base; a magnet a...
893006  20120069588  1. A light interfacing board, comprising: a st...
893007  20120069589  1. An LED landing light arrangement for an air...
893008  20120069590  1. A direction indicator comprising: a directi...
893009  20120069591  1. An exterior mirror vision system for a vehi...

[893010 rows x 2 columns]


In [14]:
final3 = final.drop_duplicates(subset="pub_no", keep="first").reset_index(drop=True)
print(final3)

            pub_no                                          claim_txt
0      20120003120  A method for recovering heat in a device for t...
1      20120003150  The method of claim 1, further comprising admi...
2      20120003151  (Canceled). (Canceled). (Canceled). The method...
3      20120003153  (Canceled). The indolyl-oxadiazolyl-diazabicyc...
4      20120003156  A method for treating neoplasia in a subject, ...
...            ...                                                ...
89199  20120069587  1. A light clip comprising: a base; a magnet a...
89200  20120069588  1. A light interfacing board, comprising: a st...
89201  20120069589  1. An LED landing light arrangement for an air...
89202  20120069590  1. A direction indicator comprising: a directi...
89203  20120069591  1. An exterior mirror vision system for a vehi...

[89204 rows x 2 columns]


In [3]:
final3["pub_no"].nunique()

89204

## Graph creation


In [7]:
import pandas as pd
import numpy as np

### Adjancy Matrix

In [8]:
df= pd.read_csv("../data/2018/graph/combined_design_patent_cpc_data.csv")

In [9]:
#df= pd.read_csv("../data/2019/combined_design_patent_cpc_data_2019.csv")

In [10]:
df['Date'] = df['Patent ID'].str[-8:]  

# Convert the extracted date to a datetime object
df['Date'] = pd.to_datetime(df['Date'], format='%Y%m%d')

# Filter for patents from January to April (inclusive)
filtered_df = df[(df['Date'].dt.month >= 1) & (df['Date'].dt.month <= 6)]

# Drop the temporary 'Date' column if not needed
df = filtered_df.drop(columns=['Date'])

# Display the filtered DataFrame
print(df)

                                Figure ID            Patent ID Main CPC  \
0        USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
1        USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
2        USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
3        USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
4        USD0806351-20180102-D00001_1.png  USD0806351-20180102        A   
...                                   ...                  ...      ...   
726651  USD0821270-20180626-D00009_10.png  USD0821270-20180626        A   
726652  USD0821270-20180626-D00009_10.png  USD0821270-20180626        F   
726653  USD0821270-20180626-D00009_10.png  USD0821270-20180626        A   
726654  USD0821270-20180626-D00009_10.png  USD0821270-20180626        A   
726655  USD0821270-20180626-D00009_10.png  USD0821270-20180626        A   

       Big CPC Medium CPC  
0          A23       A23L  
1          A47       A47G  
2          A23 

In [11]:
import os
import pandas as pd


def filter_df(df, folder_path):
    # Get a list of all image files in the folder
    try:
        image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]
    except FileNotFoundError:
        print("The folder does not exist")
        return None


    # Filter the DataFrame to only include rows where the image is present in the folder
    filtered_df = df[df['Figure ID'].isin(image_files)]
    return filtered_df

folder_path = '../data/2018/test_query'
filtered_df = filter_df(df, folder_path)
if filtered_df is not None:
    print(filtered_df)

                               Figure ID            Patent ID Main CPC  \
5       USD0806366-20180102-D00002_3.png  USD0806366-20180102        A   
6       USD0806366-20180102-D00002_3.png  USD0806366-20180102        A   
7       USD0806366-20180102-D00002_3.png  USD0806366-20180102        A   
20      USD0806419-20180102-D00001_2.png  USD0806419-20180102        A   
21      USD0806532-20180102-D00004_4.png  USD0806532-20180102        A   
...                                  ...                  ...      ...   
726628  USD0819951-20180612-D00002_2.png  USD0819951-20180612        A   
726629  USD0819951-20180612-D00002_2.png  USD0819951-20180612        A   
726630  USD0819951-20180612-D00002_2.png  USD0819951-20180612        A   
726631  USD0819951-20180612-D00002_2.png  USD0819951-20180612        A   
726632  USD0819951-20180612-D00002_2.png  USD0819951-20180612        A   

       Big CPC Medium CPC  
5          A42       A42C  
6          A42       A42B  
7          A61       A61F  

In [12]:
images_graph= filtered_df["Figure ID"].unique()

In [13]:
len(images_graph)

27101

In [19]:
from scipy.sparse import coo_matrix

# Step 1: Get unique nodes for each category
figure_nodes = filtered_df["Figure ID"].unique()
patent_nodes = filtered_df["Patent ID"].unique()
medium_ipc_nodes = filtered_df["Medium CPC"].unique()
big_ipc_nodes = filtered_df["Big CPC"].unique()
main_ipc_nodes = filtered_df["Main CPC"].unique()

# Step 2: Create index mappings
figure_idx = {node: i for i, node in enumerate(figure_nodes)}
patent_idx = {node: i for i, node in enumerate(patent_nodes)}
medium_ipc_idx = {node: i for i, node in enumerate(medium_ipc_nodes)}
big_ipc_idx = {node: i for i, node in enumerate(big_ipc_nodes)}
main_ipc_idx = {node: i for i, node in enumerate(main_ipc_nodes)}

In [20]:
len(figure_idx)

27101

In [21]:

import pickle# Save the dictionary
with open('image_index_2018.pkl', 'wb') as f:
    pickle.dump(figure_idx, f)

In [22]:
prefix = 'USD0806717-20180102'
filtered_dict = {key: value for key, value in figure_idx.items() if key.startswith(prefix)}

print(filtered_dict.values())

dict_values([416, 688])


In [23]:
print("Graph Statistics:")
print("Patent Figures : ",len(figure_idx))
print("Patent IDs : ",len(patent_idx))
print("Medium CPC codes : ",len(medium_ipc_idx))
print("Big CPC codes : ",len(big_ipc_idx))
print("Main CPC codes : ",len(main_ipc_idx))

Graph Statistics:
Patent Figures :  27101
Patent IDs :  13552
Medium CPC codes :  578
Big CPC codes :  126
Main CPC codes :  9


In [24]:
# Figure - Patent 
figure_patent_edges = [
    (figure_idx[row["Figure ID"]], patent_idx[row["Patent ID"]])
    for _, row in filtered_df.iterrows()
]
# Add reverse edges for undirected graph
rows, cols = zip(*set(figure_patent_edges))
figure_patent_adj = coo_matrix(
    (np.ones(len(rows)), (rows, cols)), shape=(len(figure_nodes), len(patent_nodes))
)

# Patent - Medium IPC
patent_medium_edges = [
    (patent_idx[row["Patent ID"]], medium_ipc_idx[row["Medium CPC"]])
    for _, row in filtered_df.iterrows()
]
rows, cols = zip(*set(patent_medium_edges))
patent_medium_adj = coo_matrix(
    (np.ones(len(rows)), (rows, cols)), shape=(len(patent_nodes), len(medium_ipc_nodes))
)

# Medium IPC - Big IPC
medium_big_edges = [
    (medium_ipc_idx[row["Medium CPC"]], big_ipc_idx[row["Big CPC"]])
    for _, row in filtered_df.iterrows()
]
rows, cols = zip(*set(medium_big_edges))
medium_big_adj = coo_matrix(
    (np.ones(len(rows)), (rows, cols)),
    shape=(len(medium_ipc_nodes), len(big_ipc_nodes)),
)

# Big IPC - Main IPC
big_main_edges = [
    (big_ipc_idx[row["Big CPC"]], main_ipc_idx[row["Main CPC"]])
    for _, row in filtered_df.iterrows()
]
rows, cols = zip(*set(big_main_edges))
big_main_adj = coo_matrix(
    (np.ones(len(rows)), (rows, cols)), shape=(len(big_ipc_nodes), len(main_ipc_nodes))
)

# Print adjacency matrices as
print("Figure - Patent Adjacency Matrix:\n", figure_patent_adj.toarray())
print("Patent - Medium IPC Adjacency Matrix:\n", patent_medium_adj.toarray())
print("Medium IPC - Big IPC Adjacency Matrix:\n", medium_big_adj.toarray())
print("Big IPC - Main IPC Adjacency Matrix:\n", big_main_adj.toarray())

Figure - Patent Adjacency Matrix:
 [[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Patent - Medium IPC Adjacency Matrix:
 [[1. 1. 1. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Medium IPC - Big IPC Adjacency Matrix:
 [[1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
Big IPC - Main IPC Adjacency Matrix:
 [[1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 ...
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 1.]]


In [25]:
print(f"figure_patent_adj shape: {figure_patent_adj.shape}")
print(f"patent_medium_adj shape: {patent_medium_adj.shape}")
print(f"medium_big_adj shape: {medium_big_adj.shape}")
print(f"big_main_adj shape: {big_main_adj.shape}")

figure_patent_adj shape: (27101, 13552)
patent_medium_adj shape: (13552, 578)
medium_big_adj shape: (578, 126)
big_main_adj shape: (126, 9)


In [26]:
from scipy.sparse import coo_matrix, hstack, vstack

# Define zero matrices where connections don't exist
patent_zero_figure = coo_matrix(
    (len(figure_nodes), len(patent_nodes))
)  # (1057711, 668)
zero_patent_medium = coo_matrix(
    (len(patent_nodes), len(medium_ipc_nodes))
)  # (1057711, 668)
zero_medium_big = coo_matrix((len(medium_ipc_nodes), len(big_ipc_nodes)))  # (668, 32)
zero_big_main = coo_matrix((len(big_ipc_nodes), len(main_ipc_nodes)))  # (32, 9)

figure_self_connections = coo_matrix(
    (
        np.ones(len(figure_nodes)),
        (range(len(figure_nodes)), range(len(figure_nodes))),
    ),  # Diagonal elements
    shape=(len(figure_nodes), len(figure_nodes)),
)
patent_self_connections = coo_matrix(
    (
        np.ones(len(patent_nodes)),
        (range(len(patent_nodes)), range(len(patent_nodes))),
    ),  # Diagonal elements
    shape=(len(patent_nodes), len(patent_nodes)),
)
medium_self_connections = coo_matrix(
    (
        np.ones(len(medium_ipc_nodes)),
        (range(len(medium_ipc_nodes)), range(len(medium_ipc_nodes))),
    ),  # Diagonal elemmedium_ipc  shape=(len(medium_ipc_nodes), len(medium_ipc_nodes))
)
big_self_connections = coo_matrix(
    (
        np.ones(len(big_ipc_nodes)),
        (range(len(big_ipc_nodes)), range(len(big_ipc_nodes))),
    ),  # Diagonal elements
    shape=(len(big_ipc_nodes), len(big_ipc_nodes)),
)
main_self_connections = coo_matrix(
    (
        np.ones(len(main_ipc_nodes)),
        (range(len(main_ipc_nodes)), range(len(main_ipc_nodes))),
    ),  # Diagonal elements
    shape=(len(main_ipc_nodes), len(main_ipc_nodes)),
)
# Top row: patent-medium connections, and no direct connections to big or main IPCs
upper = hstack(
    [
        figure_self_connections,
        figure_patent_adj,  # Patent-medium connections
        coo_matrix(
            (len(figure_nodes), len(medium_ipc_nodes))
        ),
        coo_matrix(
            (len(figure_nodes), len(big_ipc_nodes))
        ),  # No direct patent-big connections
        coo_matrix(
            (len(figure_nodes), len(main_ipc_nodes))
        ),  # No direct patent-main connections
    ]
)
print(upper.shape)

# Second row: medium-patent transpose, medium-big connections, and no direct medium-main connections
middle1 = hstack(
    [   figure_patent_adj.T,
        patent_self_connections,
        patent_medium_adj,  # Patent-medium connections
        coo_matrix(
            (len(patent_nodes), len(big_ipc_nodes))
        ),  # No direct patent-big connections
        coo_matrix(
            (len(patent_nodes), len(main_ipc_nodes))
        ),  # No direct patent-main connections
    ]
)
print(middle1.shape)
# Third row: big-medium transpose, big-main connections, and no direct big-patent connections
middle2 = hstack(
    [   coo_matrix(
            (len(medium_ipc_nodes), len(figure_nodes))
        ),
        patent_medium_adj.T,  # Medium-patent connections
        medium_self_connections,  # No self-connections for medium IPCs
        medium_big_adj,  # Medium-big connections
        coo_matrix(
            (len(medium_ipc_nodes), len(main_ipc_nodes))
        ),  # No direct medium-main connections
    ]
)
print(middle2.shape)
# Bottom row: main-big transpose, and no direct connections to patents or medium IPCs
middle3 = hstack(
    [   coo_matrix(
            (len(big_ipc_nodes), len(figure_nodes))
        ),
        coo_matrix(
            (len(big_ipc_nodes), len(patent_nodes))
        ),  # No direct big-patent connections
        medium_big_adj.T,  # Big-medium connections
        big_self_connections,  # No self-connections for big IPCs
        big_main_adj,  # Big-main connections
    ]
)
print(middle3.shape)

lower = hstack(
    [   coo_matrix(
            (len(main_ipc_nodes), len(figure_nodes))
        ),
        coo_matrix(
            (len(main_ipc_nodes), len(patent_nodes))
        ),  # No direct main-patent connections
        coo_matrix(
            (len(main_ipc_nodes), len(medium_ipc_nodes))
        ),  # No direct main-medium connections
        big_main_adj.T,  # Main-big connections
        main_self_connections,  # No self-connections for main IPCs
    ]
)
print(lower.shape)
# Stack all rows to create the full adjacency matrix
combined_adj = vstack([upper, middle1, middle2,middle3, lower])

(27101, 41366)
(13552, 41366)
(578, 41366)
(126, 41366)
(9, 41366)


In [27]:
# Check if the combined adjacency matrix is symmetric
if (combined_adj != combined_adj.T).nnz == 0:
    print("The adjacency matrix is symmetric.")
else:
    print("The adjacency matrix is not symmetric.")

The adjacency matrix is symmetric.


In [28]:
from scipy.sparse import save_npz, load_npz

save_npz("../data/2018/graph/combined_adj_query_old_v2_big.npz", combined_adj)

In [5]:
from scipy.sparse import load_npz

# Load the adjacency matrix
adj = load_npz("../data/2018/graph/combined_adj_query_old_v2_big.npz")

# Count non-zero entries (each one represents an edge if undirected, or one direction if directed)
num_edges = adj.nnz
num_edges = adj.nnz // 2
print("Number of edges:", num_edges)


Number of edges: 93400


In [29]:
print(combined_adj.toarray())

[[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]


### Feature Matrix

In [30]:
# load the pickles with the emebeddings
figure_features_df = pd.read_pickle("../data/2018/graph/query_images_embeddings_model_9_v3.pkl")
figure_features_df['Embedding'] = figure_features_df['embedding']
figure_features_df.drop('embedding', axis=1, inplace=True)

patent_features_df = pd.read_pickle("../data/2018/graph/patent_embeddings.pkl")
patent_features_df['Embedding'] = patent_features_df['Title_Embedding']
patent_features_df.drop('Title_Embedding', axis=1, inplace=True)

medium_features_df = pd.read_pickle("../data/2018/graph/medium_ipc_embeddings.pkl")
medium_features_df['Embedding'] = medium_features_df['Title_Embedding']
medium_features_df.drop('Title_Embedding', axis=1, inplace=True)

big_features_df = pd.read_pickle("../data/2018/graph/big_ipc_embeddings.pkl")
big_features_df['Embedding'] = big_features_df['Title_Embedding']
big_features_df.drop('Title_Embedding', axis=1, inplace=True)

main_features_df = pd.read_pickle("../data/2018/graph/main_ipc_embeddings.pkl")
main_features_df['Embedding'] = main_features_df['Title']
main_features_df.drop('Title', axis=1, inplace=True)

In [31]:
# Step 2: Create index mappings
figure_idx = {str(node): i for i, node in enumerate(figure_nodes)}
patent_idx = {str(node): i for i, node in enumerate(patent_nodes)}
medium_ipc_idx = {str(node): i for i, node in enumerate(medium_ipc_nodes)}
big_ipc_idx = {str(node): i for i, node in enumerate(big_ipc_nodes)}
main_ipc_idx = {str(node): i for i, node in enumerate(main_ipc_nodes)}

In [32]:
len(main_features_df.head(1)['Embedding'][0])

1

In [33]:
import pandas as pd
import numpy as np
from scipy.sparse import coo_matrix, save_npz
import torch


def align_features(df, node_mapping, feature_dim):
    aligned_features = np.zeros(
        (len(node_mapping), feature_dim)
    )  # Initialize a zero matrix
    for node, idx in node_mapping.items():
        if node in df.index:  # Ensure the node exists in the DataFrame
            feature_vector = df.loc[node, "Embedding"]  # Get the feature vector
            aligned_features[idx, :] = np.array(
                feature_vector
            )  # Assign the feature vector
        else:
            aligned_features[idx, :] = np.zeros(
                feature_dim
            )  # Assign a zero vector for missing nodes
    return torch.tensor(aligned_features, dtype=torch.float32)


# Convert the features DataFrames to numpy arrays and set the index to the respective node IDs
feature_dim = 512  # Example feature dimension (adjust as needed)

# For figures
figure_features_df.set_index("image_name", inplace=True)
aligned_figure_features = align_features(figure_features_df, figure_idx, feature_dim)

# For patents
patent_features_df.set_index("Patent ID", inplace=True)
aligned_patent_features = align_features(patent_features_df, patent_idx, feature_dim)

# For medium IPCs
medium_features_df.set_index("CPC", inplace=True)
aligned_medium_features = align_features(
    medium_features_df, medium_ipc_idx, feature_dim
)

# For big IPCs
big_features_df.set_index("CPC", inplace=True)
aligned_big_features = align_features(big_features_df, big_ipc_idx, feature_dim)

# For main IPCs
main_features_df.set_index("CPC", inplace=True)
aligned_main_features = align_features(main_features_df, main_ipc_idx, feature_dim)

# Step 3: Stack the features
combined_features = np.vstack(
    [   
        aligned_figure_features,
        aligned_patent_features,
        aligned_medium_features,
        aligned_big_features,
        aligned_main_features,
    ]
)


In [34]:

# Step 4: Save the combined feature matrix
save_npz("../data/2018/graph/combined_features_matrix_old_v3_big.npz", coo_matrix(combined_features))

# Optional: Verify the shape and order
print(f"Combined feature matrix shape: {combined_features.shape}")
print(f"Number of figures features: {aligned_figure_features.shape[0]}")
print(f"Number of patent features: {aligned_patent_features.shape[0]}")
print(f"Number of medium IPC features: {aligned_medium_features.shape[0]}")
print(f"Number of big IPC features: {aligned_big_features.shape[0]}")
print(f"Number of main IPC features: {aligned_main_features.shape[0]}")

Combined feature matrix shape: (41366, 512)
Number of figures features: 27101
Number of patent features: 13552
Number of medium IPC features: 578
Number of big IPC features: 126
Number of main IPC features: 9


In [35]:
aligned_figure_features

tensor([[ 1.0463e-02, -3.3161e-02, -9.6905e-03,  ...,  1.1711e-03,
          8.5174e-03,  7.1239e-04],
        [ 1.1471e-02, -3.4681e-02, -1.0460e-02,  ..., -1.1522e-03,
          9.4798e-03, -1.5956e-04],
        [ 1.0490e-02, -3.4362e-02, -1.1014e-02,  ..., -1.3435e-03,
          7.7000e-03, -4.9574e-05],
        ...,
        [ 1.3617e-02, -3.2675e-02, -8.5478e-03,  ..., -6.5707e-04,
          9.2585e-03,  8.9988e-04],
        [ 1.1752e-02, -3.2056e-02, -9.6562e-03,  ..., -1.4564e-03,
          9.8141e-03,  4.0854e-04],
        [ 1.1757e-02, -3.3799e-02, -9.9008e-03,  ..., -1.4917e-04,
          9.0486e-03,  4.9185e-04]])

In [36]:
aligned_figure_features

tensor([[ 1.0463e-02, -3.3161e-02, -9.6905e-03,  ...,  1.1711e-03,
          8.5174e-03,  7.1239e-04],
        [ 1.1471e-02, -3.4681e-02, -1.0460e-02,  ..., -1.1522e-03,
          9.4798e-03, -1.5956e-04],
        [ 1.0490e-02, -3.4362e-02, -1.1014e-02,  ..., -1.3435e-03,
          7.7000e-03, -4.9574e-05],
        ...,
        [ 1.3617e-02, -3.2675e-02, -8.5478e-03,  ..., -6.5707e-04,
          9.2585e-03,  8.9988e-04],
        [ 1.1752e-02, -3.2056e-02, -9.6562e-03,  ..., -1.4564e-03,
          9.8141e-03,  4.0854e-04],
        [ 1.1757e-02, -3.3799e-02, -9.9008e-03,  ..., -1.4917e-04,
          9.0486e-03,  4.9185e-04]])

## Hyper good

In [39]:
import numpy as np
import scipy.sparse as sp
import random
import json
import os
from collections import defaultdict

def prepare_training_data(A, X, counts, output_dir, neg_ratio=10, fig_pair_ratio=5):
    """
    Prepare training data including figure-to-figure pairs from the same patent.
    
    Args:
        A: Adjacency matrix
        X: Feature matrix
        counts: Dictionary with counts of different node types
        output_dir: Directory to save the prepared data
        neg_ratio: Ratio of negative to positive samples for figure-patent pairs
        fig_pair_ratio: Ratio of negative to positive samples for figure-figure pairs
    """
    # Unpack counts
    num_figures = counts['figures']
    num_patents = counts['patents']
    num_medium_cpcs = counts['medium_cpcs']
    num_big_cpcs = counts['big_cpcs']
    num_main_cpcs = counts['main_cpcs']

    # Compute index offsets
    idx_figures_end = num_figures
    idx_patents_start = num_figures
    idx_patents_end = idx_patents_start + num_patents
    idx_medium_cpcs_start = idx_patents_end
    idx_medium_cpcs_end = idx_medium_cpcs_start + num_medium_cpcs
    idx_big_cpcs_start = idx_medium_cpcs_end
    idx_big_cpcs_end = idx_big_cpcs_start + num_big_cpcs
    idx_main_cpcs_start = idx_big_cpcs_end
    idx_main_cpcs_end = idx_main_cpcs_start + num_main_cpcs

    total_nodes = idx_main_cpcs_end
    
    # Calculate LABEL_NUM for validation
    label_num_check = (idx_main_cpcs_end - idx_patents_start)
    print(f"Data Prep Check: Calculated LABEL_NUM = {label_num_check}")

    # Label offsets for evaluation slicing
    label_offsets = {
        'patents': idx_patents_start,
        'medium_cpcs': idx_medium_cpcs_start,
        'big_cpcs': idx_big_cpcs_start,
        'main_cpcs': idx_main_cpcs_start,
    }

    Y_pos = []
    implication = []
    exclusion = []
    Y_neg = []
    
    # New: For figure-to-figure pairs
    figure_to_patent = defaultdict(set)  # Maps figure_idx -> set of patent_idx
    patent_to_figures = defaultdict(set)  # Maps patent_idx -> set of figure_idx
    positive_figure_pairs = []  # Pairs of figures from the same patent
    negative_figure_pairs = []  # Pairs of figures from different patents

    # Initialize dictionaries for hierarchical mappings
    patent_to_medium = {}
    medium_to_big = {}
    big_to_main = {}

    if not isinstance(A, sp.coo_matrix):
        A_coo = A.tocoo()
    else:
        A_coo = A

    positive_figure_patent_map = {}

    print("Processing adjacency matrix to extract relationships...")
    for i, j, _ in zip(A_coo.row, A_coo.col, A_coo.data):
        # 1. Figure -> Patent (Positive Node-Label Pairs)
        if 0 <= i < idx_figures_end and idx_patents_start <= j < idx_patents_end:
            figure_original_idx = i
            patent_relative_idx = j - idx_patents_start
            Y_pos.append((figure_original_idx, patent_relative_idx))
            
            # Store for negative sampling
            if figure_original_idx not in positive_figure_patent_map:
                positive_figure_patent_map[figure_original_idx] = set()
            positive_figure_patent_map[figure_original_idx].add(j)  # absolute j
            
            # Store for figure-to-figure pair generation
            figure_to_patent[figure_original_idx].add(j)  # absolute patent idx
            patent_to_figures[j].add(figure_original_idx)  # absolute figure idx

        # 2. Patent -> Medium CPC (Hierarchical Implication)
        elif idx_patents_start <= i < idx_patents_end and idx_medium_cpcs_start <= j < idx_medium_cpcs_end:
            child_relative_idx = i - idx_patents_start
            parent_relative_idx = j - idx_patents_start
            implication.append((child_relative_idx, parent_relative_idx))
            patent_to_medium[child_relative_idx] = parent_relative_idx

        # 3. Medium CPC -> Big CPC (Hierarchical Implication)
        elif idx_medium_cpcs_start <= i < idx_medium_cpcs_end and idx_big_cpcs_start <= j < idx_big_cpcs_end:
            child_relative_idx = i - idx_patents_start
            parent_relative_idx = j - idx_patents_start
            implication.append((child_relative_idx, parent_relative_idx))
            medium_to_big[child_relative_idx] = parent_relative_idx

        # 4. Big CPC -> Main CPC (Hierarchical Implication)
        elif idx_big_cpcs_start <= i < idx_big_cpcs_end and idx_main_cpcs_start <= j < idx_main_cpcs_end:
            child_relative_idx = i - idx_patents_start
            parent_relative_idx = j - idx_patents_start
            implication.append((child_relative_idx, parent_relative_idx))
            big_to_main[child_relative_idx] = parent_relative_idx

    print(f"Found {len(Y_pos)} positive figure-patent pairs.")
    print(f"Found {len(implication)} hierarchical implication pairs.")

    # Generate figure-to-figure pairs
    print("Generating figure-to-figure pairs...")
    
    # First, generate positive pairs (figures from the same patent)
    for patent_idx, figures in patent_to_figures.items():
        figures_list = list(figures)
        if len(figures_list) >= 2:  # Need at least 2 figures to form a pair
            # Generate all possible pairs of figures from this patent
            for i in range(len(figures_list)):
                for j in range(i+1, len(figures_list)):
                    positive_figure_pairs.append((figures_list[i], figures_list[j]))
    
    print(f"Generated {len(positive_figure_pairs)} positive figure-figure pairs.")
    
    # Generate negative pairs (figures from different patents)
    # We'll sample randomly to keep the number manageable
    target_neg_fig_pairs = len(positive_figure_pairs) * fig_pair_ratio
    all_figures = list(range(num_figures))
    
    neg_fig_pairs_count = 0
    max_attempts = target_neg_fig_pairs * 10  # Heuristic limit
    attempts = 0
    
    while neg_fig_pairs_count < target_neg_fig_pairs and attempts < max_attempts:
        # Sample two random figures
        fig1 = random.choice(all_figures)
        fig2 = random.choice(all_figures)
        
        if fig1 == fig2:
            attempts += 1
            continue  # Skip same figure
        
        # Check if they belong to different patents
        patents1 = figure_to_patent.get(fig1, set())
        patents2 = figure_to_patent.get(fig2, set())
        
        # If they share any patent, they're not a negative pair
        if patents1.intersection(patents2):
            attempts += 1
            continue
        
        # Ensure we don't add duplicate pairs
        pair = (min(fig1, fig2), max(fig1, fig2))  # Order to avoid duplicates
        if pair not in negative_figure_pairs:
            negative_figure_pairs.append(pair)
            neg_fig_pairs_count += 1
        
        attempts += 1
    
    print(f"Generated {len(negative_figure_pairs)} negative figure-figure pairs.")

    # Generate Negative Samples (Y_neg) with exclusivity constraints
    print(f"Generating negative samples (ratio: {neg_ratio})...")

    all_patent_absolute_indices = list(range(idx_patents_start, idx_patents_end))
    figures_with_pos = list(positive_figure_patent_map.keys())
    random.shuffle(figures_with_pos)

    for fig_idx in figures_with_pos:
        positive_patents_for_fig_abs = positive_figure_patent_map.get(fig_idx, set())
        # Build sets of hierarchical category assignments for the positive patents
        pos_patents_rel = {p_abs - idx_patents_start for p_abs in positive_patents_for_fig_abs}
        pos_big = set()
        pos_main = set()
        for p_rel in pos_patents_rel:
            if p_rel in patent_to_medium:
                medium = patent_to_medium[p_rel]
                if medium in medium_to_big:
                    big = medium_to_big[medium]
                    pos_big.add(big)
                    if big in big_to_main:
                        main = big_to_main[big]
                        pos_main.add(main)
        
        # Decide how many negatives to generate for this figure
        num_pos_for_fig = len(positive_patents_for_fig_abs)
        target_neg_count = num_pos_for_fig * neg_ratio
        neg_count = 0
        max_attempts = target_neg_count * len(all_patent_absolute_indices)
        attempts = 0
        current_negatives = []

        while neg_count < target_neg_count and attempts < max_attempts:
            neg_patent_abs_idx = random.choice(all_patent_absolute_indices)
            candidate_rel = neg_patent_abs_idx - idx_patents_start

            # Basic exclusion: candidate should not be a positive
            if neg_patent_abs_idx in positive_patents_for_fig_abs:
                attempts += 1
                continue

            # Check hierarchical exclusivity
            candidate_medium = patent_to_medium.get(candidate_rel, None)
            candidate_big = medium_to_big.get(candidate_medium, None) if candidate_medium is not None else None
            candidate_main = big_to_main.get(candidate_big, None) if candidate_big is not None else None

            # Exclude if candidate's big CPC is same as one of the positives, or candidate's main CPC is same
            if candidate_big is not None and candidate_big in pos_big:
                attempts += 1
                continue
            if candidate_main is not None and candidate_main in pos_main:
                attempts += 1
                continue

            # If passed all checks, add as a negative sample
            neg_patent_relative_idx = candidate_rel
            current_negatives.append((fig_idx, neg_patent_relative_idx))
            neg_count += 1
            attempts += 1

        Y_neg.extend(current_negatives)

        if attempts >= max_attempts and neg_count < target_neg_count:
            print(f"Warning: Could only generate {neg_count}/{target_neg_count} negative samples for figure {fig_idx} after {attempts} attempts.")

    print(f"Generated {len(Y_neg)} negative figure-patent pairs.")

    # Prepare data dictionary
    X_figures_only = X[:num_figures, :]

    prepared_data = {
        'X_figures': X_figures_only,
        'Y_pos': Y_pos,
        'Y_neg': Y_neg,
        'implication': implication,
        'exclusion': exclusion,
        'label_offsets': label_offsets,
        'positive_figure_pairs': positive_figure_pairs,  # New
        'negative_figure_pairs': negative_figure_pairs   # New
    }

    # Save the data
    print(f"Saving data to directory: {output_dir}")
    os.makedirs(output_dir, exist_ok=True)

    npz_path = os.path.join(output_dir, 'training_data_cross.npz')
    json_path = os.path.join(output_dir, 'label_offsets_cross.json')

    Y_pos_np = np.array(Y_pos, dtype=np.int32) if Y_pos else np.empty((0, 2), dtype=np.int32)
    Y_neg_np = np.array(Y_neg, dtype=np.int32) if Y_neg else np.empty((0, 2), dtype=np.int32)
    implication_np = np.array(implication, dtype=np.int32) if implication else np.empty((0, 2), dtype=np.int32)
    exclusion_np = np.array(exclusion, dtype=np.int32) if exclusion else np.empty((0, 2), dtype=np.int32)
    pos_fig_pairs_np = np.array(positive_figure_pairs, dtype=np.int32) if positive_figure_pairs else np.empty((0, 2), dtype=np.int32)
    neg_fig_pairs_np = np.array(negative_figure_pairs, dtype=np.int32) if negative_figure_pairs else np.empty((0, 2), dtype=np.int32)

    # Add validation print before saving
    if implication:
        print(f"Data Prep Save Check: Min/Max implication indices: {implication_np.min()}, {implication_np.max()}")
    if Y_pos:
        print(f"Data Prep Save Check: Min/Max Y_pos patent indices: {Y_pos_np[:, 1].min()}, {Y_pos_np[:, 1].max()}")
    if Y_neg:
        print(f"Data Prep Save Check: Min/Max Y_neg patent indices: {Y_neg_np[:, 1].min()}, {Y_neg_np[:, 1].max()}")
    if positive_figure_pairs:
        print(f"Data Prep Save Check: Min/Max positive figure pair indices: {pos_fig_pairs_np.min()}, {pos_fig_pairs_np.max()}")
    if negative_figure_pairs:
        print(f"Data Prep Save Check: Min/Max negative figure pair indices: {neg_fig_pairs_np.min()}, {neg_fig_pairs_np.max()}")

    try:
        np.savez_compressed(
            npz_path,
            X_figures=X_figures_only,
            Y_pos=Y_pos_np,
            Y_neg=Y_neg_np,
            implication=implication_np,
            exclusion=exclusion_np,
            positive_figure_pairs=pos_fig_pairs_np,  # New
            negative_figure_pairs=neg_fig_pairs_np   # New
        )
        print(f"Saved array data to {npz_path}")
    except Exception as e:
        print(f"Error saving NPZ file: {e}")

    try:
        with open(json_path, 'w') as f:
            json.dump(label_offsets, f, indent=4)
        print(f"Saved label offsets to {json_path}")
    except Exception as e:
        print(f"Error saving JSON file: {e}")

    return prepared_data
# --- Example Usage (Updated for Saving) ---
if __name__ == '__main__':
    # Define counts
    counts = {
        'figures': 27101,
        'patents': 13552,
        'medium_cpcs': 578,
        'big_cpcs': 126,
        'main_cpcs': 9
    }
    
    total_nodes = sum(counts.values())
    feature_dim = 512
    output_save_directory = './prepared_training_data' # Define where to save
    
    #A_sparse = sp.coo_matrix((data, (rows, cols)), shape=(total_nodes, total_nodes))
    A_sparse = combined_adj
    print(f"A loaded with {A_sparse.nnz} elements.")
    
    X_full_features = combined_features
    print(f"X loaded with {X_full_features.shape}.")

    # Prepare the training data and save it
    training_data = prepare_training_data(
        A_sparse,
        X_full_features,
        counts,
        output_dir=output_save_directory, # Pass the save directory
        neg_ratio=20,
        fig_pair_ratio=15  # New parameter for figure-to-figure pair generation
    )

    # You can still use the returned dictionary if needed
    print("\n--- Returned Data Summary ---")
    print(f"X_figures shape: {training_data['X_figures'].shape}")
    print(f"Number of positive pairs (Y_pos): {len(training_data['Y_pos'])}")
    print(f"Number of negative pairs (Y_neg): {len(training_data['Y_neg'])}")
    print(f"Number of implication pairs: {len(training_data['implication'])}")
    print(f"Number of exclusion pairs: {len(training_data['exclusion'])}")
    print(f"Number of positive figure pairs: {len(training_data['positive_figure_pairs'])}")  # New
    print(f"Number of negative figure pairs: {len(training_data['negative_figure_pairs'])}")  # New
    print(f"Label Offsets: {training_data['label_offsets']}")

    # --- How to load the saved data later ---
    print(f"\n--- To load the data later ---")
    try:
        loaded_npz = np.load(os.path.join(output_save_directory, 'training_data_cross.npz'))
        print("Loaded NPZ data keys:", list(loaded_npz.keys()))
        # Access data like: loaded_npz['X_figures'], loaded_npz['Y_pos']
        
        # Access the new figure pair data
        if 'positive_figure_pairs' in loaded_npz:
            print(f"Loaded {len(loaded_npz['positive_figure_pairs'])} positive figure pairs")
        if 'negative_figure_pairs' in loaded_npz:
            print(f"Loaded {len(loaded_npz['negative_figure_pairs'])} negative figure pairs")

        with open(os.path.join(output_save_directory, 'label_offsets_cross.json'), 'r') as f:
            loaded_offsets = json.load(f)
        print("Loaded JSON offsets:", loaded_offsets)
    except FileNotFoundError:
        print("Saved files not found (this is expected if running the first time).")
    except Exception as e:
        print(f"Error loading saved files: {e}")

A loaded with 186800 elements.
X loaded with (41366, 512).
Data Prep Check: Calculated LABEL_NUM = 14265
Processing adjacency matrix to extract relationships...
Found 27101 positive figure-patent pairs.
Found 45616 hierarchical implication pairs.
Generating figure-to-figure pairs...
Generated 13549 positive figure-figure pairs.
Generated 203235 negative figure-figure pairs.
Generating negative samples (ratio: 20)...
Generated 542020 negative figure-patent pairs.
Saving data to directory: ./prepared_training_data
Data Prep Save Check: Min/Max implication indices: 0, 14264
Data Prep Save Check: Min/Max Y_pos patent indices: 0, 13551
Data Prep Save Check: Min/Max Y_neg patent indices: 0, 13551
Data Prep Save Check: Min/Max positive figure pair indices: 0, 27100
Data Prep Save Check: Min/Max negative figure pair indices: 0, 27100
Saved array data to ./prepared_training_data/training_data.npz
Saved label offsets to ./prepared_training_data/label_offsets.json

--- Returned Data Summary ---
X

In [37]:
import numpy as np
from collections import defaultdict

def extract_figure_to_pos_figures(A, num_figures, num_patents):
    # Build mapping: patent_idx -> set of figure indices
    patent_to_figures = defaultdict(set)
    figure_to_patents = defaultdict(set)

    # Assume A is in COO format
    A_coo = A.tocoo()
    idx_patents_start = num_figures
    idx_patents_end = num_figures + num_patents

    for i, j in zip(A_coo.row, A_coo.col):
        # Figure -> Patent edge
        if 0 <= i < num_figures and idx_patents_start <= j < idx_patents_end:
            patent_idx = j
            figure_idx = i
            patent_to_figures[patent_idx].add(figure_idx)
            figure_to_patents[figure_idx].add(patent_idx)

    # Now, for each figure, collect all other figures sharing any patent
    figure_to_pos_figures = defaultdict(list)
    for figure_idx in range(num_figures):
        pos_set = set()
        for patent_idx in figure_to_patents.get(figure_idx, []):
            pos_set.update(patent_to_figures[patent_idx])
        pos_set.discard(figure_idx)  # Remove self
        if pos_set:
            figure_to_pos_figures[figure_idx] = list(pos_set)

    return figure_to_pos_figures

# Example usage:
num_figures= 27101
num_patents= 13552
figure_to_pos_figures = extract_figure_to_pos_figures(combined_adj, num_figures, num_patents)

In [38]:
import pickle

with open('figure_to_pos_figures.pkl', 'wb') as f:
    pickle.dump(figure_to_pos_figures, f)

In [24]:
import numpy as np
import scipy.sparse as sp
import random
import json # Import json library
import os   # Import os library for path handling
def prepare_training_data(A, X, counts, output_dir, neg_ratio=10):
    # ... (Setup code: num_figures, num_patents, ..., index ranges, total_nodes, validation) ...
    num_figures = counts['figures']
    num_patents = counts['patents']
    num_medium_cpcs = counts['medium_cpcs']
    num_big_cpcs = counts['big_cpcs']
    num_main_cpcs = counts['main_cpcs']

    idx_figures_end = num_figures
    idx_patents_start = num_figures # This is the offset we need to subtract for labels
    idx_patents_end = idx_patents_start + num_patents
    idx_medium_cpcs_start = idx_patents_end
    idx_medium_cpcs_end = idx_medium_cpcs_start + num_medium_cpcs
    idx_big_cpcs_start = idx_medium_cpcs_end
    idx_big_cpcs_end = idx_big_cpcs_start + num_big_cpcs
    idx_main_cpcs_start = idx_big_cpcs_end
    idx_main_cpcs_end = idx_main_cpcs_start + num_main_cpcs

    total_nodes = idx_main_cpcs_end
    # ... (Validations for A and X shapes) ...

    # --- Calculate LABEL_NUM for validation ---
    label_num_check = (idx_main_cpcs_end - idx_patents_start)
    print(f"Data Prep Check: Calculated LABEL_NUM = {label_num_check}")
    # --- ---

    # --- Corrected label_offsets (relative to start of labels) ---
    # Although the training script doesn't seem to use these offsets directly
    # for indexing label_emb, let's keep them as absolute for now,
    # as the training script uses them for evaluation slicing.
    label_offsets = {
        'patents': idx_patents_start,
        'medium_cpcs': idx_medium_cpcs_start,
        'big_cpcs': idx_big_cpcs_start,
        'main_cpcs': idx_main_cpcs_start,
    }
    # --- ---

    Y_pos = []
    implication = []
    exclusion = []
    Y_neg = []

    if not isinstance(A, sp.coo_matrix):
        A_coo = A.tocoo()
    else:
        A_coo = A

    positive_figure_patent_map = {}

    print("Processing adjacency matrix to extract relationships...")
    for i, j, _ in zip(A_coo.row, A_coo.col, A_coo.data):
        # 1. Figure -> Patent (Positive Node-Label Pairs)
        if 0 <= i < idx_figures_end and idx_patents_start <= j < idx_patents_end:
            figure_original_idx = i
            # --- FIX: Save RELATIVE patent index ---
            patent_relative_idx = j - idx_patents_start
            Y_pos.append((figure_original_idx, patent_relative_idx))
            # --- Store absolute index in map for negative sampling logic ---
            if figure_original_idx not in positive_figure_patent_map:
                positive_figure_patent_map[figure_original_idx] = set()
            positive_figure_patent_map[figure_original_idx].add(j) # Still use absolute j here

        # 2. Patent -> Medium CPC (Hierarchical Implication)
        elif idx_patents_start <= i < idx_patents_end and idx_medium_cpcs_start <= j < idx_medium_cpcs_end:
            # --- FIX: Save RELATIVE indices ---
            child_relative_idx = i - idx_patents_start
            parent_relative_idx = j - idx_patents_start
            implication.append((child_relative_idx, parent_relative_idx))

        # 3. Medium CPC -> Big CPC (Hierarchical Implication)
        elif idx_medium_cpcs_start <= i < idx_medium_cpcs_end and idx_big_cpcs_start <= j < idx_big_cpcs_end:
            # --- FIX: Save RELATIVE indices ---
            child_relative_idx = i - idx_patents_start
            parent_relative_idx = j - idx_patents_start
            implication.append((child_relative_idx, parent_relative_idx))

        # 4. Big CPC -> Main CPC (Hierarchical Implication)
        elif idx_big_cpcs_start <= i < idx_big_cpcs_end and idx_main_cpcs_start <= j < idx_main_cpcs_end:
            # --- FIX: Save RELATIVE indices ---
            child_relative_idx = i - idx_patents_start
            parent_relative_idx = j - idx_patents_start
            implication.append((child_relative_idx, parent_relative_idx))

    print(f"Found {len(Y_pos)} positive figure-patent pairs.")
    print(f"Found {len(implication)} hierarchical implication pairs.")

    # 5. Generate Negative Samples (Y_neg)
    print(f"Generating negative samples (ratio: {neg_ratio})...")
    # Use absolute indices for sampling logic
    all_patent_absolute_indices = list(range(idx_patents_start, idx_patents_end))
    # Iterate through the original Y_pos structure before modification if needed,
    # or adapt using the map. Let's use the map.
    figures_with_pos = list(positive_figure_patent_map.keys())
    random.shuffle(figures_with_pos)

    for fig_idx in figures_with_pos:
        positive_patents_for_fig_abs = positive_figure_patent_map.get(fig_idx, set())
        # Generate neg_ratio negatives for *each* positive? Or just per figure?
        # Let's assume per figure for now.
        num_pos_for_fig = len(positive_patents_for_fig_abs)
        target_neg_count = num_pos_for_fig * neg_ratio # Or just neg_ratio if per figure

        neg_count = 0
        max_attempts = target_neg_count * len(all_patent_absolute_indices) # Heuristic limit
        attempts = 0
        current_negatives = []

        while neg_count < target_neg_count and attempts < max_attempts:
            neg_patent_abs_idx = random.choice(all_patent_absolute_indices)
            # Ensure the sampled patent is not a true positive for this figure
            if neg_patent_abs_idx not in positive_patents_for_fig_abs:
                 # --- FIX: Save RELATIVE negative index ---
                neg_patent_relative_idx = neg_patent_abs_idx - idx_patents_start
                current_negatives.append((fig_idx, neg_patent_relative_idx))
                neg_count += 1
            attempts += 1

        Y_neg.extend(current_negatives) # Add all negatives for this figure

        if attempts >= max_attempts and neg_count < target_neg_count:
             print(f"Warning: Could only generate {neg_count}/{target_neg_count} negative samples for figure {fig_idx} after {attempts} attempts.")

    print(f"Generated {len(Y_neg)} negative figure-patent pairs.")

    # ... (Exclusion logic - if added, ensure relative indices are saved) ...

    X_figures_only = X[:num_figures, :]

    # --- Prepare data dictionary (now contains relative indices for labels) ---
    prepared_data = {
        'X_figures': X_figures_only,
        'Y_pos': Y_pos,
        'Y_neg': Y_neg,
        'implication': implication,
        'exclusion': exclusion,
        'label_offsets': label_offsets # Keep absolute offsets for eval slicing
    }

    # --- Saving the data ---
    # ... (Saving logic remains the same, but saves the corrected data) ...
    print(f"Saving data to directory: {output_dir}")
    os.makedirs(output_dir, exist_ok=True)

    npz_path = os.path.join(output_dir, 'training_data.npz')
    json_path = os.path.join(output_dir, 'label_offsets.json')

    Y_pos_np = np.array(Y_pos, dtype=np.int32) if Y_pos else np.empty((0, 2), dtype=np.int32)
    Y_neg_np = np.array(Y_neg, dtype=np.int32) if Y_neg else np.empty((0, 2), dtype=np.int32)
    implication_np = np.array(implication, dtype=np.int32) if implication else np.empty((0, 2), dtype=np.int32)
    exclusion_np = np.array(exclusion, dtype=np.int32) if exclusion else np.empty((0, 2), dtype=np.int32)

    # --- Add validation print before saving ---
    if implication:
        print(f"Data Prep Save Check: Min/Max implication indices: {implication_np.min()}, {implication_np.max()}")
    if Y_pos:
        print(f"Data Prep Save Check: Min/Max Y_pos patent indices: {Y_pos_np[:, 1].min()}, {Y_pos_np[:, 1].max()}")
    if Y_neg:
         print(f"Data Prep Save Check: Min/Max Y_neg patent indices: {Y_neg_np[:, 1].min()}, {Y_neg_np[:, 1].max()}")
    # --- ---

    try:
        np.savez_compressed(
            npz_path,
            X_figures=X_figures_only,
            Y_pos=Y_pos_np,
            Y_neg=Y_neg_np,
            implication=implication_np,
            exclusion=exclusion_np
        )
        print(f"Saved array data to {npz_path}")
    except Exception as e:
        print(f"Error saving NPZ file: {e}")

    try:
        with open(json_path, 'w') as f:
            json.dump(label_offsets, f, indent=4)
        print(f"Saved label offsets to {json_path}")
    except Exception as e:
        print(f"Error saving JSON file: {e}")

    return prepared_data
  

# --- Example Usage (Updated for Saving) ---
if __name__ == '__main__':
    # Define counts
    counts = {
        'figures': 17719,
        'patents': 8860,
        'medium_cpcs': 544,
        'big_cpcs': 125,
        'main_cpcs': 9
    }
    
    total_nodes = sum(counts.values())
    feature_dim = 512
    output_save_directory = './prepared_training_data' # Define where to save

    
    
    #A_sparse = sp.coo_matrix((data, (rows, cols)), shape=(total_nodes, total_nodes))
    A_sparse = combined_adj
    print(f" A loaded  with {A_sparse.nnz} elements.")

    
    X_full_features = combined_features
    print(f"X loaded  with {X_full_features.shape}.")

    # Prepare the training data and save it
    training_data = prepare_training_data(
        A_sparse,
        X_full_features,
        counts,
        output_dir=output_save_directory, # Pass the save directory
        neg_ratio=10
    )

    # You can still use the returned dictionary if needed
    print("\n--- Returned Data Summary ---")
    print(f"X_figures shape: {training_data['X_figures'].shape}")
    print(f"Number of positive pairs (Y_pos): {len(training_data['Y_pos'])}")
    print(f"Number of negative pairs (Y_neg): {len(training_data['Y_neg'])}")
    print(f"Number of implication pairs: {len(training_data['implication'])}")
    print(f"Number of exclusion pairs: {len(training_data['exclusion'])}")
    print(f"Label Offsets: {training_data['label_offsets']}")

    # --- How to load the saved data later ---
    print(f"\n--- To load the data later ---")
    try:
        loaded_npz = np.load(os.path.join(output_save_directory, 'training_data_small.npz'))
        print("Loaded NPZ data keys:", list(loaded_npz.keys()))
        # Access data like: loaded_npz['X_figures'], loaded_npz['Y_pos']

        with open(os.path.join(output_save_directory, 'label_offsets.json'), 'r') as f:
            loaded_offsets = json.load(f)
        print("Loaded JSON offsets:", loaded_offsets)
    except FileNotFoundError:
        print("Saved files not found (this is expected if running the first time).")
    except Exception as e:
        print(f"Error loading saved files: {e}")

 

 A loaded  with 123401 elements.
X loaded  with (27257, 512).
Data Prep Check: Calculated LABEL_NUM = 9538
Processing adjacency matrix to extract relationships...
Found 17719 positive figure-patent pairs.
Found 30353 hierarchical implication pairs.
Generating negative samples (ratio: 10)...
Generated 177190 negative figure-patent pairs.
Saving data to directory: ./prepared_training_data
Data Prep Save Check: Min/Max implication indices: 0, 9537
Data Prep Save Check: Min/Max Y_pos patent indices: 0, 8859
Data Prep Save Check: Min/Max Y_neg patent indices: 0, 8859
Saved array data to ./prepared_training_data/training_data.npz
Saved label offsets to ./prepared_training_data/label_offsets.json

--- Returned Data Summary ---
X_figures shape: (17719, 512)
Number of positive pairs (Y_pos): 17719
Number of negative pairs (Y_neg): 177190
Number of implication pairs: 30353
Number of exclusion pairs: 0
Label Offsets: {'patents': 17719, 'medium_cpcs': 26579, 'big_cpcs': 27123, 'main_cpcs': 27248}


In [27]:
def prepare_training_data(A, X, counts, output_dir, neg_ratio=5): 
        
    # Unpack counts

    num_figures = counts['figures'] 
    num_patents = counts['patents'] 
    num_medium_cpcs = counts['medium_cpcs'] 
    num_big_cpcs = counts['big_cpcs'] 
    num_main_cpcs = counts['main_cpcs']

    # Compute index offsets. The patent indices start right after figures.
    idx_figures_end = num_figures
    idx_patents_start = num_figures              # Patent absolute indices start here.
    idx_patents_end = idx_patents_start + num_patents

    idx_medium_cpcs_start = idx_patents_end
    idx_medium_cpcs_end = idx_medium_cpcs_start + num_medium_cpcs

    idx_big_cpcs_start = idx_medium_cpcs_end
    idx_big_cpcs_end = idx_big_cpcs_start + num_big_cpcs

    idx_main_cpcs_start = idx_big_cpcs_end
    idx_main_cpcs_end = idx_main_cpcs_start + num_main_cpcs

    total_nodes = idx_main_cpcs_end
    # (You might validate A and X shapes here.)

    # --- Calculate LABEL_NUM for validation ---
    label_num_check = (idx_main_cpcs_end - idx_patents_start)
    print(f"Data Prep Check: Calculated LABEL_NUM = {label_num_check}")
    # --- ---

    # We keep label_offsets as absolute indices (for evaluation slicing)
    label_offsets = {
        'patents': idx_patents_start,
        'medium_cpcs': idx_medium_cpcs_start,
        'big_cpcs': idx_big_cpcs_start,
        'main_cpcs': idx_main_cpcs_start,
    }

    Y_pos = []
    implication = []
    exclusion = []
    Y_neg = []

    # --- Initialize dictionaries for hierarchical mappings ---
    # These will map (relative) patent id -> medium id, medium -> big, and big -> main.
    # For patents, the relative index is computed as (patent_abs - idx_patents_start)
    patent_to_medium = {}  # For condition: patent -> medium CPC
    medium_to_big = {}     # For condition: medium CPC -> big CPC
    big_to_main = {}       # For condition: big CPC -> main CPC

    if not isinstance(A, sp.coo_matrix):
        A_coo = A.tocoo()
    else:
        A_coo = A

    positive_figure_patent_map = {}

    print("Processing adjacency matrix to extract relationships...")
    for i, j, _ in zip(A_coo.row, A_coo.col, A_coo.data):
        # 1. Figure -> Patent (Positive Node-Label Pairs)
        if 0 <= i < idx_figures_end and idx_patents_start <= j < idx_patents_end:
            figure_original_idx = i
            # Save RELATIVE patent index (for training, [0, num_patents-1])
            patent_relative_idx = j - idx_patents_start
            Y_pos.append((figure_original_idx, patent_relative_idx))
            # Also store the absolute index in order to build the set for negative sampling.
            if figure_original_idx not in positive_figure_patent_map:
                positive_figure_patent_map[figure_original_idx] = set()
            positive_figure_patent_map[figure_original_idx].add(j)  # absolute j

        # 2. Patent -> Medium CPC (Hierarchical Implication)
        elif idx_patents_start <= i < idx_patents_end and idx_medium_cpcs_start <= j < idx_medium_cpcs_end:
            # Save RELATIVE indices
            child_relative_idx = i - idx_patents_start  # in [0, num_patents-1]
            parent_relative_idx = j - idx_patents_start   # medium CPC relative index; note: >= num_patents
            implication.append((child_relative_idx, parent_relative_idx))
            # Also build mapping dictionary for patent -> medium
            patent_to_medium[child_relative_idx] = parent_relative_idx

        # 3. Medium CPC -> Big CPC (Hierarchical Implication)
        elif idx_medium_cpcs_start <= i < idx_medium_cpcs_end and idx_big_cpcs_start <= j < idx_big_cpcs_end:
            child_relative_idx = i - idx_patents_start  # Note: i - idx_patents_start will be in [num_patents, num_patents+num_medium_cpcs-1]
            parent_relative_idx = j - idx_patents_start   # big CPC relative index; >= num_patents+num_medium_cpcs
            implication.append((child_relative_idx, parent_relative_idx))
            # Build mapping for medium -> big (key is medium relative id)
            medium_to_big[child_relative_idx] = parent_relative_idx

        # 4. Big CPC -> Main CPC (Hierarchical Implication)
        elif idx_big_cpcs_start <= i < idx_big_cpcs_end and idx_main_cpcs_start <= j < idx_main_cpcs_end:
            child_relative_idx = i - idx_patents_start  # will be in [num_patents+num_medium_cpcs, num_patents+num_medium_cpcs+num_big_cpcs-1]
            parent_relative_idx = j - idx_patents_start   # main CPC relative index; >= num_patents+num_medium_cpcs+num_big_cpcs
            implication.append((child_relative_idx, parent_relative_idx))
            # Build mapping for big -> main
            big_to_main[child_relative_idx] = parent_relative_idx

    print(f"Found {len(Y_pos)} positive figure-patent pairs.")
    print(f"Found {len(implication)} hierarchical implication pairs.")

    # --- Generate Negative Samples (Y_neg) with exclusivity constraints ---
    print(f"Generating negative samples (ratio: {neg_ratio})...")

    # Use absolute indices for negative sampling for the patent nodes.
    all_patent_absolute_indices = list(range(idx_patents_start, idx_patents_end))
    # Get list of figures for which we have positive assignments.
    figures_with_pos = list(positive_figure_patent_map.keys())
    random.shuffle(figures_with_pos)

    for fig_idx in figures_with_pos:
        positive_patents_for_fig_abs = positive_figure_patent_map.get(fig_idx, set())
        # Build sets of hierarchical category assignments for the positive patents.
        pos_patents_rel = {p_abs - idx_patents_start for p_abs in positive_patents_for_fig_abs}
        pos_big = set()
        pos_main = set()
        for p_rel in pos_patents_rel:
            if p_rel in patent_to_medium:
                medium = patent_to_medium[p_rel]
                if medium in medium_to_big:
                    big = medium_to_big[medium]
                    pos_big.add(big)
                    if big in big_to_main:
                        main = big_to_main[big]
                        pos_main.add(main)
        
        # Decide how many negatives to generate for this figure.
        num_pos_for_fig = len(positive_patents_for_fig_abs)
        target_neg_count = num_pos_for_fig * neg_ratio  # Here, negative samples per figure
        neg_count = 0
        max_attempts = target_neg_count * len(all_patent_absolute_indices)  # heuristic limit
        attempts = 0
        current_negatives = []

        while neg_count < target_neg_count and attempts < max_attempts:
            neg_patent_abs_idx = random.choice(all_patent_absolute_indices)
            candidate_rel = neg_patent_abs_idx - idx_patents_start

            # Basic exclusion: candidate should not be a positive.
            if neg_patent_abs_idx in positive_patents_for_fig_abs:
                attempts += 1
                continue

            # Now check hierarchical exclusivity:
            # Get candidate's medium, big and main (if available)
            candidate_medium = patent_to_medium.get(candidate_rel, None)
            candidate_big = medium_to_big.get(candidate_medium, None) if candidate_medium is not None else None
            candidate_main = big_to_main.get(candidate_big, None) if candidate_big is not None else None

            # Exclude if candidate's big CPC is same as one of the positives, or candidate's main CPC is same.
            if candidate_big is not None and candidate_big in pos_big:
                attempts += 1
                continue
            if candidate_main is not None and candidate_main in pos_main:
                attempts += 1
                continue

            # If passed all checks, add as a negative sample.
            # Save negative sample using relative index for patent.
            neg_patent_relative_idx = candidate_rel
            current_negatives.append((fig_idx, neg_patent_relative_idx))
            neg_count += 1
            attempts += 1

        Y_neg.extend(current_negatives)

        if attempts >= max_attempts and neg_count < target_neg_count:
            print(f"Warning: Could only generate {neg_count}/{target_neg_count} negative samples for figure {fig_idx} after {attempts} attempts.")

    print(f"Generated {len(Y_neg)} negative figure-patent pairs.")

    # --- Continue with rest of processing ---
    X_figures_only = X[:num_figures, :]

    # --- Prepare data dictionary. Note:
    # Y_pos and implication now have relative indices for labels.
    prepared_data = {
        'X_figures': X_figures_only,
        'Y_pos': Y_pos,
        'Y_neg': Y_neg,
        'implication': implication,
        'exclusion': exclusion,
        'label_offsets': label_offsets  # Offsets are kept absolute (for evaluation slicing)
    }

    # --- Save the data ---
    print(f"Saving data to directory: {output_dir}")
    os.makedirs(output_dir, exist_ok=True)

    npz_path = os.path.join(output_dir, 'training_data_small.npz')
    json_path = os.path.join(output_dir, 'label_offsets_small.json')

    Y_pos_np = np.array(Y_pos, dtype=np.int32) if Y_pos else np.empty((0, 2), dtype=np.int32)
    Y_neg_np = np.array(Y_neg, dtype=np.int32) if Y_neg else np.empty((0, 2), dtype=np.int32)
    implication_np = np.array(implication, dtype=np.int32) if implication else np.empty((0, 2), dtype=np.int32)
    exclusion_np = np.array(exclusion, dtype=np.int32) if exclusion else np.empty((0, 2), dtype=np.int32)

    # --- Add validation print before saving ---
    if implication:
        print(f"Data Prep Save Check: Min/Max implication indices: {implication_np.min()}, {implication_np.max()}")
    if Y_pos:
        print(f"Data Prep Save Check: Min/Max Y_pos patent indices: {Y_pos_np[:, 1].min()}, {Y_pos_np[:, 1].max()}")
    if Y_neg:
        print(f"Data Prep Save Check: Min/Max Y_neg patent indices: {Y_neg_np[:, 1].min()}, {Y_neg_np[:, 1].max()}")
    # --- ---

    try:
        np.savez_compressed(
            npz_path,
            X_figures=X_figures_only,
            Y_pos=Y_pos_np,
            Y_neg=Y_neg_np,
            implication=implication_np,
            exclusion=exclusion_np
        )
        print(f"Saved array data to {npz_path}")
    except Exception as e:
        print(f"Error saving NPZ file: {e}")

    try:
        with open(json_path, 'w') as f:
            json.dump(label_offsets, f, indent=4)
        print(f"Saved label offsets to {json_path}")
    except Exception as e:
        print(f"Error saving JSON file: {e}")

    return prepared_data

In [28]:

counts = {
        'figures': 17719,
        'patents': 8860,
        'medium_cpcs': 544,
        'big_cpcs': 125,
        'main_cpcs': 9
    }

feature_dim = 512
output_save_directory = './prepared_training_data' # Define where to save

# A_sparse should be provided (e.g., from your combined adjacency matrix)
A_sparse = combined_adj  # combined_adj must be defined earlier in your script
print(f"A loaded with {A_sparse.nnz} elements.")

# X_full_features should be provided (combined_features)
X_full_features = combined_features  # combined_features must be defined earlier
print(f"X loaded with {X_full_features.shape}.")

# Prepare the training data and save it
training_data = prepare_training_data(
    A_sparse,
    X_full_features,
    counts,
    output_dir=output_save_directory,
    neg_ratio=10
)

# Optionally print summary of returned data
print("\n--- Returned Data Summary ---")
print(f"X_figures shape: {training_data['X_figures'].shape}")
print(f"Number of positive pairs (Y_pos): {len(training_data['Y_pos'])}")
print(f"Number of negative pairs (Y_neg): {len(training_data['Y_neg'])}")
print(f"Number of implication pairs: {len(training_data['implication'])}")
print(f"Number of exclusion pairs: {len(training_data['exclusion'])}")
print(f"Label Offsets: {training_data['label_offsets']}")

# --- How to load the saved data later ---
print("\n--- To load the data later ---")
try:
    loaded_npz = np.load(os.path.join(output_save_directory, 'training_data_small.npz'))
    print("Loaded NPZ data keys:", list(loaded_npz.keys()))
    # Access data like: loaded_npz['X_figures'], loaded_npz['Y_pos']
    with open(os.path.join(output_save_directory, 'label_offsets_small.json'), 'r') as f:
        loaded_offsets = json.load(f)
    print("Loaded JSON offsets:", loaded_offsets)
except FileNotFoundError:
    print("Saved files not found (this is expected if running the first time).")
except Exception as e:
    print(f"Error loading saved files: {e}")

A loaded with 123401 elements.
X loaded with (27257, 512).
Data Prep Check: Calculated LABEL_NUM = 9538
Processing adjacency matrix to extract relationships...
Found 17719 positive figure-patent pairs.
Found 30353 hierarchical implication pairs.
Generating negative samples (ratio: 10)...
Generated 177190 negative figure-patent pairs.
Saving data to directory: ./prepared_training_data
Data Prep Save Check: Min/Max implication indices: 0, 9537
Data Prep Save Check: Min/Max Y_pos patent indices: 0, 8859
Data Prep Save Check: Min/Max Y_neg patent indices: 0, 8859
Saved array data to ./prepared_training_data/training_data_small.npz
Saved label offsets to ./prepared_training_data/label_offsets_small.json

--- Returned Data Summary ---
X_figures shape: (17719, 512)
Number of positive pairs (Y_pos): 17719
Number of negative pairs (Y_neg): 177190
Number of implication pairs: 30353
Number of exclusion pairs: 0
Label Offsets: {'patents': 17719, 'medium_cpcs': 26579, 'big_cpcs': 27123, 'main_cpcs'

## Pair computation

In [16]:
import numpy as np
import pickle
from scipy import sparse
from collections import defaultdict
import random

# Load the adjacency matrix
adj_matrix_path = "../data/2018/graph/combined_adj_query_hier_01_3.npz"
combined_adj = sparse.load_npz(adj_matrix_path)
combined_adj = combined_adj.tocsr()

# Load the figure_to_row mapping
with open('image_index_2018.pkl', 'rb') as f:
    figure_to_row = pickle.load(f)

# Define the indices for different entity types
#num_figures = 22924
#num_patents = 11463
#num_medium_cpc = 566
#num_big_cpc = 126
#num_main_cpc = 9
num_figures = 32115
num_patents = 16059
num_medium_cpc = 595
num_big_cpc = 126
num_main_cpc = 9

figure_start = 0
patent_start = num_figures
medium_cpc_start = patent_start + num_patents
big_cpc_start = medium_cpc_start + num_medium_cpc
main_cpc_start = big_cpc_start + num_big_cpc

# Create mappings
figure_to_patent = {}
patent_to_figures = defaultdict(list)
patent_to_medium_cpc = defaultdict(list)
medium_cpc_to_big_cpc = defaultdict(list)
big_cpc_to_main_cpc = defaultdict(list)

print("Building hierarchical connections...")

# Figure to Patent connections
for figure_name, row_idx in figure_to_row.items():
    row_idx = int(row_idx)
    row_data = combined_adj[row_idx, :].nonzero()[1]
    
    for col_idx in row_data:
        if patent_start <= col_idx < medium_cpc_start:
            patent_idx = col_idx - patent_start
            figure_to_patent[figure_name] = patent_idx
            patent_to_figures[patent_idx].append(figure_name)
            break

# Patent to Medium CPC connections
for patent_idx in range(num_patents):
    patent_row = patent_start + patent_idx
    row_data = combined_adj[patent_row, :].nonzero()[1]
    
    for col_idx in row_data:
        if medium_cpc_start <= col_idx < big_cpc_start:
            medium_cpc_idx = col_idx - medium_cpc_start
            patent_to_medium_cpc[patent_idx].append(medium_cpc_idx)

# Medium CPC to Big CPC connections
for medium_cpc_idx in range(num_medium_cpc):
    medium_cpc_row = medium_cpc_start + medium_cpc_idx
    row_data = combined_adj[medium_cpc_row, :].nonzero()[1]
    
    for col_idx in row_data:
        if big_cpc_start <= col_idx < main_cpc_start:
            big_cpc_idx = col_idx - big_cpc_start
            medium_cpc_to_big_cpc[medium_cpc_idx].append(big_cpc_idx)

# Big CPC to Main CPC connections
for big_cpc_idx in range(num_big_cpc):
    big_cpc_row = big_cpc_start + big_cpc_idx
    row_data = combined_adj[big_cpc_row, :].nonzero()[1]
    
    for col_idx in row_data:
        if col_idx >= main_cpc_start:
            main_cpc_idx = col_idx - main_cpc_start
            big_cpc_to_main_cpc[big_cpc_idx].append(main_cpc_idx)

# Sample pairs and determine their connection level
print("Sampling pairs and determining connection levels...")

# Number of pairs to sample
num_samples = 200000  # Sample more than needed to ensure we have enough for each level
figure_names = list(figure_to_row.keys())
connection_levels = {1: 0, 2: 0, 3: 0, 4: 0, 5: 0}
sampled_pairs = []
used_pairs = set()  # To track pairs we've already used

# First, ensure we have at least 1500 Level 1 connections (same patent)
level1_pairs = []
patents_with_multiple_figures = [p for p, figs in patent_to_figures.items() if len(figs) >= 2]
first_level=set()
second_level=set()
third_level=set()
fourth_level=set()
for patent_idx in patents_with_multiple_figures:
    figures = patent_to_figures[patent_idx]
    if len(figures) >= 2:
        # Sample pairs from this patent
        pairs = [(fig1, fig2) for i, fig1 in enumerate(figures) for fig2 in figures[i+1:]]
        level1_pairs.extend(pairs)
        if len(level1_pairs) >= 24000:
            break

# Randomly select 1500 Level 1 pairs
if len(level1_pairs) > 24000:
    level1_pairs = random.sample(level1_pairs, 24000)

# Add these Level 1 pairs to our samples
for fig1, fig2 in level1_pairs:
    pair_key = tuple(sorted([fig1, fig2]))
    used_pairs.add(pair_key)
    first_level.add(tuple(sorted([fig1, fig2])))
    sampled_pairs.append((fig1, fig2, 1))
    connection_levels[1] += 1

# Calculate how many more samples we need
remaining_samples = num_samples - len(level1_pairs)

# Sample the remaining pairs randomly
for _ in range(remaining_samples):
    # Sample two different figures
    fig1, fig2 = random.sample(figure_names, 2)
    pair_key = tuple(sorted([fig1, fig2]))
    
    if pair_key in used_pairs:
        continue
    
    used_pairs.add(pair_key)
    
    # Skip if either figure doesn't have a patent connection
    if fig1 not in figure_to_patent or fig2 not in figure_to_patent:
        sampled_pairs.append((fig1, fig2, 5))
        connection_levels[5] += 1
        continue
    
    patent1 = figure_to_patent[fig1]
    patent2 = figure_to_patent[fig2]
    
    # Level 1: Same patent
    if patent1 == patent2 :
        first_level.add(tuple(sorted([fig1, fig2])))
        sampled_pairs.append((fig1, fig2, 1))
        connection_levels[1] += 1
        continue
    
    # Level 2: Patents share medium CPC
    medium_cpcs1 = set(patent_to_medium_cpc[patent1])
    medium_cpcs2 = set(patent_to_medium_cpc[patent2])
    
    if medium_cpcs1 and medium_cpcs2 and medium_cpcs1.intersection(medium_cpcs2) and tuple(sorted([fig1, fig2])) not in first_level:
        sampled_pairs.append((fig1, fig2, 2))
        second_level.add(tuple(sorted([fig1, fig2])))
        connection_levels[2] += 1
        continue
    
    # Level 3: Medium CPCs share big CPC
    big_cpcs1 = set()
    for medium_cpc in medium_cpcs1:
        big_cpcs1.update(medium_cpc_to_big_cpc[medium_cpc])
    
    big_cpcs2 = set()
    for medium_cpc in medium_cpcs2:
        big_cpcs2.update(medium_cpc_to_big_cpc[medium_cpc])
    
    if big_cpcs1 and big_cpcs2 and big_cpcs1.intersection(big_cpcs2)and tuple(sorted([fig1, fig2])) not in second_level and tuple(sorted([fig1, fig2])) not in first_level:
        sampled_pairs.append((fig1, fig2, 3))
        third_level.add(tuple(sorted([fig1, fig2])))
        connection_levels[3] += 1
        continue
    
    # Level 4: Big CPCs share main CPC
    main_cpcs1 = set()
    for big_cpc in big_cpcs1:
        main_cpcs1.update(big_cpc_to_main_cpc[big_cpc])
    
    main_cpcs2 = set()
    for big_cpc in big_cpcs2:
        main_cpcs2.update(big_cpc_to_main_cpc[big_cpc])
    
    if main_cpcs1 and main_cpcs2 and main_cpcs1.intersection(main_cpcs2)and tuple(sorted([fig1, fig2])) not in third_level:
        sampled_pairs.append((fig1, fig2, 4))
        fourth_level.add(tuple(sorted([fig1, fig2])))
        connection_levels[4] += 1
        continue
    
    # Level 5: No connection
    sampled_pairs.append((fig1, fig2, 5))
    connection_levels[5] += 1

# Organize pairs by level
pairs_by_level = {1: [], 2: [], 3: [], 4: [], 5: []}
for pair in sampled_pairs:
    pairs_by_level[pair[2]].append(pair)


for level in pairs_by_level:
    print(level)
    if len(pairs_by_level[level]) > 28000 and level!=1:
        pairs_by_level[level] = random.sample(pairs_by_level[level], 28000)

# Combine all pairs
all_pairs = []
for level in range(1, 6):
    all_pairs.extend(pairs_by_level[level])
print(len(all_pairs))
# Update connection levels count after capping
connection_levels = {level: len(pairs_by_level[level]) for level in range(1, 6)}

# Print connection level counts
print("\nConnection Level Counts:")
for level, count in connection_levels.items():
    level_name = {
        1: "Same Patent",
        2: "Share Medium CPC",
        3: "Share Big CPC",
        4: "Share Main CPC",
        5: "No Connection"
    }[level]
    print(f"Level {level} ({level_name}): {count} pairs")

# Save the sampled pairs to a file
import json
with open('../src/figure_pair_connections.json', 'w') as f:
    json.dump({
        "connection_levels": connection_levels,
        "sampled_pairs": [(p[0], p[1], p[2]) for p in all_pairs]
    }, f, indent=2)

print("\nSampled pairs saved to 'figure_pair_connections.json'")

Building hierarchical connections...
Sampling pairs and determining connection levels...
1
2
3
4
5
105464

Connection Level Counts:
Level 1 (Same Patent): 16056 pairs
Level 2 (Share Medium CPC): 15302 pairs
Level 3 (Share Big CPC): 18106 pairs
Level 4 (Share Main CPC): 28000 pairs
Level 5 (No Connection): 28000 pairs

Sampled pairs saved to 'figure_pair_connections.json'


In [28]:
first_level.intersection(second_level)

set()

In [13]:
import torch
import pickle
import os

# Function to extract hierarchical pairs from your data
def extract_hierarchical_pairs(figure_to_row, figure_to_patent, patent_to_medium_cpc, 
                              medium_cpc_to_big_cpc, big_cpc_to_main_cpc,
                              patent_start, medium_cpc_start, big_cpc_start, main_cpc_start):
    
    print("Extracting hierarchical pairs...")
    hierarchical_pairs = []

    # Level 1: Figures to Patents
    for fig_name, patent_idx in figure_to_patent.items():
        if fig_name in figure_to_row:
            fig_idx = int(figure_to_row[fig_name])
            hierarchical_pairs.append((fig_idx, patent_start + patent_idx))

    # Level 2: Patents to Medium CPCs
    for patent_idx, medium_cpcs in patent_to_medium_cpc.items():
        for medium_cpc in medium_cpcs:
            hierarchical_pairs.append((patent_start + patent_idx, medium_cpc_start + medium_cpc))

    # Level 3: Medium CPCs to Big CPCs
    for medium_cpc, big_cpcs in medium_cpc_to_big_cpc.items():
        for big_cpc in big_cpcs:
            hierarchical_pairs.append((medium_cpc_start + medium_cpc, big_cpc_start + big_cpc))

    # Level 4: Big CPCs to Main CPCs
    for big_cpc, main_cpcs in big_cpc_to_main_cpc.items():
        for main_cpc in main_cpcs:
            hierarchical_pairs.append((big_cpc_start + big_cpc, main_cpc_start + main_cpc))

    print(f"Extracted {len(hierarchical_pairs)} hierarchical pairs")
    return hierarchical_pairs

# Function to save hierarchical pairs
def save_hierarchical_pairs(hierarchical_pairs, filepath='../src/hierarchical_pairs.pkl'):
    with open(filepath, 'wb') as f:
        pickle.dump(hierarchical_pairs, f)
    print(f"Saved hierarchical pairs to {filepath}")

# Function to load hierarchical pairs
def load_hierarchical_pairs(filepath='../src/hierarchical_pairs.pkl'):
    if os.path.exists(filepath):
        with open(filepath, 'rb') as f:
            hierarchical_pairs = pickle.load(f)
        print(f"Loaded {len(hierarchical_pairs)} hierarchical pairs from {filepath}")
        return hierarchical_pairs
    else:
        print(f"File {filepath} not found. Need to extract hierarchical pairs first.")
        return None

In [36]:
hierarchical_pairs = extract_hierarchical_pairs(
        figure_to_row, figure_to_patent, patent_to_medium_cpc,
        medium_cpc_to_big_cpc, big_cpc_to_main_cpc,
        patent_start, medium_cpc_start, big_cpc_start, main_cpc_start
    )
save_hierarchical_pairs(hierarchical_pairs)


Extracting hierarchical pairs...
Extracted 86171 hierarchical pairs
Saved hierarchical pairs to ../src/hierarchical_pairs.pkl


In [31]:
hierarchical_pairs

tensor([[    0, 32115],
        [    1, 32116],
        [    2, 32117],
        ...,
        [48892, 48896],
        [48893, 48899],
        [48894, 48903]])

In [None]:
hierarchical_pairs

In [None]:
Building hierarchical connections...
Sampling pairs and determining connection levels...
1
2
3
4
5
105372

Connection Level Counts:
Level 1 (Same Patent): 16056 pairs
Level 2 (Share Medium CPC): 15401 pairs
Level 3 (Share Big CPC): 17915 pairs
Level 4 (Share Main CPC): 28000 pairs
Level 5 (No Connection): 28000 pairs

Sampled pairs saved to 'figure_pair_connections.json'

## Hyperbolic data generation

In [17]:
num_figures = 32115
num_patents = 16059
num_medium_cpc = 595
num_big_cpc = 126
num_main_cpc = 9

In [20]:
import torch
import pickle
import os
from tqdm import tqdm
# Load the figure_to_row mapping
with open('image_index_2018.pkl', 'rb') as f:
    figure_to_row = pickle.load(f)
    
def generate_hyperbolic_inputs(
    figure_to_row, 
    figure_to_patent, 
    patent_to_medium_cpc, 
    medium_cpc_to_big_cpc, 
    big_cpc_to_main_cpc,
    num_figures,
    num_patents,
    num_medium_cpcs,
    num_big_cpcs,
    num_main_cpcs
):
    """
    Generate inputs for the hyperbolic multi-label model.
    
    Returns:
        Y_pos: List of (node_idx, label_idx) positive pairs
        Y_neg: List of (node_idx, label_idx) negative pairs
        implication: List of (child_label_idx, parent_label_idx) pairs
        exclusion: List of (label_left_idx, label_right_idx) pairs
    """
    print("Generating hyperbolic model inputs...")
    
    # Initialize empty lists
    Y_pos = []
    Y_neg = []
    
    # Generate positive figure-patent relationships
    print("Generating positive relationships...")
    for fig_idx, fig_name in tqdm(enumerate(figure_to_row.keys())):
        if fig_name in figure_to_patent:
            patent_idx = figure_to_patent[fig_name]
            # Add figure-patent positive relationship
            Y_pos.append((fig_idx, num_figures + patent_idx))
            
            # For each patent, add positive relationships to its medium CPCs
            if patent_idx in patent_to_medium_cpc:
                for medium_cpc_idx in patent_to_medium_cpc[patent_idx]:
                    # Medium CPC indices start after the last patent index
                    Y_pos.append((fig_idx, num_figures + num_patents + medium_cpc_idx))
                    
                    # Continue up the hierarchy for big CPCs and main CPCs
                    if medium_cpc_idx in medium_cpc_to_big_cpc:
                        for big_cpc_idx in medium_cpc_to_big_cpc[medium_cpc_idx]:
                            Y_pos.append((fig_idx, num_figures + num_patents + num_medium_cpcs + big_cpc_idx))
                            
                            if big_cpc_idx in big_cpc_to_main_cpc:
                                for main_cpc_idx in big_cpc_to_main_cpc[big_cpc_idx]:
                                    Y_pos.append((fig_idx, num_figures + num_patents + num_medium_cpcs + num_big_cpcs + main_cpc_idx))
    
        # Generate balanced negative relationships
    print("Generating balanced negative relationships...")
    Y_neg = []
    num_labels=num_figures + num_patents + num_medium_cpcs + num_big_cpcs + num_main_cpcs
    
    neg_samples_per_node = 5 
    for fig_idx in tqdm(range(num_figures)):
        # Sample random labels that aren't positive for this figure
        pos_labels = set(label for node, label in Y_pos if node == fig_idx)
        all_labels = set(range(num_labels))
        neg_candidates = list(all_labels - pos_labels)
        
        # Sample a fixed number of negatives
        if len(neg_candidates) > neg_samples_per_node:
            neg_samples = random.sample(neg_candidates, neg_samples_per_node)
            for label_idx in neg_samples:
                Y_neg.append((fig_idx, label_idx))
        
    # Generate implication relationships
    print("Generating implication relationships...")
    implication = []
    
    # Patent to Medium CPC implications
    for patent_idx, medium_cpcs in patent_to_medium_cpc.items():
        patent_label_idx = num_figures + patent_idx
        for medium_cpc_idx in medium_cpcs:
            medium_cpc_label_idx = num_figures + num_patents + medium_cpc_idx
            implication.append((patent_label_idx, medium_cpc_label_idx))
    
    # Medium CPC to Big CPC implications
    for medium_cpc_idx, big_cpcs in medium_cpc_to_big_cpc.items():
        medium_cpc_label_idx = num_figures + num_patents + medium_cpc_idx
        for big_cpc_idx in big_cpcs:
            big_cpc_label_idx = num_figures + num_patents + num_medium_cpcs + big_cpc_idx
            implication.append((medium_cpc_label_idx, big_cpc_label_idx))
    
    # Big CPC to Main CPC implications
    for big_cpc_idx, main_cpcs in big_cpc_to_main_cpc.items():
        big_cpc_label_idx = num_figures + num_patents + num_medium_cpcs + big_cpc_idx
        for main_cpc_idx in main_cpcs:
            main_cpc_label_idx = num_figures + num_patents + num_medium_cpcs + num_big_cpcs + main_cpc_idx
            implication.append((big_cpc_label_idx, main_cpc_label_idx))
    
    print(f"Generated {len(Y_pos)} positive relationships")
    print(f"Generated {len(Y_neg)} negative relationships")
    print(f"Generated {len(implication)} implication relationships")
    # Generate balanced exclusion relationships
    print("Generating balanced exclusion relationships...")
    exclusion = []

    # Patents are mutually exclusive (a figure can't belong to multiple patents)
    # Instead of all pairs, sample a limited number per patent
    max_exclusions_per_patent = 10  # Adjust this number based on your needs

    for i in range(num_patents):
        # Sample a subset of other patents to create exclusion relationships with
        other_patents = list(range(num_patents))
        other_patents.remove(i)
        
        if len(other_patents) > max_exclusions_per_patent:
            sampled_patents = random.sample(other_patents, max_exclusions_per_patent)
        else:
            sampled_patents = other_patents
        
        patent_i_idx = num_figures + i
        for j in sampled_patents:
            patent_j_idx = num_figures + j
            exclusion.append((patent_i_idx, patent_j_idx))

    # For Medium CPCs, also sample a limited number of exclusion relationships
    max_exclusions_per_medium_cpc = 5  # Adjust based on your needs

    # Create a list of all medium CPCs
    all_medium_cpcs = set()
    for medium_cpcs in patent_to_medium_cpc.values():
        all_medium_cpcs.update(medium_cpcs)

    # For each medium CPC, sample a few others to be mutually exclusive with
    for medium_cpc_i in tqdm(all_medium_cpcs):
        medium_cpc_i_idx = num_figures + num_patents + medium_cpc_i
        
        # Find medium CPCs that are likely to be exclusive with this one
        # (e.g., from different patent branches)
        candidate_exclusions = []
        for medium_cpc_j in all_medium_cpcs:
            if medium_cpc_i != medium_cpc_j:
                # Check if they're not in the same branch (simplified check)
                # You might need a more sophisticated check based on your data
                medium_cpc_i_parents = set()
                medium_cpc_j_parents = set()
                
                for patent_idx, medium_cpcs in patent_to_medium_cpc.items():
                    if medium_cpc_i in medium_cpcs:
                        medium_cpc_i_parents.add(patent_idx)
                    if medium_cpc_j in medium_cpcs:
                        medium_cpc_j_parents.add(patent_idx)
                
                # If they don't share any parents, they're likely exclusive
                if not medium_cpc_i_parents.intersection(medium_cpc_j_parents):
                    candidate_exclusions.append(medium_cpc_j)
        
        # Sample a limited number of exclusions
        if len(candidate_exclusions) > max_exclusions_per_medium_cpc:
            sampled_exclusions = random.sample(candidate_exclusions, max_exclusions_per_medium_cpc)
        else:
            sampled_exclusions = candidate_exclusions
        
        for medium_cpc_j in sampled_exclusions:
            medium_cpc_j_idx = num_figures + num_patents + medium_cpc_j
            exclusion.append((medium_cpc_i_idx, medium_cpc_j_idx))
    print(f"Generated {len(Y_pos)} positive relationships")
    print(f"Generated {len(Y_neg)} negative relationships")
    print(f"Generated {len(implication)} implication relationships")
    print(f"Generated {len(exclusion)} exclusion relationships")
    
    return Y_pos, Y_neg, implication, exclusion

def save_hyperbolic_inputs(Y_pos, Y_neg, implication, exclusion, filepath='hyperbolic_inputs.pkl'):
    """Save the hyperbolic model inputs to a pickle file."""
    data = {
        'Y_pos': Y_pos,
        'Y_neg': Y_neg,
        'implication': implication,
        'exclusion': exclusion
    }
    
    with open(filepath, 'wb') as f:
        pickle.dump(data, f)
    
    print(f"Saved hyperbolic inputs to {filepath}")

def load_hyperbolic_inputs(filepath='hyperbolic_inputs.pkl'):
    """Load the hyperbolic model inputs from a pickle file."""
    if os.path.exists(filepath):
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        
        Y_pos = data['Y_pos']
        Y_neg = data['Y_neg']
        implication = data['implication']
        exclusion = data['exclusion']
        
        print(f"Loaded {len(Y_pos)} positive relationships")
        print(f"Loaded {len(Y_neg)} negative relationships")
        print(f"Loaded {len(implication)} implication relationships")
        print(f"Loaded {len(exclusion)} exclusion relationships")
        
        return Y_pos, Y_neg, implication, exclusion
    else:
        print(f"File {filepath} not found. Need to generate hyperbolic inputs first.")
        return None, None, None, None

In [21]:
# Check if hyperbolic inputs file exists
hyperbolic_inputs_file = 'hyperbolic_inputs.pkl'
Y_pos, Y_neg, implication, exclusion = load_hyperbolic_inputs(hyperbolic_inputs_file)

# If not, generate and save them
if Y_pos is None:
    Y_pos, Y_neg, implication, exclusion = generate_hyperbolic_inputs(
        figure_to_row, 
        figure_to_patent, 
        patent_to_medium_cpc, 
        medium_cpc_to_big_cpc, 
        big_cpc_to_main_cpc,
        num_figures,
        num_patents,
        num_medium_cpc,
        num_big_cpc,
        num_main_cpc
    )
    save_hyperbolic_inputs(Y_pos, Y_neg, implication, exclusion, hyperbolic_inputs_file)


File hyperbolic_inputs.pkl not found. Need to generate hyperbolic inputs first.
Generating hyperbolic model inputs...
Generating positive relationships...


32115it [00:00, 265912.96it/s]


Generating balanced negative relationships...


100%|██████████| 32115/32115 [04:54<00:00, 108.88it/s]


Generating implication relationships...
Generated 352068 positive relationships
Generated 160575 negative relationships
Generated 54056 implication relationships
Generating balanced exclusion relationships...


100%|██████████| 595/595 [16:46<00:00,  1.69s/it]


Generated 352068 positive relationships
Generated 160575 negative relationships
Generated 54056 implication relationships
Generated 163565 exclusion relationships
Saved hyperbolic inputs to hyperbolic_inputs.pkl


In [18]:
len(Y_neg)

TypeError: object of type 'NoneType' has no len()