In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as f
import pandas as pd
import time
from sparkkgml.kg import KG
from sparkkgml.motifWalks import MotifWalks

from pyspark.ml.classification import LinearSVC
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# Initialize SparkSession with driver and executor memory settings and number of executors
spark = SparkSession.builder.getOrCreate()

In [None]:
# Initialize the class with skip predicates
kg_instance = KG(location="./mutag/carcinogenesis.owl",skip_predicates=["http://dl-learner.org/carcinogenesis#isMutagenic"],sparkSession=spark)

# Create a new GraphFrame from the knowledge graph
graph = kg_instance.createKG()

In [None]:
# Load the training and test data from TSV files
test_data = pd.read_csv("./mutag/testSet.tsv", sep="\t")
train_data = pd.read_csv("./mutag/trainingSet.tsv", sep="\t")

# Extract entities and labels from the data
train_entities = [entity for entity in train_data["bond"]]
train_labels = list(train_data["label_mutagenic"])
test_entities = [entity for entity in test_data["bond"]]
test_labels = list(test_data["label_mutagenic"])

# Combine entities and labels for processing
entities = train_entities + test_entities
labels = train_labels + test_labels

In [None]:
# Define vector sizes and depths to evaluate
vector_sizes = [100, 200, 300, 400, 500]
depths = [2, 4, 6]
results = []

# Set random seed for reproducibility
spark.sparkContext.setSeed(42)

# Iterate over different depths for motif walks
for depth in depths:
    
    # Initialize MotifWalks with the KG instance and entities
    motifWalks_instance = MotifWalks(kg_instance, entities=entities, sparkSession=spark)
    # Perform motif walks on the graph with the specified depth
    paths_df = motifWalks_instance.motif_walk(graph, depth)

    # Iterate over different vector sizes for embeddings
    for vector_size in vector_sizes:
        
        # Generate word2Vec embeddings for the paths
        embeddings = motifWalks_instance.word2Vec_embeddings(
            paths_df,
            vector_size=vector_size,
            window_size=5,
            min_count=0,
            max_iter=5,
            step_size=0.025,
            num_partitions=1,
            seed=42,
            input_col="paths",
            output_col="vectors"
        )

        # Convert entities and labels into a Spark DataFrame
        data = [(entity, label) for entity, label in zip(entities, labels)]
        data_df = spark.createDataFrame(data, ["word", "label"])
        # Join embeddings DataFrame with the data DataFrame on the "word" column
        combined_df = embeddings.join(data_df, "word")
        # Split data into training and test sets based on entity membership
        train_embeddings = combined_df.filter(f.col("word").isin(train_entities))
        test_embeddings = combined_df.filter(f.col("word").isin(test_entities))

        # Define the Linear Support Vector Classifier
        svm = LinearSVC(maxIter=10, labelCol="label", featuresCol="vector")

        # Set up cross-validation to find the best regularization parameter
        crossval = CrossValidator(
            estimator=svm,
            estimatorParamMaps=ParamGridBuilder().addGrid(svm.regParam, [10**i for i in range(-3, 4)]).build(),
            evaluator=MulticlassClassificationEvaluator(),
            numFolds=5  # Number of folds for cross-validation
        )

        # Fit the CrossValidator and get the best model
        cvModel = crossval.fit(train_embeddings)
        bestModel = cvModel.bestModel

        # Make predictions on the test data
        predictions = bestModel.transform(test_embeddings)

        # Evaluate the accuracy of the predictions
        evaluator = MulticlassClassificationEvaluator(
            labelCol="label",
            predictionCol="prediction",
            metricName="accuracy"
        )
        accuracy = evaluator.evaluate(predictions)
        
        # Append the accuracy result and print it
        results.append(round(accuracy, 3))
        print(f"Vector Size {vector_size}, Accuracy: {accuracy * 100:.4f}%")