In [None]:
import os
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-{torch._version_}.html
import torch_geometric
print("torch_geometric version:", torch_geometric._version_)
from torch_geometric.data import Data, HeteroData
from torch_geometric.nn import GCNConv, SAGEConv, HeteroConv, GATConv, GAT
from torch_geometric.nn import DataLoader, NeighborLoader
from torch_geometric.utils import to_networkx, to_dense_adj
import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
from sklearn.manifold import TSNE
from transformers import CLIPModel, CLIPProcessor, BertModel, BertTokenizer
from PIL import Image
import requests
from io import BytesIO
from tqdm import tqdm
import time
import seaborn as sns
from collections import Counter
import copy

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

class KGBenchExplorer:
    def _init_(self, data_path):
        """
        Initialize the explorer with the data path

        Args:
            data_path: Path to the dmg777k dataset folder
        """
        self.data_path = data_path
        self.train_data = None
        self.valid_data = None
        self.test_data = None
        self.entity_types = None
        self.relation_types = None
        self.node_features = None
        self.e2i = None
        self.i2e = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def load_data(self):
        """Load the dataset files"""
        with open(os.path.join(self.data_path, 'train.json'), 'r') as f:
            self.train_data = json.load(f)
        with open(os.path.join(self.data_path, 'valid.json'), 'r') as f:
            self.valid_data = json.load(f)
        with open(os.path.join(self.data_path, 'test.json'), 'r') as f:
            self.test_data = json.load(f)
        self.e2i = self.train_data.get('e2i', {})
        self.i2e = self.train_data.get('i2e', {})
        print(f"Loaded dataset with {len(self.e2i)} entities")
        self._extract_entity_types()
        return True

    def _extract_entity_types(self):
        """Extract entity types from URIs or other identifiers"""
        self.entity_type_map = {}
        for entity, entity_id in self.e2i.items():
            if isinstance(entity, str):
                parts = entity.split('/')
                if len(parts) > 2:
                    entity_type = parts[-2]
                elif '#' in entity:
                    entity_type = entity.split('#')[-2]
                else:
                    entity_type = 'unknown'
            else:
                entity_type = 'unknown'
            self.entity_type_map[entity_id] = entity_type

    def analyze_entity_types(self):
        """Analyze and plot the types and number of entities (nodes)"""
        if not self.train_data:
            self.load_data()
        entity_types = {}
        for entity, entity_id in self.e2i.items():
            if isinstance(entity, str):
                if entity.startswith('http://'):
                    parts = entity.split('/')
                    if len(parts) > 2:
                        entity_type = parts[-2]
                    else:
                        entity_type = 'unknown'
                elif '#' in entity:
                    entity_type = entity.split('#')[-2]
                else:
                    entity_type = 'unknown'
            else:
                entity_type = 'unknown'
            entity_types[entity_type] = entity_types.get(entity_type, 0) + 1
        self.entity_types = entity_types

        plt.figure(figsize=(14, 8))
        sorted_types = sorted(entity_types.items(), key=lambda x: x[1], reverse=True)
        labels = [t[0] for t in sorted_types]
        values = [t[1] for t in sorted_types]
        if len(labels) > 15:
            labels = labels[:14] + ['Others']
            values = values[:14] + [sum(values[14:])]
        bars = plt.bar(labels, values, color=plt.cm.viridis(np.linspace(0, 1, len(labels))))
        plt.title('Distribution of Entity Types', fontsize=16)
        plt.xlabel('Entity Type', fontsize=12)
        plt.ylabel('Count', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1, f'{int(height):,}', ha='center', va='bottom', fontsize=9)
        plt.tight_layout()
        plt.savefig('entity_types_distribution.png', dpi=300)
        plt.close()

        plt.figure(figsize=(10, 10))
        if len(labels) > 7:
            pie_labels = labels[:6] + ['Others']
            pie_values = values[:6] + [sum(values[6:])]
        else:
            pie_labels = labels
            pie_values = values
        plt.pie(pie_values, labels=pie_labels, autopct='%1.1f%%', startangle=140, colors=plt.cm.tab20(np.linspace(0, 1, len(pie_labels))))
        plt.axis('equal')
        plt.title('Entity Types Distribution', fontsize=16)
        plt.savefig('entity_types_pie.png', dpi=300)
        plt.close()

        print("\n=== Entity Types Analysis ===")
        print(f"Total unique entity types: {len(entity_types)}")
        print("Top 5 entity types:")
        for i, (type_name, count) in enumerate(sorted_types[:5]):
            print(f"{i+1}. {type_name}: {count:,} entities ({count/sum(values)*100:.1f}%)")
        return entity_types

    def analyze_edge_types(self):
        """Analyze and plot the types and number of edges"""
        if not self.train_data:
            self.load_data()
        edge_types = {}
        for s, p, o in self.train_data.get('triples', []):
            edge_types[p] = edge_types.get(p, 0) + 1
        self.relation_types = edge_types

        plt.figure(figsize=(16, 10))
        sorted_types = sorted(edge_types.items(), key=lambda x: x[1], reverse=True)
        labels = [t[0] for t in sorted_types]
        values = [t[1] for t in sorted_types]
        if len(labels) > 15:
            labels = labels[:14] + ['Others']
            values = values[:14] + [sum(values[14:])]
        bars = plt.bar(labels, values, color=plt.cm.cool(np.linspace(0, 1, len(labels))))
        plt.title('Distribution of Edge Types', fontsize=16)
        plt.xlabel('Edge Type', fontsize=12)
        plt.ylabel('Count', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1, f'{int(height):,}', ha='center', va='bottom', fontsize=9)
        plt.tight_layout()
        plt.savefig('edge_types_distribution.png', dpi=300)
        plt.close()

        plt.figure(figsize=(10, 10))
        if len(labels) > 7:
            pie_labels = labels[:6] + ['Others']
            pie_values = values[:6] + [sum(values[6:])]
        else:
            pie_labels = labels
            pie_values = values
        plt.pie(pie_values, labels=pie_labels, autopct='%1.1f%%', startangle=140, colors=plt.cm.tab20(np.linspace(0, 1, len(pie_labels))))
        plt.axis('equal')
        plt.title('Edge Types Distribution', fontsize=16)
        plt.savefig('edge_types_pie.png', dpi=300)
        plt.close()

        print("\n=== Edge Types Analysis ===")
        print(f"Total unique edge types: {len(edge_types)}")
        print("Top 5 edge types:")
        for i, (type_name, count) in enumerate(sorted_types[:5]):
            print(f"{i+1}. {type_name}: {count:,} edges ({count/sum(values)*100:.1f}%)")
        return edge_types

    def analyze_node_features(self):
        """Analyze what type of information is present per node"""
        if not self.train_data:
            self.load_data()
        text_features = {entity_id: len(text.split()) for entity_id, text in self.train_data.get('texts', {}).items()}
        image_features = {entity_id: 1 for entity_id in self.train_data.get('images', {}).keys()}

        if text_features:
            plt.figure(figsize=(12, 6))
            text_lengths = list(text_features.values())
            plt.hist(text_lengths, bins=50, alpha=0.75, color='skyblue', edgecolor='black')
            plt.title('Distribution of Text Feature Lengths', fontsize=16)
            plt.xlabel('Text Length (words)', fontsize=12)
            plt.ylabel('Frequency', fontsize=12)
            plt.grid(alpha=0.3)
            stats_text = f"Min: {min(text_lengths)}\nMax: {max(text_lengths)}\nMean: {np.mean(text_lengths):.1f}\nMedian: {np.median(text_lengths):.1f}"
            plt.annotate(stats_text, xy=(0.75, 0.75), xycoords='axes fraction', bbox=dict(boxstyle="round,pad=0.5", fc="white", alpha=0.8))
            plt.savefig('text_length_distribution.png', dpi=300)
            plt.close()

        if text_features or image_features:
            plt.figure(figsize=(10, 10))
            text_ids = set(text_features.keys())
            image_ids = set(image_features.keys())
            text_only = len(text_ids - image_ids)
            image_only = len(image_ids - text_ids)
            both = len(text_ids.intersection(image_ids))
            plt.bar(['Text Only', 'Image Only', 'Both'], [text_only, image_only, both], color=['skyblue', 'lightgreen', 'salmon'])
            plt.title('Distribution of Feature Modalities', fontsize=16)
            plt.ylabel('Number of Nodes', fontsize=12)
            for i, v in enumerate([text_only, image_only, both]):
                plt.text(i, v + 0.1, f"{v:,}", ha='center')
            plt.savefig('modality_distribution.png', dpi=300)
            plt.close()

        self.node_features = {'text': text_features, 'image': image_features}
        print("\n=== Node Features Analysis ===")
        print(f"Nodes with text features: {len(text_features):,}")
        print(f"Nodes with image features: {len(image_features):,}")
        print(f"Nodes with both text and image: {len(set(text_features.keys()).intersection(set(image_features.keys()))):,}")
        if text_features:
            print(f"\nText feature statistics:")
            print(f"  Min length: {min(text_features.values())} words")
            print(f"  Max length: {max(text_features.values())} words")
            print(f"  Mean length: {np.mean(list(text_features.values())):.1f} words")
            print(f"  Median length: {np.median(list(text_features.values())):.1f} words")
        return self.node_features

    def analyze_label_distribution(self):
        """Analyze the distribution of node labels"""
        if not self.train_data:
            self.load_data()
        labels = self.train_data.get('labels', {})
        if not labels:
            print("No labels found in the dataset.")
            return None
        label_counts = Counter(labels.values())

        plt.figure(figsize=(14, 8))
        sorted_labels = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)
        label_names = [l[0] for l in sorted_labels]
        label_values = [l[1] for l in sorted_labels]
        if len(label_names) > 20:
            label_names = label_names[:19] + ['Others']
            label_values = label_values[:19] + [sum(label_values[19:])]
        bars = plt.bar(label_names, label_values, color=plt.cm.plasma(np.linspace(0, 1, len(label_names))))
        plt.title('Distribution of Node Labels', fontsize=16)
        plt.xlabel('Label', fontsize=12)
        plt.ylabel('Count', fontsize=12)
        plt.xticks(rotation=45, ha='right')
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1, f'{int(height):,}', ha='center', va='bottom', fontsize=9)
        plt.tight_layout()
        plt.savefig('label_distribution.png', dpi=300)
        plt.close()

        print("\n=== Label Distribution Analysis ===")
        print(f"Total labeled nodes: {sum(label_counts.values()):,}")
        print(f"Number of unique labels: {len(label_counts)}")
        print("Top 5 most common labels:")
        for i, (label, count) in enumerate(sorted_labels[:5]):
            print(f"{i+1}. {label}: {count:,} nodes ({count/sum(label_values)*100:.1f}%)")
        imbalance_ratio = max(label_values) / min(label_values)
        print(f"\nClass imbalance ratio (max/min): {imbalance_ratio:.2f}")
        return label_counts

    def visualize_graph_sample(self, max_nodes=150):
        """Visualize a small sample of the graph with node types"""
        if not self.train_data:
            self.load_data()
        G = nx.DiGraph()
        triples = self.train_data.get('triples', [])
        sampled_triples = []
        if triples:
            start_triple = random.choice(triples)
            sampled_triples.append(start_triple)
            included_nodes = {start_triple[0], start_triple[2]}
            candidates = [t for t in triples if t[0] in included_nodes or t[2] in included_nodes]
            while len(included_nodes) < max_nodes and candidates:
                new_triple = random.choice(candidates)
                if new_triple not in sampled_triples:
                    sampled_triples.append(new_triple)
                    included_nodes.add(new_triple[0])
                    included_nodes.add(new_triple[2])
                candidates = [t for t in triples if (t[0] in included_nodes or t[2] in included_nodes) and t not in sampled_triples]
                if len(candidates) > 1000:
                    candidates = random.sample(candidates, 1000)
        if len(sampled_triples) < min(50, len(triples)):
            sampled_triples = random.sample(triples, min(len(triples), max_nodes * 2))

        node_modalities = {}
        for s, p, o in sampled_triples:
            G.add_edge(s, o, relation=p)
            for node in [s, o]:
                if str(node) not in node_modalities:
                    node_modalities[str(node)] = []
                if str(node) in self.train_data.get('texts', {}):
                    if 'text' not in node_modalities[str(node)]:
                        node_modalities[str(node)].append('text')
                if str(node) in self.train_data.get('images', {}):
                    if 'image' not in node_modalities[str(node)]:
                        node_modalities[str(node)].append('image')

        node_colors = []
        for node in G.nodes():
            modalities = node_modalities.get(str(node), [])
            if 'text' in modalities and 'image' in modalities:
                node_colors.append('purple')
            elif 'text' in modalities:
                node_colors.append('blue')
            elif 'image' in modalities:
                node_colors.append('green')
            else:
                node_colors.append('gray')

        node_labels = {node: self.i2e.get(str(node), str(node)).split('/')[-1] if len(G.nodes()) <= 50 else '' for node in G.nodes()}

        plt.figure(figsize=(14, 14))
        pos = nx.spring_layout(G, seed=42)
        nx.draw_networkx_nodes(G, pos, node_size=100, node_color=node_colors, alpha=0.8)
        nx.draw_networkx_edges(G, pos, width=0.5, alpha=0.5, arrows=True, arrowsize=10)
        if len(G.nodes()) <= 50:
            nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8)
        legend_elements = [
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='purple', markersize=10, label='Text & Image'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Text Only'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=10, label='Image Only'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=10, label='No Features')
        ]
        plt.legend(handles=legend_elements, loc='upper right')
        plt.title(f'Sample of Graph Structure ({len(G.nodes())} nodes, {len(G.edges())} edges)', fontsize=16)
        plt.axis('off')
        plt.savefig('graph_sample.png', dpi=300)
        plt.close()

        print("\n=== Graph Sample Visualization ===")
        print(f"Sample contains {len(G.nodes())} nodes and {len(G.edges())} edges")
        modality_counts = Counter([tuple(node_modalities.get(str(node), [])) for node in G.nodes()])
        print("Node modality distribution in sample:")
        for modality_tuple, count in modality_counts.items():
            modality_str = ', '.join(modality_tuple) if modality_tuple else 'No features'
            print(f"  {modality_str}: {count} nodes")
        avg_degree = sum(dict(G.degree()).values()) / len(G.nodes())
        print(f"Average node degree: {avg_degree:.2f}")
        try:
            density = nx.density(G)
            print(f"Graph density: {density:.6f}")
        except:
            print("Could not calculate graph density.")
        connected_components = list(nx.weakly_connected_components(G))
        print(f"Number of weakly connected components: {len(connected_components)}")
        return G

    def analyze_triples(self):
        """Analyze how the triples attribute defines graphs"""
        if not self.train_data:
            self.load_data()
        triples = self.train_data.get('triples', [])
        subjects = set([s for s, _, _ in triples])
        predicates = set([p for _, p, _ in triples])
        objects = set([o for _, _, o in triples])
        subject_degrees = Counter([s for s, _, _ in triples])
        object_degrees = Counter([o for _, _, o in triples])
        node_degrees = Counter(subject_degrees) + Counter(object_degrees)

        plt.figure(figsize=(10, 6))
        degrees = list(node_degrees.values())
        plt.hist(degrees, bins=50, alpha=0.75, color='teal', log=True, edgecolor='black')
        plt.title('Node Degree Distribution (Log Scale)', fontsize=16)
        plt.xlabel('Node Degree', fontsize=12)
        plt.ylabel('Frequency (Log Scale)', fontsize=12)
        plt.grid(alpha=0.3)
        plt.savefig('node_degree_distribution.png', dpi=300)
        plt.close()

        print("\n=== Triple Analysis ===")
        print(f"Total triples: {len(triples):,}")
        print(f"Unique subjects: {len(subjects):,}")
        print(f"Unique predicates: {len(predicates):,}")
        print(f"Unique objects: {len(objects):,}")
        predicate_counts = Counter([p for _, p, _ in triples])
        print("\nTop 5 most common predicates:")
        for i, (pred, count) in enumerate(predicate_counts.most_common(5)):
            print(f"{i+1}. {pred}: {count:,} occurrences")
        print("\nNode degree statistics:")
        print(f"  Minimum degree: {min(degrees)}")
        print(f"  Maximum degree: {max(degrees)}")
        print(f"  Average degree: {np.mean(degrees):.2f}")
        print(f"  Median degree: {np.median(degrees)}")
        print("\nSample triples (subject, predicate, object):")
        for s, p, o in random.sample(triples, min(5, len(triples))):
            s_entity = self.i2e.get(str(s), f"Entity {s}")
            o_entity = self.i2e.get(str(o), f"Entity {o}")
            print(f"  {s_entity} --[{p}]--> {o_entity}")
        return {
            'subjects': len(subjects),
            'predicates': len(predicates),
            'objects': len(objects),
            'triples': len(triples),
            'degree_stats': {'min': min(degrees), 'max': max(degrees), 'mean': np.mean(degrees), 'median': np.median(degrees)}
        }

    def analyze_e2i_i2e(self):
        """Analyze the e2i and i2e mappings"""
        if not self.train_data:
            self.load_data()
        e2i = self.train_data.get('e2i', {})
        i2e = self.train_data.get('i2e', {})
        print("\n=== Entity to ID Mapping Analysis ===")
        print(f"Total mappings: {len(e2i):,}")
        uri_patterns = Counter('/'.join(entity.split('/')[:3]) if '/' in entity else entity.split('#')[0] if '#' in entity else 'other' for entity in e2i.keys() if isinstance(entity, str))
        print("\nMost common URI patterns:")
        for i, (pattern, count) in enumerate(uri_patterns.most_common(5)):
            print(f"{i+1}. {pattern}: {count:,} entities")
        print("\nSample e2i mappings:")
        samples = random.sample(list(e2i.items()), min(5, len(e2i)))
        for entity, entity_id in samples:
            print(f"  {entity} -> {entity_id}")
        return {'e2i_count': len(e2i), 'i2e_count': len(i2e), 'uri_patterns': dict(uri_patterns)}

    def run_complete_eda(self):
        """Run all EDA functions and print a summary"""
        print("=== Starting Exploratory Data Analysis of dmg777k Dataset ===")
        self.load_data()
        entity_types = self.analyze_entity_types()
        edge_types = self.analyze_edge_types()
        node_features = self.analyze_node_features()
        label_distribution = self.analyze_label_distribution()
        triple_stats = self.analyze_triples()
        e2i_stats = self.analyze_e2i_i2e()
        graph_sample = self.visualize_graph_sample()

        with open('eda_report.txt', 'w') as f:
            f.write("=== dmg777k Dataset Analysis Report ===\n\n")
            f.write(f"Total entities: {len(self.e2i):,}\n")
            f.write(f"Entity types: {len(entity_types)}\n")
            f.write(f"Edge types: {len(edge_types)}\n")
            f.write(f"Nodes with text: {len(node_features['text']):,}\n")
            f.write(f"Nodes with images: {len(node_features['image']):,}\n")
            f.write(f"Total triples: {triple_stats['triples']:,}\n")
            f.write("\nNode degree statistics:\n")
            for key, value in triple_stats['degree_stats'].items():
                f.write(f"  {key}: {value}\n")
            if label_distribution:
                f.write(f"\nNumber of labeled nodes: {sum(label_distribution.values()):,}\n")
                f.write(f"Number of unique labels: {len(label_distribution)}\n")
        print("\n=== EDA Report ===")
        print("EDA report saved to 'eda_report.txt'.")
        return True

