In [16]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import os
import random
from src.loader import preprocess_audio
from src.umodel import StegoUNet, VoiceCloner
from src.losses import calc_ber
from src.losses import calc_ber, signal_noise_ratio

from dotenv import load_dotenv
from datasets import load_dataset

from IPython.display import Audio
load_dotenv()

True

In [17]:
def parse_keyword(keyword):
    if isinstance(keyword, bool): return keyword
    if keyword.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif keyword.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Wrong keyword.')


parser = argparse.ArgumentParser()

parser.add_argument('--transform',
                    type=str,
                    default='fourier',
                    metavar='STR',
                    help='Which transform to use: [cosine] or [fourier]'
                    )
parser.add_argument('--stft_small',
                    type=parse_keyword,
                    default=True,
                    metavar='BOOL',
                    help='If [fourier], whether to use a small or large container'
                    )
parser.add_argument('--ft_container',
                    type=str,
                    default='mag',
                    metavar='STR',
                    help='If [fourier], container to use: [mag], [phase], [magphase]'
                    )
parser.add_argument('--mp_encoder',
                    type=str,
                    default='single',
                    metavar='STR',
                    help='If [fourier] and [magphase], type of magphase encoder: [single], [double]'
                    )
parser.add_argument('--mp_decoder',
                    type=str,
                    default='unet',
                    metavar='STR',
                    help='If [fourier] and [magphase], type of magphase encoder: [unet], [double]'
                    )
parser.add_argument('--mp_join',
                    type=str,
                    default='mean',
                    metavar='STR',
                    help='If [fourier] and [magphase] and [decoder=double], type of join operation: [mean], [2D], [3D]'
                    )
parser.add_argument('--permutation',
                    type=parse_keyword,
                    default=False,
                    metavar='BOOL',
                    help='Permute the encoded image before adding it to the audio'
                    )
parser.add_argument('--embed',
                    type=str,
                    default='stretch',
                    metavar='STR',
                    help='Method of adding the image into the audio: [stretch], [blocks], [blocks2], [blocks3], [multichannel], [luma]'
                    )
parser.add_argument('--luma',
                    type=parse_keyword,
                    default=False,
                    metavar='BOOL',
                    help='Add luma component as the fourth pixelshuffle value'
                    )
parser.add_argument('--num_points',
                    type=int,
                    default=64000 - 400,
                    help="the length model can handle")
parser.add_argument('--n_fft',
                    type=int,
                    default=1022)
parser.add_argument('--hop_length',
                    type=int,
                    default=400)
parser.add_argument("--mag",
                   type=bool,
                   default=False)

pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0,
           0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
           1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1,
           1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0,
           0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0]
pattern = [0, 1, 1, 1, 1, 0, 1, 1, 0 ,1, 1, 1, 1, 1, 1, 1, 1 ,1, 0, 0, 1,
           0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0]

In [18]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

args = parser.parse_args("")
print(args)

Using device: cuda
Namespace(transform='fourier', stft_small=True, ft_container='mag', mp_encoder='single', mp_decoder='unet', mp_join='mean', permutation=False, embed='stretch', luma=False, num_points=63600, n_fft=1022, hop_length=400, mag=False)


In [93]:
# load model
model = StegoUNet(
        transform=args.transform,
        stft_small=args.stft_small,
        ft_container=args.ft_container,
        mp_encoder=args.mp_encoder,
        mp_decoder=args.mp_decoder,
        mp_join=args.mp_join,
        permutation=args.permutation,
        embed=args.embed,
        luma=args.luma,
        num_points=args.num_points,
        n_fft=args.n_fft,
        hop_length=args.hop_length,
        mag=args.mag
    )

# Load checkpoint
ckpt_path = '1-Test_tf_MagPhase/8-1-Test_tf_MagPhase.pt'
checkpoint = torch.load(os.path.join(os.environ.get('OUT_PATH'), ckpt_path),
                        map_location='cpu')
# if torch.cuda.device_count() > 1:
#     model = nn.DataParallel(model)
model.load_state_dict(checkpoint['state_dict'])
print('Checkpoint loaded')

voice_cloner = VoiceCloner()

Checkpoint loaded


In [109]:
# load audio
dataset = load_dataset("librispeech_asr")
data = dataset["train.clean.100"][0]
_, sound, sr = data["audio"].values()
sound = torch.tensor(sound).unsqueeze(0).float()

print(f"original sound len: {sound.shape}")
sound = preprocess_audio(sound,args.num_points)
print(f"Preprocessed sound len: {sound.shape}")

# play the audio
Audio(sound, rate=sr)

original sound len: torch.Size([1, 232480])
Preprocessed sound len: torch.Size([63600])


In [110]:
# generate secret
# secret = torch.normal(0.4, 0.2, (32,))
# secret = torch.tensor(pattern[:32]).float()
secret = torch.rand(32)

secret_binary = (secret > 0.5).int()
secret, secret_binary = secret.unsqueeze(0), secret_binary.unsqueeze(0)
print(f"secret: {secret_binary}")

secret: tensor([[0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0,
         0, 1, 1, 0, 1, 1, 0, 0]], dtype=torch.int32)


In [115]:
# run model
# cover_fft, containers_fft, container_wav, revealed = model(secret, sound)

# encode
container_wav = model.encode(secret, sound)

# voice cloning
# audio_prompt_path = "suno_tts_b.wav"
# transcript = "I used to work at a fire hydrant factory and you know you couldn't park anywhere near the place."
# text_prompt = "Hello, my name is Ray"
# cloned_voice, sr_clone = voice_cloner.clone(audio_prompt_path=audio_prompt_path, text_prompt=text_prompt, transcript=transcript)

cloned_voice, sr_clone = container_wav.clone().detach(), sr
# decode
revealed = model.decode(cloned_voice)

VALL-E EOS [645 -> 786]
False


In [112]:
# play watermarked_audio
container_wav_np = container_wav.detach().numpy()
print(f"watermarked sound shape: {container_wav_np.shape}")
Audio(container_wav_np, rate=sr)

watermarked sound shape: (1, 63600)


In [113]:
# play clone voice
Audio(cloned_voice, rate=sr_clone)

In [114]:
# evaluate
def evaluate(snd, snd_wm, srt, srt_rv):
    snr = signal_noise_ratio(snd, snd_wm)
    ber = calc_ber(srt_rv, srt)

    print(f"SNR: {snr}, BER: {ber * 100}%")
    print(f"Secret: {(srt > 0.5).int()} (the number of 1: {(srt > 0.5).sum()}/{len(srt[0])})")
    print(f"Revealed: {(srt_rv > 0.5).int()}")

sound_np = sound.numpy()
container_wav_np = container_wav.detach().numpy()
container_wav_np = container_wav_np.squeeze(0)
evaluate(sound_np, container_wav_np, secret, revealed)

SNR: 35.7042121887207, BER: 0.0%
Secret: tensor([[0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0,
         0, 1, 1, 0, 1, 1, 0, 0]], dtype=torch.int32) (the number of 1: 18/32)
Revealed: tensor([[0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0,
         0, 1, 1, 0, 1, 1, 0, 0]], dtype=torch.int32)


In [91]:
random_acc = 0
epoch = 10000
for i in range(epoch):
    random_acc += calc_ber(secret, torch.rand(32))
random_acc /= epoch
random_acc

tensor(0.5008)