In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms
import torchvision

from utils.Dataset_And_Transforms import FigrimFillersDataset, Downsampling, ToTensor, SequenceModeling, ExpandTargets
from utils.Create_Datasets import create_datasets   

In [2]:
train_loader, val_loader, test_loader = create_datasets(batch_size=5, data_transform=transforms.Compose([ToTensor(),Downsampling(10), SequenceModeling()]))

In [3]:
fixations_l = []
fixations_lengths = []
for i, example in enumerate(train_loader): #start at index 0
            # get the inputs
            image = example["image"]
            fixations = example["fixations"]
            states = example["states"]
            length = example["fixations_length"]

            
            if i == 0:
                break

In [4]:
fixations.size()

torch.Size([5, 15, 2])

In [5]:
#Sort the sequence lengths in descending order, keep track of the old indices, as the fixations' and token-indices'
#batch-dimension needs to be rearranged in that way
length_s, sort_idx = torch.sort(length, 0, descending=True)
sort_idx
#make length_s to list
length_s = list(length_s)

In [6]:
#rearrange batch-dimensions (directly getting rid of the additional dimension this introduces)
fixations = fixations[sort_idx].view(fixations.size(0), fixations.size(-2), fixations.size(-1))
states = states[sort_idx].view(states.size(0), states.size(-1))
#Da der Input immer derselbe Kontextvektor ist, macht es nichts, wenn die Targets umsortiert werden

In [7]:
states

tensor([[ 0,  1,  1,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]])

In [8]:
length_s

[tensor([9.]), tensor([9.]), tensor([7.]), tensor([6.]), tensor([4.])]

In [9]:
#flattened context vector
context_vector = torch.randn(2)
context_vectors_steps = torch.empty(5, int(max(length).item()), context_vector.size(0)) #so that all dims are specified as int
for i in range(5):
    for j in range(int(max(length).item())):
        context_vectors_steps[i,j] = context_vector

In [10]:
packed = torch.nn.utils.rnn.pack_padded_sequence(input=context_vectors_steps, lengths=length_s, batch_first=True)
packed

PackedSequence(data=tensor([[-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935],
        [-1.5460,  0.7935]])

In [11]:
lstm = nn.LSTM(input_size=2, hidden_size=20, num_layers=1, batch_first=True)

In [12]:
output, hidden = lstm(packed)

In [13]:
unpacked, _= torch.nn.utils.rnn.pad_packed_sequence(packed, batch_first=True)

In [14]:
unpacked.size()

torch.Size([5, 9, 2])

In [15]:
unpacked

tensor([[[-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935]],

        [[-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935]],

        [[-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [-1.5460,  0.7935],
         [ 0.0000,  0.0000],
        

In [16]:
mask = (unpacked!= 0)

In [17]:
masked_output = unpacked[mask]
masked_output.size()

torch.Size([70])

In [18]:
mask2 = (fixations != -1)
masked_fixations = fixations[mask2]
masked_fixations.size()

torch.Size([70])

In [19]:
mask3 = (states != -1)
masked_states = states[mask3]
masked_states.size()

torch.Size([35])