In [None]:
import sys
sys.path.append('..')

In [None]:
import torch
from src.pipe import MultiModalAuthPipeline, ImagePreprocessor, AudioPreprocessor, AdaFace, ReDimNet, ClassifierHead
from synthweave.utils.datasets import get_datamodule, SWAN_DF_Dataset
from synthweave.utils.fusion import get_fusion
from pathlib import Path
import json

class dotdict(dict):
    def __getattr__(self, name):
        return self[name]

    def __setattr__(self, name, value):
        self[name] = value

### DATASET

In [None]:
ds_kwargs = {
    "root_real": "/home/woleek/SynthWeave/data/SWAN-Idiap",
    "root_df": "/home/woleek/SynthWeave/data/SWAN-DF",
    "resolutions": ["320x320"],
    "video_processor": ImagePreprocessor(
        window_len=4,
        step=1,
        estimate_quality=False,
        models_dir="/home/woleek/SynthWeave/models",
        quality_model_type="ir50"
    ),
    "audio_processor": AudioPreprocessor(
        window_len=4,
        step=1,
        use_vad=True,
    )
}

dm = get_datamodule(
    "SWAN_DF",
    dataset_cls=SWAN_DF_Dataset,
    batch_size=1, # NOTE: currently single window fusions don't ignore padding
    dataset_kwargs=ds_kwargs,
    sample_mode="sequence",  # single, sequence
    clip_mode = None,
    pad_mode = 'zeros',
    encode_ids=False
)

dm.setup()

### PIPELINE

In [None]:
fusion_module_dir = Path("/home/woleek/SynthWeave/SynthWeave/examples/multimodal_auth/CAFF")

# config
args = json.loads((fusion_module_dir / "args.json").read_text())
args = dotdict(args)

# weights
weights_path = fusion_module_dir / "detection_module.ckpt"

In [None]:
# preprocessors = {
#     "video": ImagePreprocessor,
#     "audio": AudioPreprocessor
# }

preprocessors = None # passed in DataModule

In [None]:
models = {
    "video": AdaFace(
        path="/home/woleek/SynthWeave/models",
        model_type="ir50",
    ), 
    "audio": ReDimNet()
}

In [None]:
fusion = get_fusion(
    fusion_name=args.fusion,
    output_dim=args.emb_dim,
    modality_keys=["video", "audio"],
    input_dims={"video":512, "audio":192},
    out_proj_dim=args.proj_dim,
)

detection_head = ClassifierHead(input_dim=args.emb_dim, num_classes=1)

In [None]:
pipe = MultiModalAuthPipeline(
    processors=preprocessors,
    models=models,
    fusion=fusion,
    detection_head=detection_head,
    freeze_backbone=True,
    iil_mode=args.iil_mode,
)

state_dict = torch.load(weights_path, map_location="cpu")['state_dict']
state_dict = {k.replace("pipeline.", ""): v for k, v in state_dict.items()}
pipe.load_state_dict(state_dict, strict=False)

pipe = pipe.cuda()
pipe.eval();

EXAMPLE RUN

In [None]:
test_loader = dm.test_dataloader()
sample = next(iter(test_loader))

In [None]:
sample['metadata']

In [None]:
# Place on GPU
sample["video"] = sample["video"].squeeze(0).cuda() # remove batch dim (for sequence len 1)
sample["audio"] = sample["audio"].squeeze(0).cuda()

with torch.no_grad():
    out = pipe(sample)

In [None]:
probs = torch.sigmoid(out["logits"]).cpu()
prob_per_clip = probs.mean()
preds_per_clip = (prob_per_clip >= 0.5).long() # NOTE: set threshold

print("GT:", "Bonafide" if sample['metadata']["label"].cpu().item() == 0 else "DeepFake")
print("Pred:", "Bonafide" if preds_per_clip.item() == 0 else "DeepFake")

In [None]:
with torch.no_grad():
    sim = pipe.verify({
        "video": out["video"][0].cpu(), # example for 1st window
        "audio": out["audio"][0].cpu(),
        "video_ref": torch.rand_like(out["video"][0]), # NOTE: put refference embeddings here
        "audio_ref": torch.rand_like(out["audio"][0])
    })

In [None]:
sim["video"].item(), sim["audio"].item()