# Legal Knowledge Graph Dataset Usage

This notebook demonstrates how to use the Legal Knowledge Graph Dataset for training Graph Neural Networks (GNNs) on Ukrainian court document data.

## Overview

The dataset processes legal documents containing:
- **Entities**: Named entities extracted from documents (persons, organizations, etc.)
- **Relations**: Relationships between entities (e.g., "filed by", "represented by")
- **Legal References**: Optional legal code references for enhanced context

## Data Sources

The dataset can load data from:
1. **BigQuery**: Direct connection to Google Cloud BigQuery
2. **CSV**: Local CSV files with the same schema

## Features

- Automatic vocabulary building for entities, relations, and legal references
- Feature encoding for GNN training
- Proper batching with PyTorch Geometric compatibility
- Configurable graph size limits
- Legal reference integration

## 1. Setup and Imports

In [None]:
import sys
import os
sys.path.append('datasets')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import json
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 2. Import Dataset Classes

In [None]:
from legal_gnn_dataset import LegalKnowledgeGraphDataset, create_dataloader

print("Dataset classes imported successfully!")

## 3. Data Loading and Exploration

Let's start by loading data from CSV (if available) or demonstrate the BigQuery connection.

In [None]:
# Check if we have CSV data available
csv_file = 'document_data.csv'
if os.path.exists(csv_file):
    print(f"Found CSV file: {csv_file}")
    
    # Load a small sample to explore the data structure
    df_sample = pd.read_csv(csv_file, nrows=5)
    print("\nCSV Data Structure:")
    print(df_sample.columns.tolist())
    print("\nSample data:")
    print(df_sample.head())
    
    # Initialize dataset from CSV
    dataset_csv = LegalKnowledgeGraphDataset(
        data_source='csv',
        csv_file=csv_file,
        max_nodes=50,
        max_edges=100,
        include_legal_references=True,
        node_features_dim=128,
        edge_features_dim=64
    )
    
    print(f"\nCSV Dataset loaded with {len(dataset_csv)} samples")
    
else:
    print(f"CSV file {csv_file} not found. Will demonstrate with BigQuery setup.")
    dataset_csv = None

## 4. BigQuery Data Loading (Optional)

If you have Google Cloud credentials set up, you can load data directly from BigQuery.

In [None]:
# BigQuery setup (uncomment and configure if you have GCP credentials)
# import os
# os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'lab-test-project-1-305710-30eed237388b.json'
# os.environ['GOOGLE_CLOUD_PROJECT'] = 'lab-test-project-1-305710'
# 
# # Initialize dataset from BigQuery
# dataset_bq = LegalKnowledgeGraphDataset(
#     data_source='bigquery',
#     table_id='lab-test-project-1-305710.court_data_2022.processing_doc_links',
#     max_nodes=50,
#     max_edges=100,
#     include_legal_references=True,
#     node_features_dim=128,
#     edge_features_dim=64
# )
# 
# print(f"BigQuery Dataset loaded with {len(dataset_bq)} samples")

# For now, we'll use the CSV dataset if available
dataset = dataset_csv if dataset_csv is not None else None

if dataset is None:
    print("No dataset available. Please ensure you have either CSV data or BigQuery credentials configured.")
    print("\nTo use BigQuery:")
    print("1. Set GOOGLE_APPLICATION_CREDENTIALS environment variable")
    print("2. Set GOOGLE_CLOUD_PROJECT environment variable")
    print("3. Uncomment the BigQuery code above")
else:
    print(f"Using dataset with {len(dataset)} samples")

## 5. Dataset Exploration

Let's explore the dataset structure and statistics.