class MultimodalGNN(nn.Module):
    def _init_(self, num_nodes, hidden_dim, num_classes, text_embeddings=None, image_embeddings=None):
        """Multimodal GNN for node classification"""
        super(MultimodalGNN, self)._init_()
        self.node_emb = nn.Embedding(num_nodes, hidden_dim)
        self.has_text = text_embeddings is not None
        if self.has_text:
            text_dim = text_embeddings.shape[1]
            self.text_projection = nn.Linear(text_dim, hidden_dim)
            self.register_buffer('text_embeddings', torch.from_numpy(text_embeddings).float())
            self.text_nodes = torch.arange(len(text_embeddings))
        self.has_image = image_embeddings is not None
        if self.has_image:
            img_dim = image_embeddings.shape[1]
            self.image_projection = nn.Linear(img_dim, hidden_dim)
            self.register_buffer('image_embeddings', torch.from_numpy(image_embeddings).float())
            self.image_nodes = torch.arange(len(image_embeddings))
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, edge_index):
        """Forward pass through the GNN"""
        x = self.node_emb.weight
        if self.has_text:
            text_proj = self.text_projection(self.text_embeddings)
            x[self.text_nodes] = x[self.text_nodes] + text_proj
        if self.has_image:
            image_proj = self.image_projection(self.image_embeddings)
            x[self.image_nodes] = x[self.image_nodes] + image_proj
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv2(x, edge_index))
        node_logits = self.classifier(x)
        return node_logits

