In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class TemporalMask(nn.Module):

    def __init__(self, ):
        super(TemporalMask, self).__init__()

        # Mask Encoding
        self.l1 = torch.nn.Conv3d(3, 1, (3, 3, 3), padding=(1, 1, 1), bias=True)
        self.l2 = torch.nn.Conv3d(3, 1, (3, 3, 3), padding=(1, 1, 1), bias=True)
        self.l3 = torch.nn.Conv3d(3, 1, (3, 3, 3), padding=(1, 1, 1), bias=True)
        self.l4 = torch.nn.Sequential(
            torch.nn.Conv3d(3, 1, (3, 3, 3), padding=(1, 1, 1), bias=True),
        )

        # Downsampling Temporally, Upsampling Spatially
        self.u4 = torch.nn.Sequential(
            torch.nn.MaxPool3d((2, 1, 1), stride=(2, 1, 1)),
            torch.nn.ConvTranspose3d(1, 1, (1, 3, 3), padding=(0, 1, 1), stride=(1,2,2), output_padding=(0,1,1), bias=True)
        )
        self.u3 = torch.nn.Sequential(
            torch.nn.MaxPool3d((3, 1, 1), stride=(3, 1, 1)),
            torch.nn.ConvTranspose3d(1, 1, (3, 3, 3), padding=(1, 1, 1), stride=(2,2,2), output_padding=(1,1,1), bias=True),
        )
        self.u2 = torch.nn.Sequential(
            torch.nn.MaxPool3d((2, 1, 1), stride=(2, 1, 1)),
            torch.nn.ConvTranspose3d(1, 1, (1, 3, 3), padding=(0, 1, 1), stride=(1,2,2), output_padding=(0,1,1), bias=True),
        )

    def forward(self, x):

        x1 = x[:,:,::6]
        x2 = x[:,:,::3,::2,::2]
        x3 = x[:,:,::2,::4,::4]
        x4 = x[:,:,:  ,::8,::8]

        print(x1.shape, x2.shape, x3.shape, x4.shape)

        y1 = self.l1(x1)
        y2 = self.l2(x2)
        y3 = self.l3(x3)
        y4 = self.l4(x4)

        print(y1.shape, y2.shape, y3.shape, y4.shape)

        z3 = self.u4(y4) + y3
        z2 = self.u3(z3) + y2
        z1 = self.u2(z2) + y1

        print(z3.shape, z2.shape, z1.shape)

        return z1

def maskMultiply(x, mask):
    # Multiplies mask with x, interpolating mask to x's shape
    # x: (B, C, T, H, W)
    # mask: (B, 1, T//N, H, W), where N is the temporal downsampling factor

    new_mask_shape = list(mask.shape)
    new_mask_shape[2] = x.shape[2]
    interpolated_mask = F.interpolate(mask, size=new_mask_shape[2:], mode='trilinear', align_corners=False) #TODO: Check align_corners performance
    return x * interpolated_mask


In [None]:
maskModel = TemporalMask()

randBatch = torch.rand((1,3,300,128,128))
randBatch[0,0,0,0:64] = 1

sampleMask = maskModel(randBatch)

In [None]:
new_Vid = maskMultiply(randBatch, sampleMask)

In [None]:
plt.figure()
plt.imshow(randBatch[0,0,0].detach().numpy())
plt.show()
plt.figure()
plt.imshow(sampleMask[0,0,0].detach().numpy())
plt.show()
plt.figure()
plt.imshow(new_Vid[0,0,0].detach().numpy())
plt.show()