In [None]:
if dataset is not None:
    # Get vocabulary information
    vocab_info = dataset.get_vocabulary_info()
    print("Vocabulary Information:")
    for key, value in vocab_info.items():
        print(f"  {key}: {value}")
    
    # Explore a few samples
    print("\n" + "="*50)
    print("SAMPLE EXPLORATION")
    print("="*50)
    
    for i in range(min(3, len(dataset))):
        sample = dataset[i]
        print(f"\nSample {i+1}:")
        print(f"  Document ID: {sample['doc_id']}")
        print(f"  Number of nodes: {sample['num_nodes']}")
        print(f"  Number of edges: {sample['num_edges']}")
        print(f"  Triplets count: {sample['triplets_count']}")
        print(f"  Node features shape: {sample['node_features'].shape}")
        print(f"  Edge features shape: {sample['edge_features'].shape}")
        print(f"  Edge index shape: {sample['edge_index'].shape}")
    
    # Analyze data distribution
    print("\n" + "="*50)
    print("DATA DISTRIBUTION ANALYSIS")
    print("="*50)
    
    node_counts = []
    edge_counts = []
    triplet_counts = []
    
    for i in range(min(100, len(dataset))):  # Sample first 100
        sample = dataset[i]
        node_counts.append(sample['num_nodes'])
        edge_counts.append(sample['num_edges'])
        triplet_counts.append(sample['triplets_count'])
    
    print(f"Node count statistics:")
    print(f"  Mean: {np.mean(node_counts):.2f}")
    print(f"  Std: {np.std(node_counts):.2f}")
    print(f"  Min: {np.min(node_counts)}")
    print(f"  Max: {np.max(node_counts)}")
    
    print(f"\nEdge count statistics:")
    print(f"  Mean: {np.mean(edge_counts):.2f}")
    print(f"  Std: {np.std(edge_counts):.2f}")
    print(f"  Min: {np.min(edge_counts)}")
    print(f"  Max: {np.max(edge_counts)}")
    
    print(f"\nTriplet count statistics:")
    print(f"  Mean: {np.mean(triplet_counts):.2f}")
    print(f"  Std: {np.std(triplet_counts):.2f}")
    print(f"  Min: {np.min(triplet_counts)}")
    print(f"  Max: {np.max(triplet_counts)}")

## 6. Data Visualization

In [None]:
if dataset is not None:
    # Create visualizations
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Node count distribution
    axes[0, 0].hist(node_counts, bins=20, alpha=0.7, color='skyblue')
    axes[0, 0].set_title('Node Count Distribution')
    axes[0, 0].set_xlabel('Number of Nodes')
    axes[0, 0].set_ylabel('Frequency')
    
    # Edge count distribution
    axes[0, 1].hist(edge_counts, bins=20, alpha=0.7, color='lightgreen')
    axes[0, 1].set_title('Edge Count Distribution')
    axes[0, 1].set_xlabel('Number of Edges')
    axes[0, 1].set_ylabel('Frequency')
    
    # Triplet count distribution
    axes[1, 0].hist(triplet_counts, bins=20, alpha=0.7, color='salmon')
    axes[1, 0].set_title('Triplet Count Distribution')
    axes[1, 0].set_xlabel('Number of Triplets')
    axes[1, 0].set_ylabel('Frequency')
    
    # Scatter plot: nodes vs edges
    axes[1, 1].scatter(node_counts, edge_counts, alpha=0.6, color='purple')
    axes[1, 1].set_title('Nodes vs Edges')
    axes[1, 1].set_xlabel('Number of Nodes')
    axes[1, 1].set_ylabel('Number of Edges')
    
    plt.tight_layout()
    plt.show()
    
    # Vocabulary size information
    print(f"\nVocabulary Sizes:")
    print(f"  Unique entities (nodes): {vocab_info['node_vocab_size']}")
    print(f"  Unique relations (edges): {vocab_info['edge_vocab_size']}")
    if dataset.include_legal_references:
        print(f"  Unique legal references: {vocab_info['legal_ref_vocab_size']}")

## 7. Simple GNN Model Definition

Let's create a simple Graph Neural Network for legal document classification.