def extract_features(kg_explorer):
    """Extract features from text and images using pre-trained models"""
    print("Extracting features from text and images...")
    bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased').to(kg_explorer.device)
    bert_model.eval()
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(kg_explorer.device)
    clip_model.eval()

    text_features = {}
    for entity_id, text in tqdm(kg_explorer.train_data.get('texts', {}).items()):
        if len(text) > 512:
            text = text[:512]
        with torch.no_grad():
            inputs = bert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(kg_explorer.device)
            outputs = bert_model(**inputs)
            text_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            text_features[entity_id] = text_embedding[0]

    image_features = {}
    for entity_id, image_path in tqdm(kg_explorer.train_data.get('images', {}).items()):
        try:
            if image_path.startswith('http'):
                response = requests.get(image_path)
                image = Image.open(BytesIO(response.content)).convert('RGB')
            else:
                image = Image.open(os.path.join(kg_explorer.data_path, image_path)).convert('RGB')
            with torch.no_grad():
                inputs = clip_processor(images=image, return_tensors="pt").to(kg_explorer.device)
                outputs = clip_model.get_image_features(**inputs)
                image_embedding = outputs.cpu().numpy()
                image_features[entity_id] = image_embedding[0]
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")

    print(f"Extracted features for {len(text_features)} texts and {len(image_features)} images")
    return text_features, image_features

