In [2]:
import torch

from unimodal import AdaFace, ReDimNet
from pipe import MultiModalAuthPipeline, ImagePreprocessor, AudioPreprocessor
from synthweave.utils.datasets import get_datamodule
from synthweave.utils.fusion import get_fusion

  check_for_updates()


### DATASET

In [3]:
vid_proc = ImagePreprocessor(window_len=4, step=2)
aud_proc = AudioPreprocessor(window_len=4, step=2)

ds_kwargs = {
    'video_processor': vid_proc, 'audio_processor': aud_proc, 'mode': 'minimal'
}

dm = get_datamodule("DeepSpeak_v1", batch_size=1, dataset_kwargs=ds_kwargs, 
                    sample_mode='single', # single, sequence
                    clip_mode='id', # 'id', 'idx'
                    clip_to=1, # 'min', int
                    clip_selector='first', # 'first', 'random'
)

dm.setup('fit')

In [4]:
train_loader = dm.train_dataloader()

### FEATURE EXTRACTORS

In [5]:
aud_model = ReDimNet(
    freeze=True
)

img_model = AdaFace(
    path='../../../models',
    freeze=True
)

Using cache found in /home/woleek/.cache/torch/hub/IDRnD_ReDimNet_master


/home/woleek/.cache/torch/hub/IDRnD_ReDimNet_master
load_res : <All keys matched successfully>


### FUSION

In [6]:
FUSION = "CFF"
EMB_DIM = 256

fusion = get_fusion(
    fusion_name=FUSION,
    output_dim=EMB_DIM,
    modality_keys=["video", "audio"],
    out_proj_dim=256,
    
    # num_att_heads=4,
)

[INFO] This fusion expects embeddings of shape (batch_size, embed_dim).


### PIPELINE

In [12]:
pipe = MultiModalAuthPipeline(
    models={
        'audio': aud_model,
        'video': img_model
    },
    fusion=fusion,
    detection_head=torch.nn.Sequential(torch.nn.Linear(EMB_DIM, 1), torch.nn.Sigmoid()),
    freeze_backbone=True,
)

pipe.eval();

In [27]:
sample = next(iter(train_loader))

print(sample['video'].shape, sample['audio'].shape)

with torch.no_grad():
    out = pipe(sample)

torch.Size([1, 3, 112, 112]) torch.Size([1, 1, 64000])


: 

In [14]:
out['embedding'].shape, out['video'].shape, out['audio'].shape

(torch.Size([1, 256]), torch.Size([1, 512]), torch.Size([1, 192]))

In [20]:
'fake' if out['logits'].item() > 0.5 else 'real'

'fake'

In [22]:
out['video_ref'] = torch.rand_like(out['video'])
out['audio_ref'] = torch.rand_like(out['audio'])

# out['video_ref'] = out['video'].clone()
# out['audio_ref'] = out['audio'].clone()

with torch.no_grad():
    sim = pipe.verify(out)

In [23]:
sim['video'][0].item(), sim['audio'][0].item()

(-4.457475662231445, 1.972063422203064)