Skip to content

Commit

Permalink
fix: replace panics with results & better option types (#437)
Browse files Browse the repository at this point in the history
* Replace panics with results & better option types

* Apply suggestions from code review

* Remove Reqwest from AWClient::new and return a Result instead

* Run cargo fmt

---------

Co-authored-by: NathanM <2955071+nathanmerrill@users.noreply.github.com>
Co-authored-by: Erik Bjäreholt <erik.bjareholt@gmail.com>
  • Loading branch information
3 people committed Nov 17, 2023
1 parent a144746 commit dc70318
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 137 deletions.
12 changes: 6 additions & 6 deletions aw-client-rust/src/blocking.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;
use std::future::Future;
use std::vec::Vec;
use std::{collections::HashMap, error::Error};

use chrono::{DateTime, Utc};

Expand All @@ -10,7 +10,7 @@ use super::AwClient as AsyncAwClient;

pub struct AwClient {
client: AsyncAwClient,
pub baseurl: String,
pub baseurl: reqwest::Url,
pub name: String,
pub hostname: String,
}
Expand Down Expand Up @@ -38,15 +38,15 @@ macro_rules! proxy_method
}

impl AwClient {
pub fn new(ip: &str, port: &str, name: &str) -> AwClient {
let async_client = AsyncAwClient::new(ip, port, name);
pub fn new(host: &str, port: u16, name: &str) -> Result<AwClient, Box<dyn Error>> {
let async_client = AsyncAwClient::new(host, port, name)?;

AwClient {
Ok(AwClient {
baseurl: async_client.baseurl.clone(),
name: async_client.name.clone(),
hostname: async_client.hostname.clone(),
client: async_client,
}
})
}

proxy_method!(get_bucket, Bucket, bucketname: &str);
Expand Down
22 changes: 13 additions & 9 deletions aw-client-rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ extern crate tokio;

pub mod blocking;

use std::collections::HashMap;
use std::vec::Vec;
use std::{collections::HashMap, error::Error};

use chrono::{DateTime, Utc};
use serde_json::Map;
Expand All @@ -17,7 +17,7 @@ pub use aw_models::{Bucket, BucketMetadata, Event};

pub struct AwClient {
client: reqwest::Client,
pub baseurl: String,
pub baseurl: reqwest::Url,
pub name: String,
pub hostname: String,
}
Expand All @@ -28,20 +28,24 @@ impl std::fmt::Debug for AwClient {
}
}

fn get_hostname() -> String {
return gethostname::gethostname().to_string_lossy().to_string();
}

impl AwClient {
pub fn new(ip: &str, port: &str, name: &str) -> AwClient {
let baseurl = format!("http://{ip}:{port}");
pub fn new(host: &str, port: u16, name: &str) -> Result<AwClient, Box<dyn Error>> {
let baseurl = reqwest::Url::parse(&format!("http://{}:{}", host, port))?;
let hostname = get_hostname();
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(120))
.build()
.unwrap();
let hostname = gethostname::gethostname().into_string().unwrap();
AwClient {
.build()?;

Ok(AwClient {
client,
baseurl,
name: name.to_string(),
hostname,
}
})
}

pub async fn get_bucket(&self, bucketname: &str) -> Result<Bucket, reqwest::Error> {
Expand Down
6 changes: 3 additions & 3 deletions aw-client-rust/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ mod test {

#[test]
fn test_full() {
let ip = "127.0.0.1";
let port: String = PORT.to_string();
let clientname = "aw-client-rust-test";
let client: AwClient = AwClient::new(ip, &port, clientname);

let client: AwClient =
AwClient::new("127.0.0.1", PORT, clientname).expect("Client creation failed");

let shutdown_handler = setup_testserver();

Expand Down
2 changes: 1 addition & 1 deletion aw-models/src/bucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn test_bucket() {
id: "id".to_string(),
_type: "type".to_string(),
client: "client".to_string(),
hostname: "hostname".to_string(),
hostname: "hostname".into(),
created: None,
data: json_map! {},
metadata: BucketMetadata::default(),
Expand Down
13 changes: 8 additions & 5 deletions aw-sync/src/dirs.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use dirs::home_dir;
use std::error::Error;
use std::fs;
use std::path::PathBuf;

// TODO: This could be refactored to share logic with aw-server/src/dirs.rs
// TODO: add proper config support
#[allow(dead_code)]
pub fn get_config_dir() -> Result<PathBuf, ()> {
let mut dir = appdirs::user_config_dir(Some("activitywatch"), None, false)?;
pub fn get_config_dir() -> Result<PathBuf, Box<dyn Error>> {
let mut dir = appdirs::user_config_dir(Some("activitywatch"), None, false)
.map_err(|_| "Unable to read user config dir")?;
dir.push("aw-sync");
fs::create_dir_all(dir.clone()).expect("Unable to create config dir");
fs::create_dir_all(dir.clone())?;
Ok(dir)
}

Expand All @@ -21,7 +23,8 @@ pub fn get_server_config_path(testing: bool) -> Result<PathBuf, ()> {
}))
}

pub fn get_sync_dir() -> Result<PathBuf, ()> {
pub fn get_sync_dir() -> Result<PathBuf, Box<dyn Error>> {
// TODO: make this configurable
home_dir().map(|p| p.join("ActivityWatchSync")).ok_or(())
let home_dir = home_dir().ok_or("Unable to read home_dir")?;
Ok(home_dir.join("ActivityWatchSync"))
}
108 changes: 42 additions & 66 deletions aw-sync/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ extern crate serde;
extern crate serde_json;

use std::error::Error;
use std::path::Path;
use std::path::PathBuf;

use chrono::{DateTime, Datelike, TimeZone, Utc};
use chrono::{DateTime, Utc};
use clap::{Parser, Subcommand};

use aw_client_rust::blocking::AwClient;
Expand All @@ -40,7 +39,7 @@ struct Opts {

/// Port of instance to connect to.
#[clap(long)]
port: Option<String>,
port: Option<u16>,

/// Convenience option for using the default testing host and port.
#[clap(long)]
Expand All @@ -58,8 +57,8 @@ enum Commands {
/// Pulls remote buckets then pushes local buckets.
Sync {
/// Host(s) to pull from, comma separated. Will pull from all hosts if not specified.
#[clap(long)]
host: Option<String>,
#[clap(long, value_parser=parse_list)]
host: Option<Vec<String>>,
},

/// Sync subcommand (advanced)
Expand All @@ -73,57 +72,64 @@ enum Commands {
/// If not specified, start from beginning.
/// NOTE: might be unstable, as count cannot be used to verify integrity of sync.
/// Format: YYYY-MM-DD
#[clap(long)]
start_date: Option<String>,
#[clap(long, value_parser=parse_start_date)]
start_date: Option<DateTime<Utc>>,

/// Specify buckets to sync using a comma-separated list.
/// If not specified, all buckets will be synced.
#[clap(long)]
buckets: Option<String>,
#[clap(long, value_parser=parse_list)]
buckets: Option<Vec<String>>,

/// Mode to sync in. Can be "push", "pull", or "both".
/// Defaults to "both".
#[clap(long, default_value = "both")]
mode: String,
mode: sync::SyncMode,

/// Full path to sync directory.
/// If not specified, exit.
#[clap(long)]
sync_dir: String,
sync_dir: PathBuf,

/// Full path to sync db file
/// Useful for syncing buckets from a specific db file in the sync directory.
/// Must be a valid absolute path to a file in the sync directory.
#[clap(long)]
sync_db: Option<String>,
sync_db: Option<PathBuf>,
},
/// List buckets and their sync status.
List {},
}

fn parse_start_date(arg: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
chrono::NaiveDate::parse_from_str(arg, "%Y-%m-%d")
.map(|nd| nd.and_time(chrono::NaiveTime::MIN).and_utc())
}

fn parse_list(arg: &str) -> Result<Vec<String>, clap::Error> {
Ok(arg.split(',').map(|s| s.to_string()).collect())
}

fn main() -> Result<(), Box<dyn Error>> {
let opts: Opts = Opts::parse();
let verbose = opts.verbose;

info!("Started aw-sync...");

aw_server::logging::setup_logger("aw-sync", opts.testing, verbose)
.expect("Failed to setup logging");
aw_server::logging::setup_logger("aw-sync", opts.testing, verbose)?;

let port = opts
.port
.or_else(|| Some(crate::util::get_server_port(opts.testing).ok()?.to_string()))
.unwrap();
.map(|a| Ok(a))
.unwrap_or_else(|| util::get_server_port(opts.testing))?;

let client = AwClient::new(opts.host.as_str(), port.as_str(), "aw-sync");
let client = AwClient::new(&opts.host, port, "aw-sync")?;

match &opts.command {
match opts.command {
// Perform basic sync
Commands::Sync { host } => {
// Pull
match host {
Some(host) => {
let hosts: Vec<&str> = host.split(',').collect();
Some(hosts) => {
for host in hosts.iter() {
info!("Pulling from host: {}", host);
sync_wrapper::pull(host, &client)?;
Expand All @@ -137,8 +143,7 @@ fn main() -> Result<(), Box<dyn Error>> {

// Push
info!("Pushing local data");
sync_wrapper::push(&client)?;
Ok(())
sync_wrapper::push(&client)
}
// Perform two-way sync
Commands::SyncAdvanced {
Expand All @@ -148,60 +153,31 @@ fn main() -> Result<(), Box<dyn Error>> {
sync_dir,
sync_db,
} => {
let sync_directory = if sync_dir.is_empty() {
error!("No sync directory specified, exiting...");
std::process::exit(1);
} else {
Path::new(&sync_dir)
};
info!("Using sync dir: {}", sync_directory.display());

if let Some(sync_db) = &sync_db {
info!("Using sync db: {}", sync_db);
if !sync_dir.is_absolute() {
Err("Sync dir must be absolute")?
}

let start: Option<DateTime<Utc>> = start_date.as_ref().map(|date| {
println!("{}", date.clone());
chrono::NaiveDate::parse_from_str(&date.clone(), "%Y-%m-%d")
.map(|nd| {
Utc.with_ymd_and_hms(nd.year(), nd.month(), nd.day(), 0, 0, 0)
.single()
.unwrap()
})
.expect("Date was not on the format YYYY-MM-DD")
});

// Parse comma-separated list
let buckets_vec: Option<Vec<String>> = buckets
.as_ref()
.map(|b| b.split(',').map(|s| s.to_string()).collect());

let sync_db: Option<PathBuf> = sync_db.as_ref().map(|db| {
let db_path = Path::new(db);
info!("Using sync dir: {}", &sync_dir.display());

if let Some(db_path) = &sync_db {
info!("Using sync db: {}", &db_path.display());

if !db_path.is_absolute() {
panic!("Sync db path must be absolute");
Err("Sync db path must be absolute")?
}
if !db_path.starts_with(sync_directory) {
panic!("Sync db path must be in sync directory");
if !db_path.starts_with(&sync_dir) {
Err("Sync db path must be in sync directory")?
}
db_path.to_path_buf()
});
}

let sync_spec = sync::SyncSpec {
path: sync_directory.to_path_buf(),
path: sync_dir,
path_db: sync_db,
buckets: buckets_vec,
start,
};

let mode_enum = match mode.as_str() {
"push" => sync::SyncMode::Push,
"pull" => sync::SyncMode::Pull,
"both" => sync::SyncMode::Both,
_ => panic!("Invalid mode"),
buckets,
start: start_date,
};

sync::sync_run(&client, &sync_spec, mode_enum)
sync::sync_run(&client, &sync_spec, mode)
}

// List all buckets
Expand Down
Loading

0 comments on commit dc70318

Please sign in to comment.