In [None]:
import os
import javalang
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset, Data, Dataset
import javalang

In [None]:
import antlr4
from JavaLexer import JavaLexer
from JavaParser import JavaParser
from antlr4.tree.Tree import ParseTreeVisitor

# This visitor will walk the AST and create a graph representation
class ASTGraphCreator(ParseTreeVisitor):
    def __init__(self):
        self.graph = []
        self.node_index = -1
        self.edge_index = []

    def visit(self, node):
        if node.getChildCount() == 0:
            return

        # Enter the node
        self.node_index += 1
        current_node_index = self.node_index
        node_type = type(node).__name__.replace('Context', '')
        self.graph.append((current_node_index, node_type))

        for i in range(node.getChildCount()):
            child = node.getChild(i)
            child_index = self.node_index + 1

            
            self.edge_index.append((current_node_index, child_index))

          
            self.visit(child)

    def get_graph(self):
        return self.graph, self.edge_index


def parse_java_to_graph(java_code):
    input_stream = antlr4.InputStream(java_code)
    lexer = JavaLexer(input_stream)
    stream = antlr4.CommonTokenStream(lexer)
    parser = JavaParser(stream)
    tree = parser.compilationUnit()

    creator = ASTGraphCreator()
    creator.visit(tree)

    return creator.get_graph()

# Example usage
# java_code_example = """
# public class HelloWorld {
#     public static void main(String[] args) {
#         System.out.println("Hello, world!");
#     }
# }
# """
# graph, edges = parse_java_to_graph(java_code_example)
# print(graph)
# print(edges)

In [None]:
import glob


class JavaASTGraphDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_paths = glob.glob(os.path.join(root_dir, '**', '*.java'), recursive=True)

    # Correct method name for getting the dataset size
    def __len__(self):
        return len(self.file_paths)

    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        with open(file_path, 'r', encoding='utf-8') as java_file:
            java_code = java_file.read()
        
        # Parse the Java code into an AST and then into a graph
        nodes, edges = parse_java_to_graph(java_code)

   
        node_features = torch.tensor([1] * len(nodes), dtype=torch.float).view(-1, 1)
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

        
        label = torch.tensor([1 if "plagiarized" in file_path else 0], dtype=torch.long)
        
        # Create a Data object
        data = Data(x=node_features, edge_index=edge_index, y=label)
        
        return data



In [None]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class SimpleGNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))

        x = F.dropout(x, training=self.training)
        x = self.lin(x)

        return F.log_softmax(x, dim=1)


In [None]:

dataset_root = r'D:\hdu\gnn\dataset'
dataset = JavaASTGraphDataset(root_dir=dataset_root)


In [None]:
loader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
def train(model, loader, optimizer):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data)
        pred = out.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)



In [None]:

input_dim = 1  
hidden_dim = 64
output_dim = 2  

model = SimpleGNN(input_dim, hidden_dim, output_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), len(dataset) - int(len(dataset)*0.8)])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

for data in train_loader:
    print("Max node index:", data.edge_index.max().item())
    print("Number of nodes:", data.x.size(0))
    assert data.edge_index.max().item() < data.x.size(0), "Edge index out of bounds!"

for epoch in range(20):  # Number of epochs
    train_loss = train(model, train_loader, optimizer)
    val_acc = evaluate(model, val_loader)
    print(f'Epoch: {epoch+1}, Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f}')

In [None]:
import numpy as np
import matplotlib.pyplot as plt


def drawPlots(epochs, accuracies, losses):
    plt.figure(figsize=(15, 5))


    plt.subplot(1, 2, 1)
    plt.plot(range(1, epochs+1), accuracies, label='Accuracy', color='blue')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Accuracy over Epochs')
    plt.grid(True)


    plt.subplot(1, 2, 2)
    plt.plot(range(1, epochs+1), losses, label='Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss over Epochs')
    plt.grid(True)

    plt.tight_layout()
    plt.show()


#