In [112]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms
import torchvision

from utils.Dataset_And_Transforms import FigrimFillersDataset, Downsampling, ToTensor, SequenceModeling, ExpandTargets
from utils.Create_Datasets import create_datasets   

In [113]:
train_loader, val_loader, test_loader = create_datasets(batch_size=5, data_transform=transforms.Compose([ToTensor(),Downsampling(10), SequenceModeling()]))

In [114]:
fixations_l = []
fixations_lengths = []
for i, example in enumerate(train_loader): #start at index 0
            # get the inputs
            image = example["image"]
            fixations = example["fixations"]
            states = example["states"]
            length = example["fixations_length"]

            
            if i == 0:
                break

In [115]:
fixations.size()

torch.Size([5, 15, 2])

In [116]:
#Sort the sequence lengths in descending order, keep track of the old indices, as the fixations' and token-indices'
#batch-dimension needs to be rearranged in that way
length_s, sort_idx = torch.sort(length, 0, descending=True)
sort_idx
#make length_s to list
length_s = list(length_s)

In [117]:
#rearrange batch-dimensions (directly getting rid of the additional dimension this introduces)
fixations = fixations[sort_idx].view(fixations.size(0), fixations.size(-2), fixations.size(-1))
states = states[sort_idx].view(states.size(0), states.size(-1))
#Da der Input immer derselbe Kontextvektor ist, macht es nichts, wenn die Targets umsortiert werden

In [118]:
states

tensor([[ 0,  1,  1,  1,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  2, -1, -1, -1, -1, -1, -1, -1, -1]])

In [119]:
length_s

[tensor([10.]), tensor([8.]), tensor([8.]), tensor([8.]), tensor([7.])]

In [120]:
#flattened context vector
context_vector = torch.randn(2)
context_vectors_steps = torch.empty(5, int(max(length).item()), context_vector.size(0)) #so that all dims are specified as int
for i in range(5):
    for j in range(int(max(length).item())):
        context_vectors_steps[i,j] = context_vector

In [48]:
#packed = torch.nn.utils.rnn.pack_padded_sequence(input=context_vectors_steps, lengths=length_s, batch_first=True)
packed

PackedSequence(data=tensor([[ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, -0.6856,  1.2264],
        [ 0.6166, 

In [49]:
#lstm = nn.LSTM(input_size=3, hidden_size=20, num_layers=1, batch_first=True)

In [121]:
import torch.nn.utils.rnn as rnn_utils

input_size = 2 #CNN context vector (eg 100x100, so flattened out 10000)
hidden_size = 20 #vllt eher 10-50 Dimensionen

class MyRNN(nn.Module):
    def __init__(self, input_size, hidden_size, gpu):
        super(MyRNN, self).__init__()
        self.rec_layer = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc_fix = nn.Linear(in_features=hidden_size, out_features=2) #x- and y-coordinate
        self.fc_state = nn.Linear(in_features=hidden_size, out_features=3) #sos, eos, during sequence
        self.gpu = gpu
        if gpu:
            if torch.cuda.is_available():
                device = torch.device("cuda")
                self.cuda()
                
    def forward(self, inputs, length):
        packed_inputs = rnn_utils.pack_padded_sequence(input=inputs, lengths=length, batch_first=True)
        output, (hidden, cell) = self.rec_layer(packed_inputs)
        unpacked_output, _ = rnn_utils.pad_packed_sequence(output, batch_first=True, padding_value=-1)
        out_fix = self.fc_fix(unpacked_output)
        out_state = self.fc_state(unpacked_output)
        return out_fix, out_state

rnn_model = MyRNN(input_size=input_size, hidden_size=hidden_size, gpu=False) #lstm, gru

In [122]:
output, state = rnn_model(context_vectors_steps, length_s)

In [123]:
output.size()

torch.Size([5, 10, 2])

In [124]:
#unpacked, _= torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)

In [125]:
fixations = fixations[:,:output.size(1),:]
mask = (fixations != -1)
masked_output = output[mask]
masked_fixations = fixations[mask]
print(masked_fixations.size())
masked_output.size()

torch.Size([82])


torch.Size([82])

In [17]:
masked_output = unpacked[mask]
masked_output.size()

torch.Size([70])

In [18]:
mask2 = (fixations != -1)
masked_fixations = fixations[mask2]
masked_fixations.size()

torch.Size([70])

In [19]:
mask3 = (states != -1)
masked_states = states[mask3]
masked_states.size()

torch.Size([35])

In [126]:
state = state.permute(0,2,1)
state.size()

torch.Size([5, 3, 10])

In [127]:
states = states[:,:output.size(1)]
loss = nn.CrossEntropyLoss(ignore_index=-1)
loss(state, states)

tensor(1.2000, grad_fn=<NllLoss2DBackward>)

In [128]:
states

tensor([[ 0,  1,  1,  1,  1,  1,  1,  1,  1,  2],
        [ 0,  1,  1,  1,  1,  1,  1,  2, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  1,  2, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  1,  2, -1, -1],
        [ 0,  1,  1,  1,  1,  1,  2, -1, -1, -1]])

In [94]:
states.size()

torch.Size([5, 15])