def prepare_graph_data(kg_explorer, text_features=None, image_features=None):
    """Prepare graph data for GNN training"""
    print("Preparing graph data for GNN...")
    triples = kg_explorer.train_data.get('triples', [])
    src_nodes = [s for s, _, _ in triples]
    dst_nodes = [o for _, _, o in triples]
    edge_index = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)

    labels_dict = kg_explorer.train_data.get('labels', {})
    unique_labels = sorted(set(labels_dict.values()))
    label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
    num_nodes = len(kg_explorer.e2i)
    labels = torch.zeros(num_nodes, dtype=torch.long)
    for node_id, label in labels_dict.items():
        labels[int(node_id)] = label_to_idx[label]

    labeled_nodes = list(labels_dict.keys())
    if kg_explorer.valid_data and kg_explorer.test_data:
        train_idx = torch.tensor([int(idx) for idx in kg_explorer.train_data.get('nodes', [])])
        valid_idx = torch.tensor([int(idx) for idx in kg_explorer.valid_data.get('nodes', [])])
        test_idx = torch.tensor([int(idx) for idx in kg_explorer.test_data.get('nodes', [])])
    else:
        nodes = np.array([int(node) for node in labeled_nodes])
        train_nodes, test_nodes = train_test_split(nodes, test_size=0.2, random_state=42)
        train_nodes, valid_nodes = train_test_split(train_nodes, test_size=0.15, random_state=42)
        train_idx = torch.tensor(train_nodes)
        valid_idx = torch.tensor(valid_nodes)
        test_idx = torch.tensor(test_nodes)

    text_embeddings = None
    if text_features:
        text_feature_matrix = np.zeros((num_nodes, list(text_features.values())[0].shape[0]))
        for node_id, embedding in text_features.items():
            text_feature_matrix[int(node_id)] = embedding
        text_embeddings = text_feature_matrix

    image_embeddings = None
    if image_features:
        image_feature_matrix = np.zeros((num_nodes, list(image_features.values())[0].shape[0]))
        for node_id, embedding in image_features.items():
            image_feature_matrix[int(node_id)] = embedding
        image_embeddings = image_feature_matrix

    return edge_index, labels, train_idx, valid_idx, test_idx, text_embeddings, image_embeddings, label_to_idx

