Skip to content

Commit

Permalink
fix(download): loose download host check - when TABBY_DOWNLOAD_HOST i… (
Browse files Browse the repository at this point in the history
#2183)

* fix(download): loose download host check - when TABBY_DOWNLOAD_HOST is not specified, will pick the first url from model url list

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] committed May 20, 2024
1 parent ee8b811 commit 107cdca
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ use tokio_retry::{
};
use tracing::{info, warn};

fn download_host() -> String {
std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned())
fn select_by_download_host(url: &String) -> bool {
if let Ok(host) = std::env::var("TABBY_DOWNLOAD_HOST") {
url.contains(&host)
} else {
true
}
}

async fn download_model_impl(
Expand Down Expand Up @@ -50,18 +54,13 @@ async fn download_model_impl(
return download_split_model(model_info, &model_path).await;
}

let registry = download_host();
let Some(model_url) = model_info
.urls
.iter()
.flatten()
.find(|x| x.contains(&registry))
.find(|x| select_by_download_host(x))
else {
return Err(anyhow!(
"Invalid mirror <{}> for model urls: {:?}",
registry,
model_info.urls
));
return Err(anyhow!("No valid url for model <{}>", model_info.name));
};

let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
Expand All @@ -79,17 +78,12 @@ async fn download_split_model(model_info: &ModelInfo, model_path: &Path) -> Resu
}
let mut paths = vec![];
let partition_urls = model_info.partition_urls.clone().unwrap_or_default();
let mirror = download_host();

let Some(urls) = partition_urls
.iter()
.find(|urls| urls.iter().all(|url| url.contains(&mirror)))
.find(|urls| urls.iter().all(select_by_download_host))
else {
return Err(anyhow!(
"Invalid mirror <{}> for model urls: {:?}",
mirror,
partition_urls
));
return Err(anyhow!("No valid url for model <{}>", model_info.name));
};

for (index, url) in urls.iter().enumerate() {
Expand Down

0 comments on commit 107cdca

Please sign in to comment.