diff --git a/PULSE.py b/PULSE.py index d499aaa..13b8dd1 100644 --- a/PULSE.py +++ b/PULSE.py @@ -19,9 +19,14 @@ def __init__(self, cache_dir, verbose=True): cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok = True) - if self.verbose: print("Loading Synthesis Network") - with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f: - self.synthesis.load_state_dict(torch.load(f)) + if self.verbose: + print("Loading Synthesis Network") + synthesis_cached = f"{cache_dir}/synthesis.pt" + if Path(synthesis_cached).exists(): + self.synthesis.load_state_dict(torch.load(synthesis_cached)) + else: + with open_url("https://drive.google.com/uc?id=1TCViX1YpQyRsklTVYEJwdbmK91vklCo8", cache_dir=cache_dir, verbose=verbose) as f: + self.synthesis.load_state_dict(torch.load(f)) for param in self.synthesis.parameters(): param.requires_grad = False @@ -31,10 +36,15 @@ def __init__(self, cache_dir, verbose=True): if Path("gaussian_fit.pt").exists(): self.gaussian_fit = torch.load("gaussian_fit.pt") else: - if self.verbose: print("\tLoading Mapping Network") + if self.verbose: + print("\tLoading Mapping Network") mapping = G_mapping().cuda() - with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f: + mapping_cached = f"{cache_dir}/mapping.pt" + if Path(mapping_cached).exists(): + mapping.load_state_dict(torch.load(mapping_cached)) + else: + with open_url("https://drive.google.com/uc?id=14R6iHGf5iuVx3DMNsACAl7eBr7Vdpd0k", cache_dir=cache_dir, verbose=verbose) as f: mapping.load_state_dict(torch.load(f)) if self.verbose: print("\tRunning Mapping Network") diff --git a/align_face.py b/align_face.py index c15aa9b..0ab4c07 100644 --- a/align_face.py +++ b/align_face.py @@ -30,9 +30,15 @@ output_dir = Path(args.output_dir) output_dir.mkdir(parents=True,exist_ok=True) -print("Downloading Shape Predictor") -f=open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True) -predictor = dlib.shape_predictor(f) +predictor_cached = f"{cache_dir}/shape_predictor_68_face_landmarks.dat" + +if Path(predictor_cached).exists(): + print("Using cached Shape Predictor") + predictor = dlib.shape_predictor(predictor_cached) +else: + print("Downloading Shape Predictor") + with open_url("https://drive.google.com/uc?id=1huhv8PYpNNKbGCLOaYUjOgR1pY5pmbJx", cache_dir=cache_dir, return_path=True) as f: + predictor = dlib.shape_predictor(f) for im in Path(args.input_dir).glob("*.*"): faces = align_face(str(im),predictor)