def train_gnn(kg_explorer, epochs=30, lr=0.001, hidden_dim=128, batch_size=512):
    """Train a Graph Neural Network for node classification"""
    print("Starting GNN training...")
    text_features, image_features = extract_features(kg_explorer)
    edge_index, labels, train_idx, valid_idx, test_idx, text_embeddings, image_embeddings, label_to_idx = \
        prepare_graph_data(kg_explorer, text_features, image_features)

    num_nodes = len(kg_explorer.e2i)
    num_classes = len(set(labels.numpy()))
    model = MultimodalGNN(num_nodes, hidden_dim, num_classes, text_embeddings, image_embeddings).to(kg_explorer.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)

    edge_index = edge_index.to(kg_explorer.device)
    labels = labels.to(kg_explorer.device)
    train_idx = train_idx.to(kg_explorer.device)
    valid_idx = valid_idx.to(kg_explorer.device)
    test_idx = test_idx.to(kg_explorer.device)

    best_val_acc = 0
    best_model = None
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(edge_index)
        loss = F.cross_entropy(out[train_idx], labels[train_idx])
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            out = model(edge_index)
            pred = out.argmax(dim=1)
            train_acc = (pred[train_idx] == labels[train_idx]).float().mean().item()
            val_acc = (pred[valid_idx] == labels[valid_idx]).float().mean().item()
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model = copy.deepcopy(model)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    best_model.eval()
    with torch.no_grad():
        out = best_model(edge_index)
        pred = out.argmax(dim=1)
        test_acc = (pred[test_idx] == labels[test_idx]).float().mean().item()
        test_f1 = f1_score(labels[test_idx].cpu().numpy(), pred[test_idx].cpu().numpy(), average='weighted')
        target_names = [f"Class {i}" for i in range(num_classes)]
        class_report = classification_report(labels[test_idx].cpu().numpy(), pred[test_idx].cpu().numpy(), target_names=target_names, output_dict=True)
    print(f"\nTest Accuracy: {test_acc:.4f}")
    print(f"Test F1 Score (weighted): {test_f1:.4f}")

    visualize_results(kg_explorer, best_model, edge_index, labels, test_idx, label_to_idx)
    return best_model, test_acc