In [None]:
class SimpleLegalGNN(nn.Module):
    """Simple GNN for legal document classification"""
    
    def __init__(self, node_features_dim, edge_features_dim, hidden_dim=64, num_classes=2):
        super(SimpleLegalGNN, self).__init__()
        
        self.node_features_dim = node_features_dim
        self.edge_features_dim = edge_features_dim
        self.hidden_dim = hidden_dim
        
        # Node feature processing
        self.node_encoder = nn.Linear(node_features_dim, hidden_dim)
        
        # Edge feature processing
        self.edge_encoder = nn.Linear(edge_features_dim, hidden_dim)
        
        # Graph convolution layers
        self.conv1 = nn.Linear(hidden_dim, hidden_dim)
        self.conv2 = nn.Linear(hidden_dim, hidden_dim)
        
        # Global pooling and classification
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, node_features, edge_index, edge_features, batch_index):
        batch_size = node_features.size(0)
        max_nodes = node_features.size(1)
        
        # Process node features
        x = self.node_encoder(node_features.view(-1, node_features.size(-1)))
        x = x.view(batch_size, max_nodes, -1)
        
        # Simple message passing (simplified GNN)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        
        # Global pooling
        x = x.transpose(1, 2)  # [batch_size, hidden_dim, max_nodes]
        x = self.global_pool(x)  # [batch_size, hidden_dim, 1]
        x = x.squeeze(-1)  # [batch_size, hidden_dim]
        
        # Classification
        out = self.classifier(x)
        return out

print("SimpleLegalGNN model defined successfully!")

## 8. DataLoader Creation and Training Setup

In [None]:
if dataset is not None:
    # Create DataLoader
    batch_size = 8
    dataloader = create_dataloader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    
    print(f"DataLoader created with batch size: {batch_size}")
    print(f"Number of batches: {len(dataloader)}")
    
    # Test a batch
    print("\nTesting batch structure:")
    for batch in dataloader:
        print(f"Batch keys: {batch.keys()}")
        print(f"Node features shape: {batch['node_features'].shape}")
        print(f"Edge features shape: {batch['edge_features'].shape}")
        print(f"Edge index shape: {batch['edge_index'].shape}")
        print(f"Batch index shape: {batch['batch_index'].shape}")
        print(f"Triplets counts: {batch['triplets_counts']}")
        break
    
    # Initialize model
    vocab_info = dataset.get_vocabulary_info()
    model = SimpleLegalGNN(
        node_features_dim=vocab_info['node_features_dim'],
        edge_features_dim=vocab_info['edge_features_dim'] * (2 if dataset.include_legal_references else 1),
        hidden_dim=64,
        num_classes=2
    ).to(device)
    
    print(f"\nModel initialized:")
    print(f"  Node features dim: {vocab_info['node_features_dim']}")
    print(f"  Edge features dim: {vocab_info['edge_features_dim'] * (2 if dataset.include_legal_references else 1)}")
    print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Training setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    print("\nTraining setup completed!")

## 9. Training Loop

In [None]:
if dataset is not None:
    # Training loop
    num_epochs = 5
    model.train()
    
    print(f"Starting training for {num_epochs} epochs...")
    
    training_losses = []
    
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        
        for batch_idx, batch in enumerate(dataloader):
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(
                batch['node_features'],
                batch['edge_index'],
                batch['edge_features'],
                batch['batch_index']
            )
            
            # Create labels (simple binary: has triplets or not)
            labels = (batch['triplets_counts'] > 0).long()
            
            # Loss
            loss = criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Print progress every 5 batches
            if batch_idx % 5 == 0:
                print(f"  Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
        
        avg_loss = total_loss / num_batches
        training_losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    print("Training completed!")
    
    # Plot training loss
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), training_losses, 'b-o')
    plt.title('Training Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    plt.show()

## 10. Model Evaluation

In [None]:
if dataset is not None:
    # Evaluation
    model.eval()
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    print("Evaluating model...")
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            # Move batch to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            outputs = model(
                batch['node_features'],
                batch['edge_index'],
                batch['edge_features'],
                batch['batch_index']
            )
            
            _, predicted = torch.max(outputs.data, 1)
            labels = (batch['triplets_counts'] > 0).long()
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            # Only evaluate first few batches for demo
            if batch_idx >= 5:
                break
    
    accuracy = 100 * correct / total
    print(f"\nEvaluation Results:")
    print(f"  Accuracy: {accuracy:.2f}% ({correct}/{total})")
    
    # Confusion matrix
    from sklearn.metrics import confusion_matrix, classification_report
    
    cm = confusion_matrix(all_labels, all_predictions)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['No Triplets', 'Has Triplets'],
                yticklabels=['No Triplets', 'Has Triplets'])
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_predictions, 
                              target_names=['No Triplets', 'Has Triplets']))

