In [1]:
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

In [2]:
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import LanguageParser
from langchain_text_splitters import Language
repo_path = "/home/haotian/Splitted_code"

loader = GenericLoader.from_filesystem(
    repo_path,
    glob="**/*",
    suffixes=[".cpp"],
    # exclude=["**/non-utf8-encoding.py"],
    parser=LanguageParser(language=Language.CPP, parser_threshold=500),
)
documents = loader.load()
repo_path_2 = "/home/haotian/book"

loader2 = GenericLoader.from_filesystem(
    repo_path_2,
    glob="**/*",
    suffixes=[".txt"],
    # exclude=["**/non-utf8-encoding.py"],
    parser=LanguageParser(language=Language.CPP, parser_threshold=500),
)
documents2 = loader2.load()



In [6]:
documents = documents + documents2


In [7]:
len(documents)
from langchain_text_splitters import RecursiveCharacterTextSplitter

python_splitter = RecursiveCharacterTextSplitter.from_language(
    language=Language.CPP, chunk_size=1000, chunk_overlap=200
)
texts = python_splitter.split_documents(documents)
len(texts)

2452

In [8]:
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
import numpy as np
np.array(model.encode(texts[0].page_content).shape)

array([768])

In [10]:
txt = []
for text in texts:
    txt.append(np.array(model.encode(text.page_content)))

In [12]:
query = '''
void kernel_adi(int tsteps,int n,float u[60][60],float v[60][60],float p[60][60],float q[60][60])
{
  float DX;
  float DY;
  float DT;
  float B1;
  float B2;
  float mul1;
  float mul2;
  float a;
  float b;
  float c;
  float d;
  float e;
  float f;
  int t;
  int i;
  int j;
  
{
    DX = 1.0 / ((float )60);
    DY = 1.0 / ((float )60);
    DT = 1.0 / ((float )40);
    B1 = 2.0;
    B2 = 1.0;
    mul1 = B1 * DT / (DX * DX);
    mul2 = B2 * DT / (DY * DY);
    a = -mul1 / 2.0;
    b = 1.0 + mul1;
    c = a;
    d = -mul2 / 2.0;
    e = 1.0 + mul2;
    f = d;
    
    
    
    for (t = 1; t <= 40; t++) {
//Column Sweep
      
      
      
      for (i = 1; i < 60 - 1; i++) {
        v[0][i] = 1.0;
        p[i][0] = 0.0;
        q[i][0] = v[0][i];
        
        for (j = 1; j < 60 - 1; j++) {
          p[i][j] = -c / (a * p[i][j - 1] + b);
          q[i][j] = (-d * u[j][i - 1] + (1.0 + 2.0 * d) * u[j][i] - f * u[j][i + 1] - a * q[i][j - 1]) / (a * p[i][j - 1] + b);
        }
        v[60 - 1][i] = 1.0;
        
        for (j = 0; j <= 57; j++) {
          int _in_j_0 = 58 + -1 * j;
          v[_in_j_0][i] = p[i][_in_j_0] * v[_in_j_0 + 1][i] + q[i][_in_j_0];
        }
        j = 1 + -1;
      }
//Row Sweep
      
      
      
      for (i = 1; i < 60 - 1; i++) {
        u[i][0] = 1.0;
        p[i][0] = 0.0;
        q[i][0] = u[i][0];
        
        for (j = 1; j < 60 - 1; j++) {
          p[i][j] = -f / (d * p[i][j - 1] + e);
          q[i][j] = (-a * v[i - 1][j] + (1.0 + 2.0 * a) * v[i][j] - c * v[i + 1][j] - d * q[i][j - 1]) / (d * p[i][j - 1] + e);
        }
        u[i][60 - 1] = 1.0;
        
        for (j = 0; j <= 57; j++) {
          int _in_j = 58 + -1 * j;
          u[i][_in_j] = p[i][_in_j] * u[i][_in_j + 1] + q[i][_in_j];
        }
        j = 1 + -1;
      }
    }
  }
}

'''
# top k search
import numpy as np
k = 5
query_embed = model.encode(query)
score = []
for t in txt:
    score.append(query_embed.T @ t )
    # score.append(np.linalg.norm(query_embed - t))
indexed_list = sorted(enumerate(score), key=lambda x: x[1], reverse= True)
# indexed_list = sorted(enumerate(score), key=lambda x: x[1], reverse= False)
sorted_indices = [index for index, value in indexed_list]
top_k_index = sorted_indices[:k]
top_k_doc = []
for i in top_k_index:
    top_k_doc.append(texts[i])
top_k_doc
r_doc = []
for t in top_k_doc:
    r_doc.append(t.page_content)
r_doc = r_doc[:3]

In [15]:
prompt = f'''Here are some examples of matrix multiplication code with pragma hls lines added
{",".join(r_doc)},
Now try your best to optimize by inserting Pragma HLS lines:{query}
Only output your optimized cod, do not output anything else'''
prompt

'Here are some examples of matrix multiplication code with pragma hls lines added\nvoid krnl_vmult(uint32_t* in1, uint32_t* in2, uint32_t* out, int vSize) {\n    static hls::stream<uint32_t> inStream1("input_stream_1");\n    static hls::stream<uint32_t> inStream2("input_stream_2");\n    static hls::stream<uint32_t> outStream("output_stream");\n#pragma HLS INTERFACE m_axi port = in1 bundle = gmem0 depth = 4096\n#pragma HLS INTERFACE m_axi port = in2 bundle = gmem1 depth = 4096\n#pragma HLS INTERFACE m_axi port = out bundle = gmem0 depth = 4096\n\n#pragma HLS dataflow\n    // dataflow pragma instruct compiler to run following three APIs in parallel\n    read_input(in1, inStream1, vSize);\n    read_input(in2, inStream2, vSize);\n    compute_mult(inStream1, inStream2, outStream, vSize);\n    write_result(out, outStream, vSize);\n}\n},void firI1(data t ∗y, data t x);\nvoid firQ1(data t ∗y, data t x);\nvoid firI2(data t ∗y, data t x);\nvoid firQ2(data t ∗y, data t x);,void firI1(data t ∗y, d