def visualize_results(kg_explorer, model, edge_index, labels, test_idx, label_to_idx):
    """Visualize the results of the GNN model"""
    model.eval()
    with torch.no_grad():
        x = model.node_emb.weight
        if model.has_text:
            text_proj = model.text_projection(model.text_embeddings)
            x[model.text_nodes] = x[model.text_nodes] + text_proj
        if model.has_image:
            image_proj = model.image_projection(model.image_embeddings)
            x[model.image_nodes] = x[model.image_nodes] + image_proj
        x = F.relu(model.conv1(x, edge_index))
        x = F.relu(model.conv2(x, edge_index))
        embeddings = x.cpu().numpy()

    tsne = TSNE(n_components=2, random_state=42)
    test_embeddings = embeddings[test_idx.cpu().numpy()]
    test_labels = labels[test_idx].cpu().numpy()
    reduced_embeddings = tsne.fit_transform(test_embeddings)

    plt.figure(figsize=(12, 10))
    idx_to_label = {v: k for k, v in label_to_idx.items()}
    unique_classes = sorted(set(test_labels))
    for class_idx in unique_classes:
        mask = test_labels == class_idx
        plt.scatter(reduced_embeddings[mask, 0], reduced_embeddings[mask, 1], label=f"{idx_to_label.get(class_idx, f'Class {class_idx}')}", alpha=0.7)
    plt.title("t-SNE Visualization of Node Embeddings", fontsize=16)
    plt.xlabel("Dimension 1", fontsize=12)
    plt.ylabel("Dimension 2", fontsize=12)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig('node_embeddings_tsne.png', dpi=300)
    plt.close()

    plt.figure(figsize=(10, 8))
    conf_matrix = confusion_matrix(test_labels, model(edge_index).argmax(dim=1).cpu().numpy()[test_idx.cpu().numpy()])
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.title("Confusion Matrix", fontsize=16)
    plt.xlabel("Predicted Label", fontsize=12)
    plt.ylabel("True Label", fontsize=12)
    plt.savefig('confusion_matrix.png', dpi=300)
    plt.close()

