Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of hifigan .pths and non hifigan .pths #118

Open
StillTravelling opened this issue May 31, 2024 · 0 comments
Open

Better handling of hifigan .pths and non hifigan .pths #118

StillTravelling opened this issue May 31, 2024 · 0 comments

Comments

@StillTravelling
Copy link

Sorry I'm hopeless with git but I think the below changes will help when switching between hifigan and non hifigan .pths....

Modules\Tortoise-tts\tortoise\utils\audio.py

def load_voice(voice, extra_voice_dirs=[], load_latents=True, sample_rate=22050, device='cpu', model_hash=None, use_hifigan=False):
    if voice == 'random':
        return None, None
    print(f"hifigan = {use_hifigan}, voice={voice}")
    voices = _get_voices(dirs=[get_voice_dir()] + extra_voice_dirs, load_latents=load_latents)

    paths = voices[voice]
    mtime = 0
    
    latent = None
    voices = []

    for path in paths:
        filename = os.path.basename(path)
        if filename[-4:] == ".pth" and use_hifigan == False and filename[:12] == "cond_latents":
            if not model_hash and filename == "cond_latents.pth":
                latent = path
            elif model_hash and filename == f"cond_latents_{model_hash[:8]}.pth":
                latent = path
        elif filename[-4:] == ".pth" and use_hifigan == True and filename[:20] == "hifigan_cond_latents":
            if not model_hash and filename == "hifigan_cond_latents.pth":
                latent = path
            elif model_hash and filename == f"hifigan_cond_latents_{model_hash[:8]}.pth":
                latent = path
        else:
            voices.append(path)
            mtime = max(mtime, os.path.getmtime(path))

    if load_latents and latent is not None:
        #if os.path.getmtime(latent) > mtime:
            print(f"Reading from latent: {latent}")
            return None, torch.load(latent, map_location=device)
        #print(f"Latent file out of date: {latent}")
    
    samples = []
    for path in voices:
        c = load_audio(path, sample_rate)
        samples.append(c)
    return samples, None
	
	



src\utils.py

def fetch_voice( voice ):
		cache_key = f'{voice}:{tts.autoregressive_model_hash[:8]}'
		if cache_key in voice_cache:
			return voice_cache[cache_key]

		print(f"Loading voice: {voice} with model {tts.autoregressive_model_hash[:8]}")
		sample_voice = None
		if voice == "microphone":
			if parameters['mic_audio'] is None:
				raise Exception("Please provide audio from mic when choosing `microphone` as a voice input")
			voice_samples, conditioning_latents = [load_audio(parameters['mic_audio'], tts.input_sample_rate)], None
		elif voice == "random":
			voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
		else:
			if progress is not None:
				notify_progress(f"Loading voice: {voice}", progress=progress)

			voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash, use_hifigan=args.use_hifigan)
			
		if voice_samples and len(voice_samples) > 0:
			if conditioning_latents is None:
				conditioning_latents = compute_latents(voice=voice, voice_samples=voice_samples, voice_latents_chunks=parameters['voice_latents_chunks'])
				
			sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
			voice_samples = None

		voice_cache[cache_key] = (voice_samples, conditioning_latents, sample_voice)
		return voice_cache[cache_key]
		
		
###		
	def get_info( voice, settings = None, latents = True ):
		info = {}
		info.update(parameters)

		info['time'] = time.time()-full_start_time
		info['datetime'] = datetime.now().isoformat()

		info['model'] = tts.autoregressive_model_path
		info['model_hash'] = tts.autoregressive_model_hash 

		info['progress'] = None
		del info['progress']

		if info['delimiter'] == "\n":
			info['delimiter'] = "\\n"

		if settings is not None:
			for k in settings:
				if k in info:
					info[k] = settings[k]

			if 'half_p' in settings and 'cond_free' in settings:
				info['experimentals'] = []
				if settings['half_p']:
					info['experimentals'].append("Half Precision")
				if settings['cond_free']:
					info['experimentals'].append("Conditioning-Free")

		if latents and "latents" not in info:
			voice = info['voice']
			model_hash = settings["model_hash"][:8] if settings is not None and "model_hash" in settings else tts.autoregressive_model_hash[:8]

			dir = f'{get_voice_dir()}/{voice}/'
			if args.use_hifigan:
				latents_path = f'{dir}/cond_latents_{model_hash}.pth'
			else:
				latents_path = f'{dir}/hifigan_cond_latents_{model_hash}.pth'

			if voice == "random" or voice == "microphone":
				if args.use_hifigan:
					if latents and settings is not None and torch.any(settings['conditioning_latents']):
						os.makedirs(dir, exist_ok=True)
						torch.save(conditioning_latents, latents_path)
				else: 
					if latents and settings is not None and settings['conditioning_latents']:
						os.makedirs(dir, exist_ok=True)
						torch.save(conditioning_latents, latents_path)

			if latents_path and os.path.exists(latents_path):
				try:
					with open(latents_path, 'rb') as f:
						info['latents'] = base64.b64encode(f.read()).decode("ascii")
				except Exception as e:
					pass

		return info		
		

