-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
69 lines (64 loc) · 2.38 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include "Core/BlockPrep.hpp"
#include "Core/CPUEvaluator.hpp"
#include "Core/Compiler.hpp"
#include "Core/Parser.hpp"
#include "Core/PortableMemPool.hpp"
#include <iostream>
using namespace FunGPU;
int main(int argc, char **argv) {
auto memPool = std::make_shared<PortableMemPool>();
try {
cl::sycl::buffer<PortableMemPool> memPoolBuff(memPool,
cl::sycl::range<1>(1));
CPUEvaluator evaluator(memPoolBuff);
Index_t argvIndex = 1;
BlockPrep blockPrep(64, 32, 32, memPoolBuff);
while (true) {
const auto programPath = [&]() -> std::optional<std::string> {
if (argvIndex < argc) {
return std::string(argv[argvIndex++]);
}
std::cout << "Program to run(or q to quit): ";
std::string interactivePath;
std::cin >> interactivePath;
if (interactivePath == "q") {
return std::optional<std::string>();
}
return interactivePath;
}();
if (!programPath) {
break;
}
Parser parser(*programPath);
auto parsedResult = parser.ParseProgram();
Compiler compiler(parsedResult, memPoolBuff);
Compiler::ASTNodeHandle compiledResult;
try {
compiledResult = compiler.Compile();
} catch (const Compiler::CompileException &e) {
std::cerr << "Failed to compile " << *programPath << ": " << e.What()
<< std::endl;
continue;
}
std::cout << "Original compilation without any modifications: "
<< std::endl;
compiler.DebugPrintAST(compiledResult);
std::cout << std::endl;
std::cout << "Updated for block generation: " << std::endl;
compiledResult = blockPrep.PrepareForBlockGeneration(compiledResult);
compiler.DebugPrintAST(compiledResult);
std::cout << "Successfully compiled program " << *programPath
<< std::endl;
Index_t maxConcurrentBlockCount;
const auto programResult =
evaluator.EvaluateProgram(compiledResult, maxConcurrentBlockCount);
std::cout << programResult.m_data.floatVal << std::endl;
std::cout << "Max concurrent blocks: " << maxConcurrentBlockCount
<< std::endl;
compiler.DeallocateAST(compiledResult);
}
} catch (const cl::sycl::exception &e) {
std::cerr << "Sycl exception in main: " << e.what() << std::endl;
}
return 0;
}