def jointly_finetune(kg_explorer, gnn_model, edge_index, labels, train_idx, valid_idx, test_idx, epochs=5):
    """Jointly fine-tune the CLIP and BERT models with GNN"""
    print("Starting joint fine-tuning of pre-trained models...")
    bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased').to(kg_explorer.device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(kg_explorer.device)

    class JointModel(nn.Module):
        def _init_(self, gnn_model, bert_model, clip_model):
            super(JointModel, self)._init_()
            self.gnn = gnn_model
            self.bert = bert_model
            self.clip = clip_model

        def forward(self, edge_index, text_inputs=None, image_inputs=None):
            return self.gnn(edge_index)

    joint_model = JointModel(gnn_model, bert_model, clip_model).to(kg_explorer.device)
    for param in bert_model.embeddings.parameters():
        param.requires_grad = False
    for param in clip_model.vision_model.embeddings.parameters():
        param.requires_grad = False

    optimizer = torch.optim.Adam([
        {'params': gnn_model.parameters(), 'lr': 0.001},
        {'params': bert_model.encoder.parameters(), 'lr': 0.00003},
        {'params': clip_model.vision_model.encoder.parameters(), 'lr': 0.00003}
    ])

    edge_index = edge_index.to(kg_explorer.device)
    labels = labels.to(kg_explorer.device)
    train_idx = train_idx.to(kg_explorer.device)
    valid_idx = valid_idx.to(kg_explorer.device)

    best_val_acc = 0
    best_model = None
    for epoch in range(epochs):
        joint_model.train()
        optimizer.zero_grad()
        out = joint_model(edge_index)
        loss = F.cross_entropy(out[train_idx], labels[train_idx])
        loss.backward()
        optimizer.step()

        joint_model.eval()
        with torch.no_grad():
            out = joint_model(edge_index)
            pred = out.argmax(dim=1)
            train_acc = (pred[train_idx] == labels[train_idx]).float().mean().item()
            val_acc = (pred[valid_idx] == labels[valid_idx]).float().mean().item()
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model = copy.deepcopy(joint_model)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    best_model.eval()
    with torch.no_grad():
        out = best_model(edge_index)
        pred = out.argmax(dim=1)
        test_acc = (pred[test_idx] == labels[test_idx]).float().mean().item()
    print(f"\nJoint Fine-tuning Test Accuracy: {test_acc:.4f}")
    return best_model

