#### Importing the other packages

In [None]:
%run Packages/libraries.ipynb
%run Packages/data_loader.ipynb
%run Packages/Network.ipynb
%run Packages/training.ipynb
%run Packages/generate.ipynb

In [None]:
def write_loss_on_file(loss,filename):
    f = open(filename, 'w+')
    for i, loss_ in enumerate(loss):
        f.write(str(i+1) + '\t' + str(loss_) + '\n')
    f.close()

Parameters

In [None]:
n_chars = 180
n_epochs = 600

#### Creation of the Dataset

In [None]:
LWds = LewisCarrollDataset('Alice_total.txt', n_char=n_chars)

alphabet_len = len(LWds.alphabet)
trans = transforms.Compose([OneHotEncoder(alphabet_len),
                            ToTensor()
                           ])
trans_conv = transforms.Compose([OneHotEncoder(alphabet_len),
                                 training_data_conv(),
                                 ToTensor()
                                ])
LWds.transform = trans_conv

#### Initialization of the Network

In [None]:
#%% Initialize network    
input_size = len(LWds.alphabet)
hidden_units = 280
layers_num = 2
linear = 140
dropout_prob = 0.25
net = lstm_double(input_size, hidden_units, layers_num, linear, dropout_prob)


optimizer = torch.optim.RMSprop(net.parameters(), lr=0.004)
loss_fn = nn.CrossEntropyLoss()

In [None]:
dataloader = DataLoader(LWds, batch_size=256, shuffle=True)

In [None]:
device = torch.device("cuda")
net.to(device)

#### Training

In [None]:
loss_log = []
best_loss = 10
load_params = False
load_best_params = True

times_per_epoch = len(dataloader)
time_start = time.time()
seed = 'alice fell in the hole and desperately cried out for help. luckily, a white rabbit was passing by ' 
seed += 'there and heard her call. ‘what are you doing down there little girl?’ asked the white rabbit. '
seed += '‘i fell down, can you help me?’ cried out alice.'

for epoch in range(n_epochs):
    if load_params or (load_best_params and epoch>0):
        print('Loaded!')
        net.load_state_dict(torch.load('net/ckpt/best.ckpt'))
        net.to(device)
        net.train()
    print('##################################')
    print('## EPOCH %d' % (epoch + 1))
    print('##################################')
    # Iterate batches
    counter = 0
    loss_placeholder = []
    for batch_sample in dataloader:
        counter += 1
        # Extract batch
        x_rnn = batch_sample['x_rnn'].to(device)
        y = batch_sample['y_conv'].to(device)
        speech = batch_sample['speech']
        # Update network
        batch_loss = train_batch(net, x_rnn, y, speech, loss_fn, optimizer)
        loss_placeholder.append(batch_loss)
        if counter%25==0:
            print('Epoch ', epoch+1)
            print('[' + '#'*int(100*counter/times_per_epoch) + ' '*(100-int(100*counter/times_per_epoch)) + ']')
            print('\t Training loss (single batch):', batch_loss)
            eta = (time.time()-time_start)*(times_per_epoch*n_epochs/(counter+times_per_epoch*epoch)-1)/60
            print('\t Approximately %4.2f minutes left' % (eta))
            print('\t Time elapsed: %4.2f minutes' % ((time.time()-time_start)/60))
            clear_output(wait = True)
    loss_log.append(np.mean(loss_placeholder))
    write_loss_on_file(loss_log, 'net/loss.txt')
    if np.mean(loss_placeholder)<best_loss:
        best_loss = np.mean(loss_placeholder)
        net.save()
        print('Saved!')
    if epoch>1:
        plt.plot(np.arange(1,epoch+2), loss_log)
        plt.show()
    #generate_text_multi(net, seed, LWds, file=str(epoch)+'.txt', temperature=0.3, max_char=1000)
    