In [5]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os
import random
from src.loader import preprocess_audio
from src.umodel import StegoUNet, VoiceCloner
from src.losses import calc_ber, signal_noise_ratio, batch_calc_ber
from src.loader import StegoDataset
from torch.utils.data import DataLoader

from dotenv import load_dotenv
from datasets import load_dataset
import matplotlib.pyplot as plot
from IPython.display import Audio
load_dotenv()

True

In [6]:
def parse_keyword(keyword):
    if isinstance(keyword, bool): return keyword
    if keyword.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif keyword.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Wrong keyword.')


parser = argparse.ArgumentParser()

parser.add_argument('--stft_small',
                    type=parse_keyword,
                    default=True,
                    metavar='BOOL',
                    help='If [fourier], whether to use a small or large container'
                    )
parser.add_argument('--ft_container',
                    type=str,
                    default='mag',
                    metavar='STR',
                    help='If [fourier], container to use: [mag], [phase], [magphase]'
                    )
parser.add_argument('--mp_encoder',
                    type=str,
                    default='single',
                    metavar='STR',
                    help='If [fourier] and [magphase], type of magphase encoder: [single], [double]'
                    )
parser.add_argument('--mp_decoder',
                    type=str,
                    default='unet',
                    metavar='STR',
                    help='If [fourier] and [magphase], type of magphase encoder: [unet], [double]'
                    )
parser.add_argument('--mp_join',
                    type=str,
                    default='mean',
                    metavar='STR',
                    help='If [fourier] and [magphase] and [decoder=double], type of join operation: [mean], [2D], [3D]'
                    )
parser.add_argument('--permutation',
                    type=parse_keyword,
                    default=False,
                    metavar='BOOL',
                    help='Permute the encoded image before adding it to the audio'
                    )
parser.add_argument('--embed',
                    type=str,
                    default='stretch',
                    metavar='STR',
                    help='Method of adding the image into the audio: [stretch], [blocks], [blocks2], [blocks3], [multichannel], [luma]'
                    )
parser.add_argument('--luma',
                    type=parse_keyword,
                    default=False,
                    metavar='BOOL',
                    help='Add luma component as the fourth pixelshuffle value'
                    )
parser.add_argument('--num_points',
                    type=int,
                    default=48000,
                    help="the length model can handle")
parser.add_argument('--n_fft',
                    type=int,
                    default=1000)
parser.add_argument('--hop_length',
                    type=int,
                    default=800)
parser.add_argument("--mag",
                   type=bool,
                   default=False)
parser.add_argument("--num_layers",
                   type=int,
                   default=4)
parser.add_argument('--transform',
                    type=str,
                    choices=["ID", "TC", "RS", "VC"],
                    default="ID",
                   )
parser.add_argument("--watermark_len",
                    type=int,
                    default=4)
parser.add_argument("--dataset_i",
                    type=int,
                    choices=[0, 1],
                    default=0)
parser.add_argument("--shift_ratio",
                    type=float,
                    default=0
                    )

def set_reproductibility(seed=2023):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
set_reproductibility()

pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
           0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
           1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
           1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
           0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0]
pattern = [0, 1, 1, 1, 1, 0, 1, 1, 0 ,1, 1, 1, 1, 1, 1, 1, 1 ,1, 0, 0, 1,
           0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0]

In [7]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

args = parser.parse_args("")
print(args)

Using device: cuda
Namespace(stft_small=True, ft_container='mag', mp_encoder='single', mp_decoder='unet', mp_join='mean', permutation=False, embed='stretch', luma=False, num_points=48000, n_fft=1000, hop_length=800, mag=False, num_layers=4, transform='ID', watermark_len=4, dataset_i=0, shift_ratio=0)


In [11]:
# load model
model = StegoUNet(
        transform=args.transform,
        num_points=args.num_points,
        n_fft=args.n_fft,
        hop_length=args.hop_length,
        mag=args.mag,
        num_layers=args.num_layers,
        watermark_len=args.watermark_len,
        shift_ratio=args.shift_ratio
    )

# load dataset
DATA_FOLDER = os.environ.get('DATA_PATH')
AUDIO_FOLDER = f"{DATA_FOLDER}/LibrispeechVoiceClone_"
dataset = StegoDataset(
        audio_root_i=args.dataset_i,
        folder="test",
        num_points=48000,
        watermark_len=args.watermark_len,
        shift_ratio=args.shift_ratio
    )

# Load checkpoint
ckpt_path = '1-wavmarkConfig_wl32lr1e-4audioMSElosslam100/9-1-wavmarkConfig_wl32lr1e-4audioMSElosslam100.pt'
ckpt_path = "1-VCwl4lr1e-4audioMSElam100/1-1-VCwl4lr1e-4audioMSElam100.pt"
checkpoint = torch.load(os.path.join(os.environ.get('OUT_PATH'), ckpt_path),
                        map_location='cpu')
# if torch.cuda.device_count() > 1:
#     model = nn.DataParallel(model)
model.load_state_dict(checkpoint['state_dict'])
print('Checkpoint loaded')

