In [1]:
import json
from pathlib import Path
import os
import re

In [2]:
with open("/workdir/onnx-mlir-seeds/command_mapping.json", "r") as f:
    cmd_string_mapping = json.load(f)

In [4]:
def get_cmd(text):
    # for now let's only consider mlir-opt
    pattern = r"^(onnx-mlir-opt) (.*?)[|\n]"
    matches = re.findall(pattern, text)
    if len(matches) < 1:
        return None
    if len(matches) > 1:
        #print(text)
        #raise ValueError("Expected at most one match")
        return None
    command, args = matches[0]
    return command, args.split()

cmd_mapping = dict()
for filename, command_strings in cmd_string_mapping.items():
    # concatenate command strings
    full_string = ""
    for cmd_str in command_strings:
        matches = re.findall(r"//\s+RUN:\s+(\S.*)", cmd_str)
        if len(matches) < 1:
            break
        if len(matches) > 1:
            raise ValueError("Expected at most one match")
        full_string += matches[0]
    full_string = full_string.replace("\\", "")
    #if "mlir-opt" in full_string:
    #    print(full_string)
    #    print("----")
    try:
        cmd_mapping[filename] = get_cmd(full_string)
    except:
        print(filename)

# filter out non mlir-opt commands
cmd_mapping = {file: cmdargs for file, cmdargs in cmd_mapping.items() if cmdargs is not None}
cmd_mapping

{'Gemm.mlir': ('onnx-mlir-opt',
  ['--convert-onnx-to-stablehlo', '%s', '-split-input-file']),
 'imperfectly_nested_stmts.mlir': ('onnx-mlir-opt',
  ['-O3', '--convert-krnl-to-affine', '%s', '-split-input-file']),
 'Split_with_canonicalize.mlir': ('onnx-mlir-opt',
  ['--shape-inference',
   '--convert-onnx-to-krnl',
   '--canonicalize',
   '%s',
   '-split-input-file']),
 'Pad_with_canonicalize.mlir': ('onnx-mlir-opt',
  ['--shape-inference',
   '--convert-onnx-to-krnl',
   '--canonicalize',
   '%s',
   '-split-input-file']),
 'onnx_shape_inference_optional.mlir': ('onnx-mlir-opt',
  ['--shape-inference', '%s']),
 'mul-2.mlir': ('onnx-mlir-opt',
  ['--mcpu=z16',
   '--maccel=NNPA',
   '--convert-zhigh-to-onnx',
   '%s',
   '-split-input-file']),
 'relu-2.mlir': ('onnx-mlir-opt',
  ['--mcpu=z16',
   '--maccel=NNPA',
   '--convert-zhigh-to-onnx',
   '%s',
   '-split-input-file']),
 'matmul-1.mlir': ('onnx-mlir-opt',
  ['--mcpu=z16',
   '--maccel=NNPA',
   '--shape-inference',
   '--conve

In [5]:
dialects = [
   "onnx", "krnl", "zhigh", 
]

In [6]:
dialect_associations = dict()
for cmd, args in cmd_mapping.values():
    for arg in args:
        possible_match = re.match(r"--?convert-([a-z]+)-to", arg)
        if possible_match:
            dialect = possible_match.group(1)
            if dialect not in dialect_associations:
                dialect_associations[dialect] = []
            if arg not in dialect_associations[dialect]:
                dialect_associations[dialect].append(arg)
        possible_match = re.match(r"--?([a-z]+)", arg)
        if possible_match:
            dialect = possible_match.group(1)
            if dialect in dialects:
                if dialect not in dialect_associations:
                    dialect_associations[dialect] = []
                if arg not in dialect_associations[dialect]:
                    dialect_associations[dialect].append(arg)
with open("/workdir/mlir-eval/onnx/dialect-associations.json", "w") as f:
    json.dump(dialect_associations, f)