###
settings = get_settings( override=override )
		#print(settings) #This line changed to comment out
		try:
			if args.use_hifigan:
				gen = tts.tts(cut_text, **settings)
			else:
				gen, additionals = tts.tts(cut_text, **settings )
				parameters['seed'] = additionals[0]
		except Exception as e:
			raise RuntimeError(f'Possible latent mismatch: click the "(Re)Compute Voice Latents" button and then try again. Error: {e}')



###
def compute_latents(voice=None, voice_samples=None, voice_latents_chunks=0, original_ar=False, original_diffusion=False):
	global tts
	global args
	
	unload_whisper()
	unload_voicefixer()

	if not tts:
		if tts_loading:
			raise Exception("TTS is still initializing...")
		load_tts()

	if hasattr(tts, "loading") and tts.loading:
		raise Exception("TTS is still initializing...")

	if args.tts_backend == "bark":
		tts.create_voice( voice )
		return

	if args.autoregressive_model == "auto":
		tts.load_autoregressive_model(deduce_autoregressive_model(voice))

	if voice:
		load_from_dataset = voice_latents_chunks == 0

		if load_from_dataset:
			dataset_path = f'./training/{voice}/train.txt'
			if not os.path.exists(dataset_path):
				load_from_dataset = False
			else:
				with open(dataset_path, 'r', encoding="utf-8") as f:
					lines = f.readlines()

				print("Leveraging dataset for computing latents")

				voice_samples = []
				max_length = 0
				for line in lines:
					filename = f'./training/{voice}/{line.split("|")[0]}'
					
					waveform = load_audio(filename, 22050)
					max_length = max(max_length, waveform.shape[-1])
					voice_samples.append(waveform)

				for i in range(len(voice_samples)):
					voice_samples[i] = pad_or_truncate(voice_samples[i], max_length)

				voice_latents_chunks = len(voice_samples)
				if voice_latents_chunks == 0:
					print("Dataset is empty!")
					load_from_dataset = True
		if not load_from_dataset:
			voice_samples, _ = load_voice(voice, load_latents=False, use_hifigan=args.use_hifigan) #This line changed

	if voice_samples is None:
		return

	conditioning_latents = tts.get_conditioning_latents(voice_samples, return_mels=not args.latents_lean_and_mean, slices=voice_latents_chunks, force_cpu=args.force_cpu_for_conditioning_latents, original_ar=original_ar, original_diffusion=original_diffusion)

	if len(conditioning_latents) == 4:
		conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
	if args.use_hifigan: #newsection
		outfile = f'{get_voice_dir()}/{voice}/hifigan_cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
	else:
		outfile = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth' #end newsection
	torch.save(conditioning_latents, outfile)
	print(f'Saved voice latents: {outfile}')

	return conditioning_latents

###
def reload_tts():
	unload_tts()
	load_tts()
	
def change_hifigan(newvalue=True): #newsection
	args.use_hifigan=newvalue
	save_args_settings()
	do_gc()
	reload_tts()
	return args.use_hifigan
	
def get_hifigan():
	return args.use_hifigan  #endnewsection


src\webui.py

EXEC_SETTINGS['autoregressive_model'].change(
					fn=update_autoregressive_model,
					inputs=EXEC_SETTINGS['autoregressive_model'],
					outputs=None,
					api_name="set_autoregressive_model"
				)
				
				EXEC_SETTINGS['use_hifigan'].change( #newsection
					fn=change_hifigan,
					inputs=EXEC_SETTINGS['use_hifigan'],
					outputs=EXEC_SETTINGS['use_hifigan'],
					api_name="use_hifigan"
				)
				
				EXEC_SETTINGS['use_hifigan'].select(
					fn=get_hifigan,
					outputs=EXEC_SETTINGS['use_hifigan'],
					api_name="get_hifigan"
				) #endnewsection
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant