In [None]:
pip install transformers

In [None]:
pip install sentence_transformers

In [None]:
!pip install llama-index==0.9.38

# hej

In [24]:
import numpy as np
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import Document

# Step 1: Use LlamaIndex chunkers for better text splitting
def split_text_with_sentence_splitter(text, chunk_size=512, chunk_overlap=50):
    """
    Splits text using LlamaIndex SentenceSplitter which respects sentence boundaries.
    This creates more natural, semantically coherent chunks than simple word-based splitting.
    """
    # Create a Document object (LlamaIndex's container for text)
    document = Document(text=text)
    
    # Create a SentenceSplitter - this splits by sentences and then combines them into chunks
    sentence_splitter = SentenceSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        paragraph_separator="\n\n",  # Adjust as needed for your text
        separator=" "                # Separator used when combining sentences
    )
    
    # Parse the document into nodes
    nodes = sentence_splitter.get_nodes_from_documents([document])
    
    # Extract text from each node
    chunks = [node.text for node in nodes]

    # for i, chunk in enumerate(chunks):
    #     print(f"Chunk {i}: {chunk}")  # Print first 50 characters of each chunk for debugging
    
    return chunks

# Step 2: Embed each chunk using SentenceTransformer
class TextEmbedder:
    def __init__(self, model_name="paraphrase-MiniLM-L6-v2"):  # Using a more compatible model
        self.model = SentenceTransformer(model_name)

    def embed_text(self, chunks):
        """
        Embeds a list of text chunks using SentenceTransformer.
        """
        return self.model.encode(chunks)

# Step 3: Group embeddings by category and document
def process_texts_nested_dict(categories_dict, chunk_size=20, chunk_overlap=20):
    """
    Processes a nested dictionary of documents organized by categories,
    splits them into chunks, embeds them, and preserves the hierarchy.
    
    Args:
        categories_dict: Nested dictionary where:
            - First level keys are categories (e.g., "bananas", "apples")
            - Second level keys are document IDs
            - Values contain 'doc_text' field
        chunk_size: Size of each chunk
        chunk_overlap: Overlap between chunks
        
    Returns:
        Nested dictionary with same structure but containing embeddings
    """
    embedder = TextEmbedder()
    grouped_embeddings = {}

    for category, documents in categories_dict.items():
        grouped_embeddings[category] = {}
        
        for doc_id, doc_data in documents.items():
            text = doc_data["doc_text"]
            chunks = split_text_with_sentence_splitter(text, chunk_size, chunk_overlap)
            if chunks:  # Make sure we have chunks before proceeding
                embeddings = embedder.embed_text(chunks)
                grouped_embeddings[category][doc_id] = np.array(embeddings)

    return grouped_embeddings

# Example usage
if __name__ == "__main__":
    categories_dict = {
        "bananas": {
            "doc1": {
                "doc_text": "This is the first example text. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks."
            },
            "doc2": {
                "doc_text": "Here is another text that will also be split and embedded. Here is another text that will also be split and embedded. Here is another text that will also be split and embedded."
            },
            "doc3": {
                "doc_text": "A third document with some content that will be processed into embeddings. A third document with some content that will be processed into embeddings."
            }
        },
        "apples": {
            "doc1": {
                "doc_text": "This is the first example text. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks. It is quite long and needs to be split into chunks."
            },
            "doc2": {
                "doc_text": "Here is another text that will also be split and embedded. Here is another text that will also be split and embedded. Here is another text that will also be split and embedded."
            }
        }
    }

    # Process texts with LlamaIndex chunker
    grouped_embeddings = process_texts_nested_dict(categories_dict)
    
    # Print information about the chunks
    for category, documents in grouped_embeddings.items():
        print(f"\nCategory: {category}")
        for doc_id, embeddings in documents.items():
            print(f"  {doc_id}: {len(embeddings)} chunks with embedding dimension {embeddings.shape[1]}")

Metadata length (0) is close to chunk size (20). Resulting chunks are less than 50 tokens. Consider increasing the chunk size or decreasing the size of your metadata to avoid this.
Metadata length (0) is close to chunk size (20). Resulting chunks are less than 50 tokens. Consider increasing the chunk size or decreasing the size of your metadata to avoid this.
Metadata length (0) is close to chunk size (20). Resulting chunks are less than 50 tokens. Consider increasing the chunk size or decreasing the size of your metadata to avoid this.
Metadata length (0) is close to chunk size (20). Resulting chunks are less than 50 tokens. Consider increasing the chunk size or decreasing the size of your metadata to avoid this.
Metadata length (0) is close to chunk size (20). Resulting chunks are less than 50 tokens. Consider increasing the chunk size or decreasing the size of your metadata to avoid this.

Category: bananas
  doc1: 6 chunks with embedding dimension 384
  doc2: 3 chunks with embeddin

In [18]:
grouped_embeddings

{'bananas': {'doc1': array([[-0.547143  ,  0.13099997, -0.29798925, ...,  0.3127594 ,
           0.7349615 ,  0.58360195],
         [-0.28905192,  0.11623957, -0.1416125 , ...,  0.17979482,
           0.4187386 ,  0.14725946],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233]], dtype=float32),
  'doc2': array([[-0.50558054,  0.26027247, -0.09289418, ...,  0.05310394,
           0.35045525,  0.15219437],
         [-0.39704907,  0.25346982,  0.10586175, ..., -0.00925015,
           0.19879848,  0.12182148],
         [-0.39704907,  0.25346982,  0.10586175, ..., -0.00925015,
           0.19879848,  0.12182148]], dtype=float32),
  '

# JSON

In [None]:
grouped_embeddings

In [None]:
pip install torch-geometric

In [None]:
master_embedding.size(0)  # Replace 0 with the dimension you want to check

In [25]:
import torch
import torch.nn as nn
import numpy as np
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

# Step 1: Define the GNN models
class ChunkGNN(nn.Module):
    """GNN for creating document-level embeddings from chunks"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ChunkGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        x = global_mean_pool(x, batch)  # Pooling to get document-level embedding
        x = self.fc(x)
        return x

class MasterGNN(nn.Module):
    """Graph neural network for aggregating document embeddings into a fixed-size master embedding"""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MasterGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.output_dim = output_dim
        
    def forward(self, document_embeddings):
        """
        Create a fixed-size master embedding from document embeddings using a graph approach
        
        Args:
            document_embeddings: Tensor of shape [num_documents, embedding_dim]
        
        Returns:
            Fixed-size master embedding regardless of number of input documents
        """
        # Create a single batch for all documents (all in one graph)
        batch = torch.zeros(document_embeddings.size(0), dtype=torch.long)
        
        # Create a fully connected graph between documents
        if document_embeddings.size(0) > 1:
            # Generate all pairs of indices for a fully connected graph
            edge_index = torch.combinations(torch.arange(document_embeddings.size(0)), r=2).t()
            # Make edges bidirectional
            edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        else:
            # Self-loop for a single document
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        
        # Process through GNN
        x = self.conv1(document_embeddings, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        x = torch.relu(x)
        
        # Global pooling to get fixed-size representation regardless of document count
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        
        return x

# Step 2: Prepare data for chunk-level GNN
def prepare_chunk_data(grouped_embeddings):
    """Convert grouped embeddings into PyTorch Geometric Data objects"""
    data_list = []
    for text_id, embeddings in grouped_embeddings.items():
        # Create node features
        x = torch.tensor(embeddings, dtype=torch.float)
        
        # Create a fully connected graph between chunks
        num_nodes = x.size(0)
        if num_nodes > 1:
            # Create bidirectional edges for better message passing
            edge_index = torch.combinations(torch.arange(num_nodes), r=2).t()
            edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        else:
            # Handle single node case with self-loop
            edge_index = torch.tensor([[0], [0]], dtype=torch.long)
            
        # Create a Data object
        data = Data(x=x, edge_index=edge_index)
        data_list.append(data)
    
    return data_list

# Step 3: Process grouped embeddings through the GNNs
def process_with_gnns(grouped_embeddings, chunk_gnn, master_gnn):
    """Process embeddings through GNNs to get a single fixed-size master embedding"""
    # Prepare chunk-level data
    data_list = prepare_chunk_data(grouped_embeddings)
    loader = DataLoader(data_list, batch_size=1, shuffle=False)
    
    # Process each document through chunk-level GNN
    document_embeddings = []
    for data in loader:
        document_embedding = chunk_gnn(data.x, data.edge_index, data.batch)
        document_embeddings.append(document_embedding.squeeze(0))  # Remove batch dimension
    
    # Stack document embeddings
    document_embeddings = torch.stack(document_embeddings)
    
    # Process through master GNN to get fixed-size embedding
    master_embedding = master_gnn(document_embeddings)
    
    return master_embedding

# # Example usage
# if __name__ == "__main__":
#     # For testing with different numbers of documents
#     test_cases = [
#         {"text_0": np.random.rand(5, 384)},  # 1 document
#         {"text_0": np.random.rand(5, 384), "text_1": np.random.rand(3, 384)},  # 2 documents
#         {"text_0": np.random.rand(5, 384), "text_1": np.random.rand(3, 384), 
#          "text_2": np.random.rand(4, 384), "text_3": np.random.rand(4, 384)}  # 4 documents
#     ]
    
#     # Define fixed embedding dimensions
#     embedding_dim = 384
#     output_dim = 64
    
#     # Initialize models
#     chunk_gnn = ChunkGNN(input_dim=embedding_dim, hidden_dim=128, output_dim=output_dim)
#     master_gnn = MasterGNN(input_dim=output_dim, hidden_dim=32, output_dim=16)
    
#     # Test with different numbers of documents
#     for i, grouped_embeddings in enumerate(test_cases):
#         print(f"\nTest case {i+1}: {len(grouped_embeddings)} documents")
        
#         # Process embeddings
#         master_embedding = process_with_gnns(grouped_embeddings, chunk_gnn, master_gnn)
        
#         # Verify shape is consistent
#         print(f"Master embedding shape: {master_embedding.shape}")

In [26]:
def convert_nested_dict_to_simple_format(nested_embeddings):
    """
    Converts a nested dictionary structure with categories to a dictionary where 
    each category maps to a simplified dictionary format compatible with process_with_gnns.
    
    Args:
        nested_embeddings: Nested dictionary with categories as first level
                          and documents as second level, containing numpy arrays
    
    Returns:
        Dictionary where each key is a category and each value is a simple dict
        with keys like "text_0", "text_1", etc. mapping to numpy arrays
    """
    converted_dict = {}
    
    for category, documents in nested_embeddings.items():
        print("hej")
        # Create a simple dictionary for this category
        simple_dict = {}
        
        # Convert document IDs to the text_N format
        for i, (doc_id, embeddings) in enumerate(documents.items()):

            simple_dict[f"text_{i}"] = embeddings
        
        # Store the simplified dictionary for this category
        converted_dict[category] = simple_dict
    
    return converted_dict

# Example usage
if __name__ == "__main__":
    # Convert the nested structure to the simplified format
    converted_embeddings = convert_nested_dict_to_simple_format(grouped_embeddings)
    
    # Process each category with the existing process_with_gnns function
    category_master_embeddings = {}
    
    for category, simple_dict in converted_embeddings.items():
        print(f"Processing category: {category} with {len(simple_dict)} documents")
        master_embedding = process_with_gnns(simple_dict, chunk_gnn, master_gnn)
        category_master_embeddings[category] = master_embedding
    
    # Print results for each category
    for category, embedding in category_master_embeddings.items():
        print(f"\nCategory: {category}")
        print(f"  Master embedding shape: {embedding.shape}")
        print(f"  First few values: {embedding.flatten()[:5].tolist()}")

hej
hej
Processing category: bananas with 3 documents
Processing category: apples with 2 documents

Category: bananas
  Master embedding shape: torch.Size([1, 16])
  First few values: [-0.16262401640415192, -0.03874002769589424, -0.07605313509702682, -0.04566062614321709, 0.05373515188694]

Category: apples
  Master embedding shape: torch.Size([1, 16])
  First few values: [-0.1765163689851761, -0.02803163416683674, -0.07624045014381409, -0.034578364342451096, 0.04556804150342941]


In [27]:
grouped_embeddings

{'bananas': {'doc1': array([[-0.547143  ,  0.13099997, -0.29798925, ...,  0.3127594 ,
           0.7349615 ,  0.58360195],
         [-0.28905192,  0.11623957, -0.1416125 , ...,  0.17979482,
           0.4187386 ,  0.14725946],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233],
         [-0.07625767, -0.05811209, -0.08879285, ..., -0.20441213,
           0.17478123, -0.22232233]], dtype=float32),
  'doc2': array([[-0.50558054,  0.26027247, -0.09289418, ...,  0.05310394,
           0.35045525,  0.15219437],
         [-0.39704907,  0.25346982,  0.10586175, ..., -0.00925015,
           0.19879848,  0.12182148],
         [-0.39704907,  0.25346982,  0.10586175, ..., -0.00925015,
           0.19879848,  0.12182148]], dtype=float32),
  '

# Model 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report
from tqdm.notebook import tqdm
import os

# Load all datasets
embeddings_dir = os.path.join(os.getcwd(), 'WOS_embeddings')
train_embeddings = np.load(os.path.join(embeddings_dir, 'train_embeddings.npy'))
val_embeddings = np.load(os.path.join(embeddings_dir, 'val_embeddings.npy'))
test_embeddings = np.load(os.path.join(embeddings_dir, 'test_embeddings.npy'))

# Load labels for all sets
train_labels_l1 = np.load(os.path.join(embeddings_dir, 'train_labels_l1.npy'))
train_labels_l2 = np.load(os.path.join(embeddings_dir, 'train_labels_l2.npy'))

val_labels_l1 = np.load(os.path.join(embeddings_dir, 'val_labels_l1.npy'))
val_labels_l2 = np.load(os.path.join(embeddings_dir, 'val_labels_l2.npy'))

test_labels_l1 = np.load(os.path.join(embeddings_dir, 'test_labels_l1.npy'))
test_labels_l2 = np.load(os.path.join(embeddings_dir, 'test_labels_l2.npy'))

# Check shapes
print(f"Training embeddings shape: {train_embeddings.shape}")
print(f"Validation embeddings shape: {val_embeddings.shape}")
print(f"Test embeddings shape: {test_embeddings.shape}")

# Get number of categories at each level
num_classes_l1 = len(np.unique(train_labels_l1))
num_classes_l2 = len(np.unique(train_labels_l2))

print(f"Number of L1 categories: {num_classes_l1}")
print(f"Number of L2 categories: {num_classes_l2}")

# 1. Create the hierarchical category graph with just 2 levels
def create_category_hierarchy_graph():
    """
    Create a graph representing the hierarchical structure of categories:
    Root -> L1 categories -> L2 categories
    """
    # Get unique combinations of categories at different levels
    unique_combinations = {}
    
    # Process training set to map hierarchical relationships
    for l1, l2 in zip(train_labels_l1, train_labels_l2):
        if l1 not in unique_combinations:
            unique_combinations[l1] = set()
        unique_combinations[l1].add(l2)
    
    # Create mappings and node features
    node_to_idx = {}
    idx_to_node = {}
    node_features = []
    
    # Root node (index 0)
    node_idx = 0
    node_to_idx[('root', 0)] = node_idx
    idx_to_node[node_idx] = ('root', 0)
    node_features.append([1.0, 0.0])  # One-hot for root level
    node_idx += 1
    
    # L1 categories
    for l1 in sorted(unique_combinations.keys()):
        node_to_idx[('l1', l1)] = node_idx
        idx_to_node[node_idx] = ('l1', l1)
        node_features.append([0.0, 1.0])  # One-hot for L1
        node_idx += 1
    
    # L2 categories (final categories)
    for l1 in sorted(unique_combinations.keys()):
        for l2 in sorted(unique_combinations[l1]):
            node_to_idx[('l2', l2)] = node_idx
            idx_to_node[node_idx] = ('l2', l2)
            node_features.append([0.0, 0.0])  # Special encoding for L2
            node_idx += 1
    
    # Create edges (root -> L1 -> L2)
    edges = []
    
    # Root to L1 connections
    for l1 in sorted(unique_combinations.keys()):
        edges.append((node_to_idx[('root', 0)], node_to_idx[('l1', l1)]))
        edges.append((node_to_idx[('l1', l1)], node_to_idx[('root', 0)]))  # Bidirectional
    
    # L1 to L2 connections
    for l1 in sorted(unique_combinations.keys()):
        for l2 in sorted(unique_combinations[l1]):
            edges.append((node_to_idx[('l1', l1)], node_to_idx[('l2', l2)]))
            edges.append((node_to_idx[('l2', l2)], node_to_idx[('l1', l1)]))  # Bidirectional
    
    # Convert to PyTorch Geometric format
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    return edge_index, node_features, node_to_idx, idx_to_node

# 2. Create the graph and initialize mappings
cat_edge_index, cat_node_features, node_to_idx, idx_to_node = create_category_hierarchy_graph()
num_cat_nodes = cat_node_features.shape[0]

print(f"Category hierarchy graph created with {num_cat_nodes} nodes and {cat_edge_index.shape[1]} edges")

# 3. Define the GNN model for category hierarchy embedding
class CategoryGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(CategoryGNN, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=4)
        self.conv2 = GATConv(hidden_channels * 4, out_channels, heads=1)
        
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# 4. Define the classifier that combines category and text embeddings
class HierarchicalClassifier(nn.Module):
    def __init__(self, text_dim, cat_embedding_dim, hidden_dim, num_classes_l1, num_classes_l2):
        super(HierarchicalClassifier, self).__init__()
        self.category_gnn = CategoryGNN(cat_node_features.shape[1], 32, cat_embedding_dim)
        
        # Fully connected layers for classification
        combined_dim = text_dim + cat_embedding_dim
        
        # Shared layers
        self.fc1 = nn.Linear(combined_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        
        # Output heads for hierarchical classification
        self.out_l1 = nn.Linear(hidden_dim // 2, num_classes_l1)
        self.out_l2 = nn.Linear(hidden_dim // 2, num_classes_l2)
        
        # Store category node mappings
        self.node_to_idx = node_to_idx
        
    def forward(self, text_embeddings, edge_index, node_features):
        # Get category embeddings from GNN
        cat_embeddings = self.category_gnn(node_features, edge_index)
        
        # For each text sample, combine with a learnable category embedding
        batch_size = text_embeddings.shape[0]
        combined_features = []
        
        for i in range(batch_size):
            # In a real implementation, we would use actual category info here
            # For now, we'll use a zero vector as a placeholder
            dummy_embedding = torch.zeros(cat_embeddings.shape[1], device=text_embeddings.device)
            combined = torch.cat([text_embeddings[i], dummy_embedding])
            combined_features.append(combined)
            
        combined_features = torch.stack(combined_features)
        
        # Shared layers
        x = F.relu(self.fc1(combined_features))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.3, training=self.training)
        
        # Output heads
        out_l1 = self.out_l1(x)
        out_l2 = self.out_l2(x)
        
        return out_l1, out_l2, cat_embeddings

# 5. Set up training and evaluation utilities
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Convert data to PyTorch tensors
cat_edge_index = cat_edge_index.to(device)
cat_node_features = cat_node_features.to(device)

train_x = torch.tensor(train_embeddings, dtype=torch.float).to(device)
train_y_l1 = torch.tensor(train_labels_l1, dtype=torch.long).to(device)
train_y_l2 = torch.tensor(train_labels_l2, dtype=torch.long).to(device)

val_x = torch.tensor(val_embeddings, dtype=torch.float).to(device)
val_y_l1 = torch.tensor(val_labels_l1, dtype=torch.long).to(device)
val_y_l2 = torch.tensor(val_labels_l2, dtype=torch.long).to(device)

test_x = torch.tensor(test_embeddings, dtype=torch.float).to(device)
test_y_l1 = torch.tensor(test_labels_l1, dtype=torch.long).to(device)
test_y_l2 = torch.tensor(test_labels_l2, dtype=torch.long).to(device)

# Initialize model
text_dim = train_embeddings.shape[1]
cat_embedding_dim = 64
hidden_dim = 256

model = HierarchicalClassifier(
    text_dim=text_dim, 
    cat_embedding_dim=cat_embedding_dim,
    hidden_dim=hidden_dim,
    num_classes_l1=num_classes_l1,
    num_classes_l2=num_classes_l2
).to(device)

# 6. Set up optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

# 7. Training function
def train_epoch(model, optimizer, batch_x, batch_y_l1, batch_y_l2):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    out_l1, out_l2, cat_embeddings = model(batch_x, cat_edge_index, cat_node_features)
    
    # Calculate loss for each level
    loss_l1 = F.cross_entropy(out_l1, batch_y_l1)
    loss_l2 = F.cross_entropy(out_l2, batch_y_l2)
    
    # Combined loss with weights
    total_loss = 0.4 * loss_l1 + 0.6 * loss_l2
    
    # Backward pass and optimize
    total_loss.backward()
    optimizer.step()
    
    return total_loss.item(), loss_l1.item(), loss_l2.item()

# 8. Evaluation function
def evaluate(model, x, y_l1, y_l2):
    model.eval()
    with torch.no_grad():
        # Forward pass
        out_l1, out_l2, _ = model(x, cat_edge_index, cat_node_features)
        
        # Get predictions
        pred_l1 = out_l1.argmax(dim=1).cpu().numpy()
        pred_l2 = out_l2.argmax(dim=1).cpu().numpy()
        
        # Convert labels to numpy for evaluation
        true_l1 = y_l1.cpu().numpy()
        true_l2 = y_l2.cpu().numpy()
        
        # Calculate accuracies
        acc_l1 = accuracy_score(true_l1, pred_l1)
        acc_l2 = accuracy_score(true_l2, pred_l2)
        
    return acc_l1, acc_l2, pred_l1, pred_l2

# 9. Training loop
epochs = 400
losses = []

print("\nTraining the hierarchical GNN model...")
for epoch in range(epochs):
    # Train on training data
    total_loss, loss_l1, loss_l2 = train_epoch(
        model, optimizer, train_x, train_y_l1, train_y_l2
    )
    
    losses.append(total_loss)
    
    if (epoch + 1) % 15 == 0 or epoch == 0:
        # Evaluate on validation data
        val_acc_l1, val_acc_l2, _, _ = evaluate(model, val_x, val_y_l1, val_y_l2)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, "
              f"Val Acc L1: {val_acc_l1:.4f}, L2: {val_acc_l2:.4f}")
        
        # Update learning rate based on validation performance
        scheduler.step(total_loss)

# 10. Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig(os.path.join(embeddings_dir, 'hierarchical_gnn_loss.png'))
plt.show()

# 11. Final evaluation on test set
test_acc_l1, test_acc_l2, pred_l1, pred_l2 = evaluate(model, test_x, test_y_l1, test_y_l2)

print("\nTest Results:")
print(f"Level 1 Accuracy: {test_acc_l1:.4f}")
print(f"Level 2 Accuracy: {test_acc_l2:.4f}")

# 12. Save the model
torch.save(model.state_dict(), os.path.join(embeddings_dir, 'hierarchical_gnn_model.pt'))
print(f"Model saved to {os.path.join(embeddings_dir, 'hierarchical_gnn_model.pt')}")

# 13. Detailed classification reports
print("\nClassification Report (Level 1):")
print(classification_report(test_y_l1.cpu().numpy(), pred_l1))

print("\nClassification Report (Level 2):")
print(classification_report(test_y_l2.cpu().numpy(), pred_l2))

# 14. Visualize the confusion matrix for Level 2 (optional)
from sklearn.metrics import confusion_matrix
import seaborn as sns

plt.figure(figsize=(12, 10))
cm = confusion_matrix(test_y_l2.cpu().numpy(), pred_l2)
sns.heatmap(cm, annot=False, fmt='d', cmap='Blues')
plt.title('Confusion Matrix for Level 2 Categories')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.savefig(os.path.join(embeddings_dir, 'l2_confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

# Combine the two models

In [None]:
# Combined architecture that integrates MasterGNN with the hierarchical classifier
class CombinedHierarchicalModel(nn.Module):
    def __init__(self, text_dim, cat_embedding_dim, hidden_dim, num_classes_l1, num_classes_l2,
                 chunk_input_dim=384, doc_embedding_dim=64, master_embedding_dim=16):
        super(CombinedHierarchicalModel, self).__init__()
        
        # Document hierarchy processing components
        self.chunk_gnn = ChunkGNN(input_dim=chunk_input_dim, 
                                  hidden_dim=128, 
                                  output_dim=doc_embedding_dim)
        
        self.master_gnn = MasterGNN(input_dim=doc_embedding_dim, 
                                    hidden_dim=32, 
                                    output_dim=master_embedding_dim)
        
        # Category hierarchy processing
        self.category_gnn = CategoryGNN(cat_node_features.shape[1], 32, cat_embedding_dim)
        
        # Fully connected layers for classification - combining text features,
        # category embeddings, and the master document embedding
        combined_dim = text_dim + cat_embedding_dim + master_embedding_dim
        
        # Shared layers
        self.fc1 = nn.Linear(combined_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        
        # Output heads for hierarchical classification
        self.out_l1 = nn.Linear(hidden_dim // 2, num_classes_l1)
        self.out_l2 = nn.Linear(hidden_dim // 2, num_classes_l2)
        
        # Store category node mappings
        self.node_to_idx = node_to_idx
        
    def process_documents(self, grouped_embeddings):
        """Process document embeddings through the document GNN pipeline"""
        return process_with_gnns(grouped_embeddings, self.chunk_gnn, self.master_gnn)
    
    def forward(self, text_embeddings, edge_index, node_features, raw_document_chunks=None, master_embedding=None):
        """
        Forward pass that optionally processes raw document chunks or uses pre-computed master embedding
        
        Args:
            text_embeddings: Base text embeddings for classification
            edge_index: Category graph edge indices
            node_features: Category graph node features
            raw_document_chunks: Raw document chunks to process through GNN pipeline (optional)
            master_embedding: Pre-computed master embedding (used if raw_document_chunks is None)
        """
        # Get category embeddings from GNN
        cat_embeddings = self.category_gnn(node_features, edge_index)
        
        # Process documents if provided, otherwise use the pre-computed master embedding
        if raw_document_chunks is not None:
            # Process documents through GNN pipeline
            doc_master_embedding = self.process_documents(raw_document_chunks)
        else:
            # Use pre-computed master embedding if provided
            doc_master_embedding = master_embedding
            
            # Ensure it's the right shape for batch processing
            if len(doc_master_embedding.shape) == 1:
                # If single embedding vector, reshape to batch size 1
                doc_master_embedding = doc_master_embedding.unsqueeze(0)
        
        # For each text sample, combine with category embedding and document master embedding
        batch_size = text_embeddings.shape[0]
        combined_features = []
        
        for i in range(batch_size):
            # In a real implementation, we would use actual category info
            # For now, we'll use a zero vector as a placeholder
            dummy_cat_embedding = torch.zeros(cat_embeddings.shape[1], device=text_embeddings.device)
            
            # Determine which document embedding to use
            if doc_master_embedding.shape[0] == 1:
                # If only one document embedding, use it for all samples
                doc_emb = doc_master_embedding.squeeze(0)
            else:
                # If we have batch-aligned document embeddings
                doc_emb = doc_master_embedding[i]
            
            # Combine text embedding with category embedding and document master embedding
            combined = torch.cat([
                text_embeddings[i], 
                dummy_cat_embedding, 
                doc_emb
            ])
            combined_features.append(combined)
        
        combined_features = torch.stack(combined_features)
        
        # Shared layers
        x = F.relu(self.fc1(combined_features))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p=0.3, training=self.training)
        
        # Output heads
        out_l1 = self.out_l1(x)
        out_l2 = self.out_l2(x)
        
        return out_l1, out_l2, cat_embeddings

# train the two models

In [None]:
# Example of how to use the combined model
if __name__ == "__main__":
    # Initialize the combined model
    text_dim = train_embeddings.shape[1]
    cat_embedding_dim = 64
    hidden_dim = 256
    
    combined_model = CombinedHierarchicalModel(
        text_dim=text_dim,
        cat_embedding_dim=cat_embedding_dim,
        hidden_dim=hidden_dim,
        num_classes_l1=num_classes_l1,
        num_classes_l2=num_classes_l2,
        chunk_input_dim=384,
        doc_embedding_dim=64,
        master_embedding_dim=16
    ).to(device)
    
    # Set up optimizer for combined model
    optimizer = torch.optim.Adam(combined_model.parameters(), lr=0.001, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Create master embeddings for each training sample (pre-compute once to save time)
    print("Pre-computing document master embeddings...")
    
    # This assumes you have raw document chunks for each training sample
    # In a real scenario, you would need to map each training sample to its chunks
    train_master_embeddings = []
    
    for i in range(len(train_x)):
        # Get document chunks for this sample (simplified example)
        # In reality, you would need to have a mapping from samples to documents
        sample_documents = raw_document_chunks[i]  # This is just a placeholder
        
        # Process through document GNN pipeline
        with torch.no_grad():
            master_embedding = combined_model.process_documents(sample_documents)
            train_master_embeddings.append(master_embedding)
    
    train_master_embeddings = torch.stack(train_master_embeddings).to(device)
    
    # Modified training function to use pre-computed master embeddings
    def train_epoch(model, optimizer, batch_x, batch_y_l1, batch_y_l2, master_embeddings):
        model.train()
        optimizer.zero_grad()
        
        # Forward pass with pre-computed master embeddings
        out_l1, out_l2, _ = model(
            batch_x, cat_edge_index, cat_node_features,
            master_embedding=master_embeddings
        )
        
        # Calculate losses
        loss_l1 = F.cross_entropy(out_l1, batch_y_l1)
        loss_l2 = F.cross_entropy(out_l2, batch_y_l2)
        total_loss = 0.4 * loss_l1 + 0.6 * loss_l2
        
        # Backward pass and optimize
        total_loss.backward()
        optimizer.step()
        
        return total_loss.item(), loss_l1.item(), loss_l2.item()
    
    # Training loop with master embeddings
    print("\nTraining the combined hierarchical GNN model...")
    epochs = 200
    losses = []
    
    for epoch in range(epochs):
        # Train on training data with master embeddings
        total_loss, loss_l1, loss_l2 = train_epoch(
            combined_model, optimizer, train_x, train_y_l1, train_y_l2, train_master_embeddings
        )
        
        losses.append(total_loss)
        
        if (epoch + 1) % 15 == 0 or epoch == 0:
            # Evaluate on validation data (would need val_master_embeddings)
            # This is just a placeholder - you would compute val_master_embeddings similar to train
            val_acc_l1, val_acc_l2, _, _ = evaluate_combined(
                combined_model, val_x, val_y_l1, val_y_l2, val_master_embeddings
            )
            
            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, "
                  f"Val Acc L1: {val_acc_l1:.4f}, L2: {val_acc_l2:.4f}")
            
            # Update learning rate based on validation performance
            scheduler.step(total_loss)