In [2]:
import sys
sys.path.append('..')

In [3]:
import torch
from src.pipe import MultiModalAuthPipeline, ImagePreprocessor, AudioPreprocessor, AdaFace, ReDimNet, ClassifierHead
import json
from pathlib import Path
from synthweave.utils.fusion import get_fusion
from synthweave.utils.tools import read_video

class dotdict(dict):
    def __getattr__(self, name):
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value

### PIPELINE

In [5]:
fusion_module_dir = Path("/home/woleek/SynthWeave/models/CAFF")

# config
args = json.loads((fusion_module_dir / "args.json").read_text())
args = dotdict(args)

# weights
weights_path = fusion_module_dir / "detection_module.ckpt"

In [6]:
preprocessors = {
    "video": ImagePreprocessor(
        window_len=4,
        step=1,
        estimate_quality=False,
        models_dir="/home/woleek/SynthWeave/models",
        quality_model_type="ir50"
    ),
    "audio": AudioPreprocessor(
        window_len=4,
        step=1,
        use_vad=True,
    )
}

Using cache found in /home/woleek/.cache/torch/hub/snakers4_silero-vad_master


In [7]:
models = {
    "video": AdaFace(
        path="/home/woleek/SynthWeave/models",
        model_type="ir50",
    ), 
    "audio": ReDimNet()
}

/home/woleek/.cache/torch/hub/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


Using cache found in /home/woleek/.cache/torch/hub/IDRnD_ReDimNet_master


In [8]:
fusion = get_fusion(
    fusion_name=args.fusion,
    output_dim=args.emb_dim,
    modality_keys=["video", "audio"],
    input_dims={"video":512, "audio":192},
    out_proj_dim=args.proj_dim,
)

detection_head = ClassifierHead(input_dim=args.emb_dim, num_classes=1)

[INFO] This fusion expects embeddings of shape (batch_size, embed_dim).


In [9]:
pipe = MultiModalAuthPipeline(
    processors=preprocessors,
    models=models,
    fusion=fusion,
    detection_head=detection_head,
    freeze_backbone=True,
    iil_mode=args.iil_mode,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)

state_dict = torch.load(weights_path, map_location="cpu")['state_dict']
state_dict = {k.replace("pipeline.", ""): v for k, v in state_dict.items()}
pipe.load_state_dict(state_dict, strict=False)

pipe = pipe.cuda()
pipe.eval();

### SAMPLE

In [18]:
df_sample = Path("../demo/samples/john_face_fake.mp4").resolve()
bf_sample = Path("../demo/samples/john_real.mp4").resolve()

In [19]:
vid, aud, meta = read_video(df_sample)
df = {
    "video": [vid, meta['video_fps']],
    "audio": [aud, meta['audio_fps']],
    "metadata": meta
}

vid, aud, meta = read_video(bf_sample)
bf = {
    "video": [vid, meta['video_fps']],
    "audio": [aud, meta['audio_fps']],
    "metadata": meta
}

Infer pipeline

In [20]:
with torch.no_grad():
    df_out = pipe(df) # Sample for pipeline inference
    bf_out = pipe(bf) # Simulates database

Threshold based on quality, face/voice detection etc.

In [21]:
print("Original len: ", df_out['org_len'])
print("Processed len: ", df_out['valid_len']) # NOTE: Can threshold based on % of dropped windows

Original len:  4
Processed len:  3


Check DeepFake module prediction

In [22]:
probs = torch.sigmoid(df_out["logits"]).cpu()
prob_per_clip = probs.mean()
preds_per_clip = (prob_per_clip >= 0.5).long() # NOTE: set threshold for DeepFake detection

print("Pred:", "Bonafide" if preds_per_clip.item() == 0 else "DeepFake")

Pred: DeepFake


Run verification

In [25]:
with torch.no_grad():
    sim = pipe.verify({
        "video": df_out["video"].mean(dim=0).unsqueeze(0).cpu(),
        "audio": df_out["audio"].mean(dim=0).unsqueeze(0).cpu(),
        "video_ref": bf_out["video"].mean(dim=0).unsqueeze(0).cpu(), # NOTE: put refference embeddings here
        "audio_ref": bf_out["audio"].mean(dim=0).unsqueeze(0).cpu()
    })

In [30]:
sim['video'].item()

0.6023023724555969

Face ver

In [32]:
# NOTE: Select aggregation methods across windows
# vid_sim = sim["video"].max(dim=1).values                # Max similarity to any reference
vid_sim = torch.tensor(sim["video"]).item()          # Mean similarity to reference

vid_th = 60.0 # NOTE: set threshold for video similarity

passed = vid_sim > vid_th
print("Face verified:", passed) # True if all windows passed the threshold

Face verified: False


  vid_sim = torch.tensor(sim["video"]).item()          # Mean similarity to reference


Audio ver

In [34]:
# NOTE: Select aggregation methods across windows
# aud_sim = sim["audio"].max(dim=1).values                # Max similarity to any reference
aud_sim = torch.tensor(sim["audio"]).item()          # Mean similarity to reference

aud_th = 60.0 # NOTE: set threshold for video similarity

passed = aud_sim > aud_th
print("Voice verified:", passed) # True if all windows passed the threshold

Voice verified: False


  aud_sim = torch.tensor(sim["audio"]).item()          # Mean similarity to reference