## 11. Advanced Usage Examples

In [None]:
# Example 1: Different dataset configurations
print("Example 1: Different Dataset Configurations")
print("="*50)

if dataset is not None:
    # Smaller graph size
    dataset_small = LegalKnowledgeGraphDataset(
        data_source='csv',
        csv_file=csv_file,
        max_nodes=20,
        max_edges=40,
        include_legal_references=False,  # Disable legal references
        node_features_dim=64,
        edge_features_dim=32
    )
    
    print(f"Small dataset: {len(dataset_small)} samples")
    print(f"Max nodes: {dataset_small.max_nodes}")
    print(f"Max edges: {dataset_small.max_edges}")
    print(f"Include legal references: {dataset_small.include_legal_references}")
    
    # Test a sample from small dataset
    sample = dataset_small[0]
    print(f"Sample node features shape: {sample['node_features'].shape}")
    print(f"Sample edge features shape: {sample['edge_features'].shape}")
    
    # Larger graph size
    dataset_large = LegalKnowledgeGraphDataset(
        data_source='csv',
        csv_file=csv_file,
        max_nodes=100,
        max_edges=200,
        include_legal_references=True,
        node_features_dim=256,
        edge_features_dim=128
    )
    
    print(f"\nLarge dataset: {len(dataset_large)} samples")
    print(f"Max nodes: {dataset_large.max_nodes}")
    print(f"Max edges: {dataset_large.max_edges}")
    print(f"Node features dim: {dataset_large.node_features_dim}")
    print(f"Edge features dim: {dataset_large.edge_features_dim}")

In [None]:
# Example 2: Custom transform function
print("\nExample 2: Custom Transform Function")
print("="*50)

def custom_transform(sample):
    """Custom transform to add noise to node features"""
    # Add small random noise to node features
    noise = torch.randn_like(sample['node_features']) * 0.01
    sample['node_features'] = sample['node_features'] + noise
    return sample

if dataset is not None:
    dataset_with_transform = LegalKnowledgeGraphDataset(
        data_source='csv',
        csv_file=csv_file,
        max_nodes=50,
        max_edges=100,
        include_legal_references=True,
        node_features_dim=128,
        edge_features_dim=64,
        transform=custom_transform
    )
    
    print(f"Dataset with custom transform: {len(dataset_with_transform)} samples")
    
    # Test the transform
    sample_original = dataset[0]
    sample_transformed = dataset_with_transform[0]
    
    print(f"Original node features norm: {torch.norm(sample_original['node_features']):.4f}")
    print(f"Transformed node features norm: {torch.norm(sample_transformed['node_features']):.4f}")
    print(f"Difference: {torch.norm(sample_transformed['node_features'] - sample_original['node_features']):.4f}")

## 12. Summary and Next Steps

This notebook demonstrated:

1. **Dataset Loading**: How to load data from CSV or BigQuery
2. **Data Exploration**: Understanding the structure and statistics of legal document graphs
3. **Model Training**: Training a simple GNN for document classification
4. **Evaluation**: Assessing model performance with metrics and visualizations
5. **Advanced Usage**: Different configurations and custom transforms

### Next Steps:

- **Advanced GNN Architectures**: Implement more sophisticated GNN layers (GCN, GAT, GraphSAGE)
- **Multi-task Learning**: Train on multiple objectives (entity recognition, relation extraction, document classification)
- **Hyperparameter Tuning**: Use Optuna or similar tools for optimization
- **Model Interpretability**: Analyze which entities and relations are most important for predictions
- **Production Deployment**: Save and serve models for real-world applications

### Key Features of the Dataset:

- **Flexible Data Sources**: BigQuery and CSV support
- **Automatic Vocabulary Building**: Handles entity and relation vocabularies
- **Legal Reference Integration**: Optional legal code references for enhanced context
- **Configurable Graph Sizes**: Adjustable node and edge limits
- **PyTorch Geometric Compatibility**: Easy integration with advanced GNN libraries
- **Proper Batching**: Handles variable-sized graphs with padding

The dataset is designed to be flexible and extensible for various legal document analysis tasks.