In [61]:
import torch
import torch.nn as nn
import pickle

In [62]:
#%%
#Analogue of the nn.RNN module
class MyRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bias=True, nonlinearity='tanh'):
        super(MyRNN, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Initialize parameters
        self.weight_ih = nn.Parameter(torch.Tensor(num_layers, hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.Tensor(num_layers, hidden_size, hidden_size))
        if bias:
            self.bias_ih = nn.Parameter(torch.Tensor(num_layers, hidden_size))
            self.bias_hh = nn.Parameter(torch.Tensor(num_layers, hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)

        self.nonlinearity = nonlinearity

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / (self.hidden_size ** 0.5)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, input, hx=None):
        '''
        This function defines a forward RNN pass  

        Input: tensor of shape (batch_size, sequence_length, input_size)'
        Output: (output, hx) where output is a list of tensors oh  cell
        predictions, shape (num_layers, batch_size, hidden_size)
        '''
        # Initializes the hidden state if not provided
        if hx is None:
            hx = torch.zeros(self.num_layers, input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)

        outputs = []

        #iterate over each time step
        for i in range(input.size(1)):
            hx = self.rnn_cell(input[:, i, :], hx)
            outputs.append(hx.unsqueeze(1))

        output = torch.cat(outputs, dim=1)
        return output, hx

    def rnn_cell(self, input, hx):
        '''
        Defines a run of one RNN batch for one time step

        Inputs: 
            input tensor of hape (batch_size, 1, input_size)
            hx tensor of shape (num_layers, batch_size, hidden_size)
        Output:
            tensor of shape (num_layers, batch_size, hidden_size)

        '''
        # Apply RNN cell computation  --> tensor (batch_size, hidden_size)
        gates = torch.matmul(input, self.weight_ih.transpose(0, 1)) + torch.matmul(hx, self.weight_hh.transpose(0, 1))
        if self.bias_ih is not None:
            gates += self.bias_ih.unsqueeze(0)
            gates += self.bias_hh.unsqueeze(0)
        if self.nonlinearity == 'tanh':
            return torch.tanh(gates)
        elif self.nonlinearity == 'relu':
            return torch.relu(gates)
        else:
            raise ValueError("Unsupported nonlinearity. Choose from 'tanh' or 'relu'.")



In [66]:
# %%
class fullRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(fullRNN, self).__init__()
        self.hidden_size=hidden_size
        self.rnn_cell=nn.RNN(input_size, hidden_size)
        self.output_layer=nn.Linear(hidden_size, output_size)

    def forward(self, ordered_text):
        '''
        This functions defines forward prop through our RNN network.
        The input is a tensor of shape (seq_length, batch_size, input_size)
        The seq_length is number of examples
        '''
        #Initiates the hidden layer for the whole text
        print(ordered_text.shape())
        ordered_text = ordered_text.unsqueeze(1)  # Adds a batch dimension
        hidden = torch.zeros(1, ordered_text.size(1), self.hidden_size)
        # hidden=torch.zeros (ordered_text.size(1), self.hidden_size)
        rnn_output, hidden = self.rnn_cell(ordered_text, hidden)
        output=self.output_layer(rnn_output[-1, :, :])
        return output

input_size = 128
hidden_size = 100
output_size = 100

# Step 1 - Create RNN for Query Tower + for Doc tower 
queryRNN = fullRNN(input_size, hidden_size, output_size)

# Step 2 - Load input data - pickle files have been tokenised by sentence piece and embedded by
# Data in format - [tesnor([tensor(query), tensor(rel_docs), tensor(irr_docs)]), .... ]
# For query
testData = []
trainingData = [] 
# To prep the data
validationData = []

with open('tokenised_triplets/test.pkl', 'rb') as file:
    testData = pickle.load(file)
with open('tokenised_triplets/training.pkl', 'rb') as file:
    trainingData = pickle.load(file)
with open('tokenised_triplets/validation.pkl', 'rb') as file:
    validationData = pickle.load(file)

    




In [64]:
# Take the query out of the triplet 
query_list = []
for (query, _, _ ) in trainingData:
#  This gives tensor([w1, w2, w3,...wn]) for each individual query
# Take the query out of a tensor form and keep as a list
    query_as_list = query.tolist()
# Then iterate over all of the triplets and pull them all out 
    query_list.append(query_as_list)
# Put them all in one list 
# Make this list a tensor
tensor_query_list = torch.tensor(query_list)
print(tensor_query_list)
# Put into model 

tensor([[ 0.0450, -0.1094,  0.2346,  ...,  0.1490, -0.1885, -0.0019],
        [ 0.0450, -0.1094,  0.2346,  ...,  0.1490, -0.1885, -0.0019],
        [ 0.0450, -0.1094,  0.2346,  ...,  0.1490, -0.1885, -0.0019],
        ...,
        [-0.0671, -0.1579,  0.3436,  ...,  0.0970, -0.0428, -0.2825],
        [-0.0671, -0.1579,  0.3436,  ...,  0.0970, -0.0428, -0.2825],
        [-0.0671, -0.1579,  0.3436,  ...,  0.0970, -0.0428, -0.2825]])


In [71]:
# Step 3 - Pass data into model 
criterion = nn.NLLLoss()
learning_rate = 0.005 # param, play around with to learn
optimizer = torch.optim.SGD(queryRNN.parameters(), lr=learning_rate) #stochastic gradient descent

print(tensor_query_list.shape)
print(tensor_query_list[0].shape)
print(tensor_query_list[0])
output = queryRNN.forward(tensor_query_list)
# output.forward(trainintensor_query_listData[0][0])

print(output)

# Step 4 - Evaluate loss function 


torch.Size([814, 128])
torch.Size([128])
tensor([ 4.5037e-02, -1.0937e-01,  2.3464e-01,  2.8838e-01,  2.1045e-01,
        -3.6086e-02, -3.6875e-02, -1.3356e-01, -1.8880e-02, -3.9864e-02,
         2.0278e-01, -9.6033e-04, -1.8266e-01,  4.3697e-02,  7.3484e-02,
         1.6414e-01, -1.6777e-01,  1.0787e-01, -2.3564e-01,  3.5150e-01,
         5.3646e-02, -1.8614e-01, -2.8010e-01, -3.0639e-01, -1.3082e-01,
         6.8598e-02, -1.4798e-01,  3.2496e-02,  1.7219e-01, -8.9243e-02,
        -1.1667e-01,  7.7255e-03,  6.9057e-02, -2.2016e-04,  5.5669e-02,
         1.9851e-01,  8.6366e-02, -1.1740e-01, -6.8780e-02, -1.2199e-01,
        -1.3029e-01,  4.2058e-01,  9.6036e-02,  7.6770e-02,  2.1897e-01,
        -1.2299e-01, -8.5546e-02,  3.8775e-02,  1.7612e-01,  2.7531e-01,
        -3.0188e-01,  6.1773e-02,  2.6999e-01,  1.7055e-01,  3.0800e-01,
         1.4514e-01,  1.6264e-02, -2.5426e-01, -2.3492e-01,  1.1858e-01,
        -5.6707e-02, -4.4231e-02,  1.8520e-01, -8.9252e-02,  1.6647e-01,
         8

TypeError: 'torch.Size' object is not callable