# voice_cloner = VoiceCloner()

DATA LOCATED AT: /home/rz60/data/LibrispeechVoiceClone_test
Set up done
Checkpoint loaded


In [23]:
# load audio
# name = "speech7"
# dataset = load_dataset("librispeech_asr")
# data = dataset["train.clean.100"][10000]
# _, sound, sr = data["audio"].values()
# sound = torch.tensor(sound).unsqueeze(0).float()
# torchaudio.save(f"src/output_test/{name}/origin.wav", sound, sr)


# audio_path = f"src/output_test/{name}/origin.wav"
# audio_path = "VALL-E_1.wav"
# sound, sr = torchaudio.load(audio_path)
# torchaudio.save(f"src/output_test/{name}/watermarked.wav", sound, sr)

# load from dataset
(sequence, sequence_binary), sound, transcript, text_prompt, shift_sound = dataset[400]
print(shift_sound)
sr = 16000
start = 8000
sound = sound[None, :48000]
print(transcript)
    
print(f"original sound len: {sound.shape}")
sound, _ = preprocess_audio(sound,args.num_points, shift_ratio=args.shift_ratio)
# sound = sound[:,4000:4000+model.num_points]
print(f"Preprocessed sound len: {sound.shape}")

# play the audio
sound = sound[None,:]
Audio(sound, rate=sr)

[]
 happened before at some time.
original sound len: torch.Size([1, 48000])
Preprocessed sound len: torch.Size([48000])


In [24]:
# generate secret
# secret = torch.normal(0.4, 0.2, (32,))
# secret = torch.tensor(pattern[:32]).float()
secret = torch.rand(args.watermark_len)

secret_binary = (secret > 0.5).int()
secret, secret_binary = secret.unsqueeze(0), secret_binary.unsqueeze(0)
print(f"secret: {secret_binary}")

secret: tensor([[0, 1, 0, 1]], dtype=torch.int32)


In [25]:
# run model
# cover_fft, containers_fft, container_wav, revealed = model(secret, sound)

# encode
container_wav = model.encode(secret, sound)

# voice cloning
# audio_prompt_path = "suno_tts_b.wav"
# transcript = "I used to work at a fire hydrant factory and you know you couldn't park anywhere near the place."
# text_prompt = "Hello, my name is Ray"
# cloned_voice, sr_clone = voice_cloner.clone(audio_prompt_path=audio_prompt_path, text_prompt=text_prompt, transcript=transcript)

cloned_voice, sr_clone = container_wav.clone().detach(), sr
# decode
revealed = model.decode(container_wav)

In [26]:
# play watermarked_audio
container_wav_np = container_wav.detach().numpy()
print(f"watermarked sound shape: {container_wav_np.shape}")
Audio(container_wav_np, rate=sr)

watermarked sound shape: (1, 48000)


In [27]:
# play clone voice
Audio(cloned_voice, rate=sr_clone)

In [31]:
# evaluate
def evaluate(snd, snd_wm, srt, srt_rv):
    snr = signal_noise_ratio(snd, snd_wm)
    ber = calc_ber(srt_rv, srt)

    print(f"SNR: {snr}, BER: {ber * 100}%")
    print(f"Secret: {(srt > 0.5).int()} (the number of 1: {(srt > 0.5).sum()}/{len(srt[0])})")
    print(f"Revealed: {(srt_rv > 0.5).int()}")

sound_np = sound.numpy().squeeze(0)
container_wav_np = container_wav.detach().numpy()
container_wav_np = container_wav_np.squeeze(0)
# print(sound_np.shape, container_wav_np.shape)
evaluate(sound_np, container_wav_np, secret, revealed)

(48000,) (48000,)
SNR: 2.83545583486557, BER: 0.0%
Secret: tensor([[0, 1, 0, 1]], dtype=torch.int32) (the number of 1: 2/4)
Revealed: tensor([[0, 1, 0, 1]], dtype=torch.int32)


In [32]:
from utils.prompt_making import make_prompt, make_prompt_train
from utils.generation import SAMPLE_RATE, generate_audio, preload_models, generate_audio_train
import whisper

In [33]:
%%time
preload_models()
whisper_m = whisper.load_model("base")

CPU times: user 9min 2s, sys: 6.58 s, total: 9min 8s
Wall time: 17.3 s


In [40]:
print(transcript)
audio_tokens, text_tokens, lang_pr = make_prompt_train(name="test", audio_prompt=sound, sr=sr, transcript=transcript)

# make_prompt(name="test", audio_prompt_path="suno_tts_b.wav")
text_prompt = """
Test Voice Cloning
"""
audio_array = generate_audio_train(text_prompt,audio_tokens=audio_tokens,text_tokens=text_tokens,lang_pr=lang_pr)

 happened before at some time.


In [41]:
print(audio_array.shape)
Audio(audio_array.clone().detach().cpu(), rate=SAMPLE_RATE)

torch.Size([33600])
