Skip to content

Commit

Permalink
Update iOS to use latest APIs in llm chat (octoml#178)
Browse files Browse the repository at this point in the history
* ios downloader

* use dist as optional provided dir

* Update iOS app to new reload api

---------

Co-authored-by: Yaxing Cai <caiyaxing666@gmail.com>
  • Loading branch information
tqchen and cyx-6 committed May 18, 2023
1 parent ff81bdb commit de7b5ab
Show file tree
Hide file tree
Showing 19 changed files with 612 additions and 91 deletions.
5 changes: 4 additions & 1 deletion build.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _parse_args():

parsed.export_kwargs = {}
parsed.lib_format = "so"

parsed.system_lib_prefix = None
parsed = _setup_model_path(parsed)

parsed.db_path = parsed.db_path or os.path.join("log_db", parsed.model)
Expand Down Expand Up @@ -248,6 +248,9 @@ def dump_default_mlc_chat_config(args):

def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
target_kind = args.target_kind
if args.system_lib_prefix:
mod_deploy = mod_deploy.with_attrs({"system_lib_prefix": args.system_lib_prefix})

debug_dump_script(mod_deploy, "mod_before_build.py", args)
if target_kind != "cpu":
if os.path.exists(args.db_path):
Expand Down
52 changes: 31 additions & 21 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ class LLMChat {
*/
std::string RuntimeStatsText() {
std::ostringstream os;
os << "encode: " << std::setprecision(1) << std::fixed
os << "prefill: " << std::setprecision(1) << std::fixed
<< this->encode_total_tokens / this->encode_total_time << " tok/s"
<< ", decode: " << std::setprecision(1) << std::fixed
<< this->decode_total_tokens / this->decode_total_time << " tok/s";
Expand Down Expand Up @@ -518,6 +518,13 @@ class LLMChat {
ICHECK(fload_params) << "Cannot find env function vm.builtin.param_array_from_cache";
params_ = (*fload_params)("param", -1);

// after we get params, it is safe to simply clear the cached version
// as these params are referenced by params_
const PackedFunc* fclear_ndarray_cache =
tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.clear");
ICHECK(fclear_ndarray_cache) << "Cannot find env function vm.builtin.ndarray_cache.clear";
(*fclear_ndarray_cache)();

// Step 4. KV cache creation.
kv_cache_ = vm_->GetFunction("create_kv_cache")();

Expand Down Expand Up @@ -1103,63 +1110,68 @@ class LLMChatModule : public ModuleNode {
ICHECK_EQ(args.size(), 2);
chat_ = nullptr;
chat_ = std::make_unique<LLMChat>(LLMChat(device_));
(*fclear_ndarray_cache_)();
chat_->Reload(args[0], args[1]);
});
}

ICHECK(chat_ != nullptr);
if (name == "evaluate") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { chat_->Evaluate(); });
} else if (name == "unload") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { chat_ = nullptr; });
} else if (name == "evaluate") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->Evaluate(); });
} else if (name == "try_tokenizer") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { chat_->TryTokenizer(); });
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->TryTokenizer(); });
} else if (name == "encode") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 1);
chat_->EncodeStep(args[0]);
GetChat()->EncodeStep(args[0]);
});
} else if (name == "decode") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { chat_->DecodeStep(); });
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->DecodeStep(); });
} else if (name == "init_chat_legacy") {
// TODO: remove the legacy initialization func after updating app and web sides.
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 5);
chat_->InitChatLegacy(args[0], args[1], args[2], args[3], args[4]);
GetChat()->InitChatLegacy(args[0], args[1], args[2], args[3], args[4]);
});
} else if (name == "reset_chat") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 0);
chat_->ResetChat();
GetChat()->ResetChat();
});
} else if (name == "get_role0") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
*rv = chat_->conversation_.roles[0];
*rv = GetChat()->conversation_.roles[0];
});
} else if (name == "get_role1") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
*rv = chat_->conversation_.roles[1];
*rv = GetChat()->conversation_.roles[1];
});
} else if (name == "stopped") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = chat_->Stopped(); });
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = GetChat()->Stopped(); });
} else if (name == "get_message") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = chat_->GetMessage(); });
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = GetChat()->GetMessage(); });
} else if (name == "runtime_stats_text") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = chat_->RuntimeStatsText(); });
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
*rv = GetChat()->RuntimeStatsText();
});
} else if (name == "reset_runtime_stats") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { chat_->ResetRuntimeStats(); });
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->ResetRuntimeStats(); });
} else {
return PackedFunc(nullptr);
}
}

