In [1]:
import torch
import torch.nn as nn
from src.models.frame_detector import Transformer, YNetEncoder
from src.models.fuvai import YNet
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence


this_path = Path().resolve()
data_path = this_path.parent / 'data/preprocessed'
assert data_path.exists()

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hidden_dim = 768
# pretrained encoder
ckpt_path = data_path.parent / 'fuvai_weights.pt'
pretrained_model = YNet(1, 64, 1)
ckpt = torch.load(ckpt_path)
pretrained_model.load_state_dict(ckpt)
# pretrained_model.to(DEVICE)

encoder = YNetEncoder(pretrained_model=pretrained_model)
encoder.eval()
encoder.to(DEVICE)

proj = nn.Linear(encoder.out_channels,
                              hidden_dim)

# transformer
transformer = Transformer(hidden_dim=hidden_dim)
transformer.to(DEVICE)

input = torch.randn(1, 840, 1, 256, 256).to(DEVICE)



In [11]:
torch.unbind(input, dim=1)[0].shape

torch.Size([1, 1, 256, 256])

In [12]:
with torch.no_grad():
    output = encoder(input)
    print(output.shape)

torch.Size([840, 512])


In [10]:
input = [torch.ones(i, 4) for i in range(2, 7)]
print(len(input))
input = pad_sequence(input, batch_first=True, padding_value=0)
print(input.shape)

5
torch.Size([5, 6, 4])


In [11]:
input

tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])

In [12]:
def create_mask(src, padding_value=0):
    src_seq_len = src.shape[1]
    src_mask = torch.zeros((src_seq_len, src_seq_len)).type(torch.bool)
    src_padding_mask = (src == padding_value)
   
    return src_mask, src_padding_mask

src_mask, src_padding_mask = create_mask(input)
print(src_mask.shape, src_padding_mask.shape)

torch.Size([6, 6]) torch.Size([5, 6, 4])
