In [2]:
import random

import torch
import torch.nn as nn

In [3]:
def get_style(t, dims):
    mu = torch.mean(t, dim=dims, keepdim=True)
    diff = torch.square(t - mu)
    diff_mean = torch.mean(diff, dim=dims, keepdim=True)
    sigma = torch.sqrt(diff_mean + 0.000001)
    return mu, sigma

In [4]:
def change_style(z1, z2, dims=[-1], alpha=-1):
    mu1, sigma1 = get_style(z1, dims=dims)
    mu2, sigma2 = get_style(z2, dims=dims)

    if alpha == -1:
        alpha = random.random()
        
    mu_hat = alpha * mu1 + (1.0 - alpha) * mu2
    sigma_hat = alpha * sigma1 + (1.0 - alpha) * sigma2
    z_prime = sigma_hat * ((z1 - mu1) / sigma1) + mu_hat
    return z_prime

In [5]:
def style_aug(z, dims=[-1], ids=None):
    if ids is None:
        ids = list(range(len(z)))
        random.shuffle(ids)
    z_styled = change_style(z, z[ids, ...], dims=dims)
    z_adv = change_style(z[ids, ...], z, dims=dims, alpha=0)
    return z_styled, z_adv

In [6]:
def style_tensors(*tensors, dims=[-1]):
    ids = list(range(len(tensors[0])))
    random.shuffle(ids)
    res = []
    for tensor in tensors:
        res.append(style_aug(tensor, dims=dims, ids=ids))
    return res

In [7]:
f_video, f_audio = torch.rand(16, 3, 1024), torch.rand(16, 3, 1024)

style_aug(f_video)

(f_a_style, f_a_adv), (f_v_style, f_v_adv) = style_tensors(f_audio, f_video)

In [8]:
for para in nn.Conv3d(256, 256, 3, padding=1, groups=256).parameters():
    print(para.shape)

torch.Size([256, 1, 3, 3, 3])
torch.Size([256])


In [9]:
def fuse_audio_video_with_p(f_video, f_audio):
    B = f_video.shape[0]
    ids = random.sample(list(range(B)), B)
    shuffle_ids = ids[: B // 2]
    shuffle_ids2 = ids[B // 2 :]
    res = torch.concat([f_video, f_audio], dim=-1)
    res[shuffle_ids, :] = torch.concat(
        [f_video[shuffle_ids, :], f_audio[shuffle_ids2, :]], dim=-1
    )
    return res, shuffle_ids

In [10]:
fuse_aduio_video_with_p(f_video, f_audio)

(tensor([[[0.3621, 0.8210, 0.2526,  ..., 0.3885, 0.5031, 0.7670],
          [0.9842, 0.4927, 0.3066,  ..., 0.0049, 0.9713, 0.4688],
          [0.8444, 0.4985, 0.1776,  ..., 0.9809, 0.3983, 0.7339]],
 
         [[0.2346, 0.9648, 0.3303,  ..., 0.5126, 0.6680, 0.1643],
          [0.1991, 0.4186, 0.5656,  ..., 0.1372, 0.3342, 0.4348],
          [0.3077, 0.7347, 0.4348,  ..., 0.5088, 0.9867, 0.4648]],
 
         [[0.8383, 0.6945, 0.9600,  ..., 0.3152, 0.6015, 0.0760],
          [0.7534, 0.6611, 0.9517,  ..., 0.8275, 0.2450, 0.8317],
          [0.5968, 0.1178, 0.5724,  ..., 0.9501, 0.6088, 0.0088]],
 
         ...,
 
         [[0.2669, 0.7850, 0.7322,  ..., 0.2213, 0.7151, 0.2965],
          [0.1696, 0.9538, 0.1899,  ..., 0.2122, 0.5088, 0.1797],
          [0.7217, 0.2593, 0.4761,  ..., 0.9330, 0.7914, 0.9525]],
 
         [[0.6970, 0.9671, 0.4331,  ..., 0.5650, 0.8449, 0.0608],
          [0.2549, 0.5054, 0.5485,  ..., 0.7795, 0.6867, 0.1536],
          [0.8304, 0.1958, 0.0025,  ..., 0.5618,

In [19]:
def fuse_audio_video_with_shuffle(f_video, f_audio):
    B = f_video.shape[0]
    ids = random.sample(list(range(B)), B)
    shuffle_ids = [x for i, x in enumerate(ids) if x != i]
    res = torch.concat([f_video, f_audio[ids, :]], dim=-1)
    return res, shuffle_ids