void Init(DLDevice device) { device_ = device; }

LLMChat* GetChat() {
ICHECK(chat_ != nullptr) << "Chat is not initialized via reload";
return chat_.get();
}

// TODO: legacy function to be removed
void InitLegacy(tvm::runtime::Module executable, std::unique_ptr<Tokenizer> tokenizer,
const tvm::runtime::String& param_path, DLDevice device) {
Expand Down Expand Up @@ -1217,8 +1229,6 @@ class LLMChatModule : public ModuleNode {
const char* type_key() const final { return "mlc.llm_chat"; }

private:
const PackedFunc* fclear_ndarray_cache_ =
tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.clear");
std::unique_ptr<LLMChat> chat_ = nullptr;
DLDevice device_;
};
Expand Down
28 changes: 27 additions & 1 deletion ios/MLCChat.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
objects = {

/* Begin PBXBuildFile section */
1453A4C92A1353F1001B909F /* ModelConfig in CopyFiles */ = {isa = PBXBuildFile; fileRef = 1453A4C72A1353D7001B909F /* ModelConfig */; };
1453A4CF2A1354B9001B909F /* StartView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CA2A1354B9001B909F /* StartView.swift */; };
1453A4D02A1354B9001B909F /* ModelView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CB2A1354B9001B909F /* ModelView.swift */; };
1453A4D12A1354B9001B909F /* StartState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CC2A1354B9001B909F /* StartState.swift */; };
1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CD2A1354B9001B909F /* ModelConfig.swift */; };
1453A4D32A1354B9001B909F /* ModelState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CE2A1354B9001B909F /* ModelState.swift */; };
C06A74E429F99E5500BC4BE6 /* LLMChat.mm in Sources */ = {isa = PBXBuildFile; fileRef = C06A74E329F99E5500BC4BE6 /* LLMChat.mm */; };
C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */ = {isa = PBXBuildFile; fileRef = C06A74E029F99C9F00BC4BE6 /* dist */; };
C06A74F429F9BE7A00BC4BE6 /* ThreadWorker.swift in Sources */ = {isa = PBXBuildFile; fileRef = C06A74F329F9BE7A00BC4BE6 /* ThreadWorker.swift */; };
Expand All @@ -25,6 +31,7 @@
dstPath = "";
dstSubfolderSpec = 7;
files = (
1453A4C92A1353F1001B909F /* ModelConfig in CopyFiles */,
C06A74F229F9A78800BC4BE6 /* dist in CopyFiles */,
);
runOnlyForDeploymentPostprocessing = 0;
Expand All @@ -42,6 +49,12 @@
/* End PBXCopyFilesBuildPhase section */

