In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.distributions.binomial as Binomial
torch.set_default_tensor_type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor)
print(torch.cuda.is_available())

In [8]:
class Model(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Model, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
    
    def forward(self, x, hidden):
        output, hidden = self.lstm(x, hidden)
        return torch.round(torch.sigmoid(output)), hidden

In [9]:
#sample fake input
training_data = []
for i in range(2048):
    size = np.random.randint(1, 30)
    data = np.zeros(512 * size)
    s = np.random.choice(len(data), np.random.randint(len(data)), replace=False)
    data[s] = 1
    data = data.reshape(1, size, 512)
    data = torch.Tensor(data)
    training_data.append(data)


In [10]:
model = Model(512, 256)
model.load_state_dict(torch.load("./models/lstm"))
hidden = (torch.Tensor(np.zeros((1, 1, 256))), torch.Tensor(np.zeros((1, 1, 256))))

In [11]:
optimizer = torch.optim.Adam(list(model.parameters()), lr=1e-4)
criterion = torch.nn.BCELoss()

In [12]:
batch_size = 128
for i in range(0, len(training_data), batch_size):
    optimizer.zero_grad()
    output, hidden = model(training_data[i], hidden)
    batch_output = output[0][-1].unsqueeze(0)
    for batch in range(1, batch_size):
        output, hidden = model(training_data[i+batch], hidden)
        batch_output = torch.cat((batch_output, output[0][-1].unsqueeze(0)))
        
    target = Binomial.Binomial(total_count=1, probs=torch.ones((batch_size, 256))/2)
    
    loss = criterion(batch_output, target.sample())
    loss.backward(retain_graph=True)
    print(loss)
    optimizer.step()
    if i % 128 == 0:
        torch.save(model.state_dict(), "model_checkpoint")

tensor(13.8405, grad_fn=<BinaryCrossEntropyBackward>)
tensor(13.7959, grad_fn=<BinaryCrossEntropyBackward>)


KeyboardInterrupt: 

In [None]:
for p in model.parameters():
    print(p, p.shape)