def main():
    data_path = "./dmg777k"
    explorer = KGBenchExplorer(data_path)

    print("=== Running Exploratory Data Analysis ===")
    explorer.run_complete_eda()

    print("\n=== Training Multimodal GNN ===")
    gnn_model, test_accuracy = train_gnn(explorer, epochs=10)

    if input("\nDo you want to run the BONUS joint fine-tuning? (y/n): ").lower() == 'y':
        print("\n=== Running Bonus Task: Joint Fine-tuning ===")
        text_features, image_features = extract_features(explorer)
        edge_index, labels, train_idx, valid_idx, test_idx, text_embeddings, image_embeddings, label_to_idx = \
            prepare_graph_data(explorer, text_features, image_features)
        joint_model = jointly_finetune(explorer, gnn_model, edge_index, labels, train_idx, valid_idx, test_idx, epochs=2)
        print("Joint fine-tuning completed!")

    print("\n=== Task Completed ===")

if _name_ == "_main_":
    main()

Looking in links: https://data.pyg.org/whl/torch-{torch._version_}.html
Collecting torch-scatter
  Using cached torch_scatter-2.1.2.tar.gz (108 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-sparse
  Using cached torch_sparse-0.6.18.tar.gz (209 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-cluster
  Using cached torch_cluster-1.6.3.tar.gz (54 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-spline-conv
  Using cached torch_spline_conv-1.2.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25hBuildin

AttributeError: module 'torch_geometric' has no attribute '_version_'