/* Begin PBXFileReference section */
1453A4C72A1353D7001B909F /* ModelConfig */ = {isa = PBXFileReference; lastKnownFileType = folder; name = ModelConfig; path = MLCChat/ModelConfig; sourceTree = "<group>"; };
1453A4CA2A1354B9001B909F /* StartView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = StartView.swift; path = MLCChat/StartView.swift; sourceTree = "<group>"; };
1453A4CB2A1354B9001B909F /* ModelView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = ModelView.swift; path = MLCChat/ModelView.swift; sourceTree = "<group>"; };
1453A4CC2A1354B9001B909F /* StartState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = StartState.swift; path = MLCChat/StartState.swift; sourceTree = "<group>"; };
1453A4CD2A1354B9001B909F /* ModelConfig.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = ModelConfig.swift; path = MLCChat/ModelConfig.swift; sourceTree = "<group>"; };
1453A4CE2A1354B9001B909F /* ModelState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; name = ModelState.swift; path = MLCChat/ModelState.swift; sourceTree = "<group>"; };
C06A74E029F99C9F00BC4BE6 /* dist */ = {isa = PBXFileReference; lastKnownFileType = folder; path = dist; sourceTree = "<group>"; };
C06A74E229F99E5500BC4BE6 /* MLCChat-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "MLCChat-Bridging-Header.h"; sourceTree = "<group>"; };
C06A74E329F99E5500BC4BE6 /* LLMChat.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = LLMChat.mm; sourceTree = "<group>"; };
Expand Down Expand Up @@ -70,6 +83,12 @@
C0D643A629F99A7F004DDAA4 = {
isa = PBXGroup;
children = (
1453A4CD2A1354B9001B909F /* ModelConfig.swift */,
1453A4CE2A1354B9001B909F /* ModelState.swift */,
1453A4CB2A1354B9001B909F /* ModelView.swift */,
1453A4CC2A1354B9001B909F /* StartState.swift */,
1453A4CA2A1354B9001B909F /* StartView.swift */,
1453A4C72A1353D7001B909F /* ModelConfig */,
C06A74E029F99C9F00BC4BE6 /* dist */,
C0D643B129F99A7F004DDAA4 /* MLCChat */,
C0D643B029F99A7F004DDAA4 /* Products */,
Expand Down Expand Up @@ -192,10 +211,15 @@
files = (
C06A74E429F99E5500BC4BE6 /* LLMChat.mm in Sources */,
C06A74F429F9BE7A00BC4BE6 /* ThreadWorker.swift in Sources */,
1453A4D12A1354B9001B909F /* StartState.swift in Sources */,
C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */,
C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */,
C0D643C329F99B07004DDAA4 /* ChatState.swift in Sources */,
1453A4D32A1354B9001B909F /* ModelState.swift in Sources */,
C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */,
1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */,
1453A4D02A1354B9001B909F /* ModelView.swift in Sources */,
1453A4CF2A1354B9001B909F /* StartView.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand Down Expand Up @@ -381,10 +405,11 @@
ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements;
CODE_SIGN_IDENTITY = "Apple Development";
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_ASSET_PATHS = "\"MLCChat/Preview Content\"";
DEVELOPMENT_TEAM = 3D7CQ3D634;
DEVELOPMENT_TEAM = 3FR42MXLK9;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
"HEADER_SEARCH_PATHS[arch=*]" = (
Expand Down Expand Up @@ -420,6 +445,7 @@
);
PRODUCT_BUNDLE_IDENTIFIER = mlc.Chat;
PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE_SPECIFIER = "";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_OBJC_BRIDGING_HEADER = "MLCChat/MLCChat-Bridging-Header.h";
SWIFT_VERSION = 5.0;
Expand Down
82 changes: 50 additions & 32 deletions ios/MLCChat/ChatState.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ struct MessageData: Hashable, Identifiable {
}

class ChatState : ObservableObject {
@Published var messages = [MessageData]()
@Published var infoText = ""
@Published var messages = [MessageData]();
@Published var infoText = "";
@Published var modelName = "";
@Published var inProgress = false;
@Published var unfinishedRespondRole = MessageRole.bot;
@Published var unfinishedRespondMessage = "";
Expand All @@ -30,32 +31,53 @@ class ChatState : ObservableObject {
private var stopRequested = false;
private var gpuVRAMDetectionPass = false;


init() {
threadWorker.qualityOfService = QualityOfService.userInteractive;
threadWorker.start()
self.systemInit()
// TODO(change to state dependent)
let modelLib = "RedPajama-INCITE-Chat-3B-v1-q4f16_0";
let modelPath = Bundle.main.bundlePath + "/dist/RedPajama-INCITE-Chat-3B-v1-q4f16_0";
self.mainReload(modelName: "RedPajama-3B", modelLib: modelLib, modelPath: modelPath)
}
func systemInit() {
let vram = os_proc_available_memory()
if (vram < (4000000000)) {
print(vram)
let errMsg = (
"Sorry, the system do not have 4GB memory as requested, " +
"so we cannot initialize chat module on this device."
)
self.messages.append(MessageData(role: MessageRole.bot, message: errMsg))
self.gpuVRAMDetectionPass = false
self.inProgress = true
return

// reset all chat state
func mainResetChat() {
self.messages = [MessageData]()
self.infoText = ""
self.unfinishedRespondMessage = ""
self.inProgress = false;
self.requestedReset = false;
}

func mainReload(modelName: String, modelLib: String, modelPath: String, estimatedMemReq : Int64 = 4000000000) {
if (self.inProgress) {
return;
}
self.gpuVRAMDetectionPass = true
self.mainResetChat()
self.gpuVRAMDetectionPass = false;
self.inProgress = true;
self.modelName = modelName;

threadWorker.push {[self] in
self.updateReply(role: MessageRole.bot, message: "[System] Initalize...")
backend.initialize()
backend.unload();
let vram = os_proc_available_memory()
if (vram < estimatedMemReq) {
let reqMem = String (
format: "%.1fGB", Double(estimatedMemReq) / Double(1 << 20)
)
let errMsg = (
"Sorry, the system do not have" + reqMem + " memory as requested, " +
"so we cannot initialize this model on this device."
)
self.messages.append(MessageData(role: MessageRole.bot, message: errMsg))
self.gpuVRAMDetectionPass = false
self.inProgress = true
return
}
self.gpuVRAMDetectionPass = true
backend.reload(modelLib, modelPath: modelPath)
self.updateReply(role: MessageRole.bot, message: "[System] Ready to chat")
self.commitReply()
self.markFinish()
Expand All @@ -80,18 +102,18 @@ class ChatState : ObservableObject {
}
self.commitReply()
self.reportSpeed(encodingSpeed: 1000, decodingSpeed: 1000)

self.markFinish()
}
}
}

func backendGenerate(prompt: String) {
assert(self.inProgress);
// generation needs to run on thread worker
threadWorker.push {[self] in
self.appendMessage(role: MessageRole.user, message: prompt)

backend.encode(prompt);
while (!backend.stopped()) {
assert(self.inProgress);
Expand All @@ -105,7 +127,7 @@ class ChatState : ObservableObject {
break;
}
}

self.commitReply()
let runtimeText: String = self.backend.runtimeStatsText()
DispatchQueue.main.sync { [runtimeText] in
Expand All @@ -123,7 +145,7 @@ class ChatState : ObservableObject {
self.stopRequested = false
self.backendGenerate(prompt: prompt)
}

func requestStop() {
if (!self.gpuVRAMDetectionPass) {
return
Expand All @@ -150,23 +172,19 @@ class ChatState : ObservableObject {
threadWorker.push {
self.backend.reset()
DispatchQueue.main.sync {
self.messages = [MessageData]()
self.infoText = ""
self.unfinishedRespondMessage = ""
self.inProgress = false;
self.requestedReset = false;
self.mainResetChat();
}
}
}

func reportSpeed(encodingSpeed: Float, decodingSpeed: Float) {
DispatchQueue.main.sync { [self, encodingSpeed, decodingSpeed] in
self.infoText = String(
format: "encode: %.1f tok/s, decode: %.1f tok/s", encodingSpeed, decodingSpeed
format: "prefill: %.1f tok/s, decode: %.1f tok/s", encodingSpeed, decodingSpeed
)
}
}

func markFinish() {
DispatchQueue.main.sync { [self] in
self.inProgress = false
Expand Down
2 changes: 1 addition & 1 deletion ios/MLCChat/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct ChatView: View {
}.bold().opacity(state.inProgress ? 0.5 : 1)
}.frame(minHeight: CGFloat(70)).padding()
}
.navigationBarTitle("MLC Chat", displayMode: .inline)
.navigationBarTitle("MLC Chat: " + state.modelName, displayMode: .inline)
.toolbar{
ToolbarItem(placement: .navigationBarLeading) {
Button("Reset") {
Expand Down
Loading

0 comments on commit de7b5ab

Please sign in to comment.