In [1]:
import torch
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
import time
from torch.utils.data import Dataset, DataLoader
from dataset import *
from model import *

In [15]:
class Args:
    def __init__(self):
        # System configs
        self.gpu = 3
        self.run_label = 0
        self.verbose = False
        self.log_root = '/data/hdim-forecast/log5/'
        
        # Global model choice parameters
        self.x_dim = 2         # Number of input features 
        self.y_dim = 131       # Number of predicted values
        self.n_past = 32       # Number of steps from the past to condition on
        self.n_future = 10     # Number of steps into the future to predict
        self.feat_size = 131   # Number of basis features
        self.lstm_hidden_dim = 256  
        self.lstm_layers = 2
        self.lstm_dropout = 0.5
        self.prediction_model = 'lstm'        
        
        # Parameters only used for forecasting
        self.q_learning_rate = 1e-3   # learning rate for optimizing queries
        self.q_batch_size = 8    # batch size to evaluate each query on 
        self.n_sample = 1024   # Number of samples to simulate
        
        # Parameters only used for sampler training
        self.s_test_len = 256   # The length of test sequence to visualize
        self.s_learning_rate = 1e-4  
        self.s_batch_size = 8
        
        # Parameters only used for predictor training
        self.p_learning_rate = 1e-4
        
args = Args()
device = torch.device('cuda:%d' % args.gpu)
args.device = device

In [16]:
while True:
    args.name = '%s-%d-%d-%d-%d-%d-%d-%d-%.1f-%.4f-%d' % \
        (args.prediction_model, args.x_dim, args.y_dim, args.n_past, args.n_future, args.feat_size, 
         args.lstm_layers, args.lstm_hidden_dim, args.lstm_dropout, args.s_learning_rate, args.run_label)
    args.log_dir = os.path.join(args.log_root, args.name)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
        break
    args.run_label += 1
print("Run number = %d" % args.run_label)
writer = SummaryWriter(args.log_dir)
log_writer = open(os.path.join(args.log_dir, 'results.txt'), 'w')

start_time = time.time()
global_iteration = 0
random.seed(args.run_label)  # Set a different random seed for different run labels
torch.manual_seed(args.run_label)
    
def log_scalar(name, value, epoch):
    writer.add_scalar(name, value, epoch)
    log_writer.write('%f ' % value)
    
    
def message(epoch):
    print("Finished epoch %d, time elapsed %.1f" % (epoch, time.time() - start_time))

Run number = 4


In [22]:
train_dataset = TrafficDataset(train=True, max_len=args.n_past+args.n_future)
train_loader = DataLoader(train_dataset, batch_size=args.s_batch_size, shuffle=True, num_workers=4)

test_dataset = TrafficDataset(train=False, max_len=args.s_test_len)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)

In [23]:
        # Time features


# def next_time(week, num_steps):
#     start_time = week.item() * 3600 * 24 * 7
#     time_new = torch.linspace(start_time, start_time + 5*60*(num_steps-1), num_steps, device=week.device)
#     hour_of_day = (time_new % (3600 * 24)) / (3600. * 24)
#     week_day = (time_new % (3600 * 24 * 7)) / (3600 * 24. * 7)
#     return torch.stack([hour_of_day, week_day], dim=-1)

# print(next_time(bx[0, 0, 1], 500).shape)
# print(bx.shape)
# bx_expanded = []
# for i in range(args.batch_size):
#     bx_expanded.append(torch.cat([bx[i, :32], next_time(bx[i, -1, 1], 500)]))
# bx_expanded = torch.stack(bx_expanded)
    
# print(bx_expanded.shape)

In [24]:

def loss_fn_nll(mu: Variable, sigma: Variable, labels: Variable):
    '''
    Compute using gaussian the log-likehood which needs to be maximized. Ignore time steps where labels are missing.
    Args:
        mu: (Variable) dimension [batch_size, y_dim] - estimated mean at time step t
        sigma: (Variable) dimension [batch_size, y_dim] - estimated standard deviation at time step t
        labels: (Variable) dimension [batch_size, y_dim] z_t
    Returns:
        loss: (Variable) average log-likelihood loss across the batch
    '''
    mu, sigma, labels = mu.flatten(), sigma.flatten(), labels.flatten()
    valid_index = (labels > 0)
    distribution = torch.distributions.normal.Normal(mu[valid_index], sigma[valid_index])
    likelihood = distribution.log_prob(labels[valid_index])
    return -torch.mean(likelihood)

In [25]:
net = LSTMSampler(args).to(device)
optimizer = optim.Adam(net.parameters(), lr=args.s_learning_rate)
#scheduler = optim.lr_scheduler.StepLR(exp_optim, 20, 0.9)

In [None]:

for epoch in range(1000):
    for data in train_loader:
        bx, by = data[0].to(device), data[1].to(device)
        # print(bx.shape)
        optimizer.zero_grad()

        hidden, cell = net.init_hidden(bx.shape[0]), net.init_cell(bx.shape[0])
        for t in range(args.n_past):
            mu, sigma, hidden, cell = net(bx[:, t].unsqueeze(0), by[:, t].unsqueeze(0), hidden, cell)

        loss = loss_fn_nll(mu=mu, sigma=sigma, labels=by[:, args.n_past])
        if torch.isnan(loss):
            continue

        writer.add_scalar('loss', loss, global_iteration)
        global_iteration += 1
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        tx, ty = iter(test_loader).next()
        tx, ty = tx.to(device).permute(1, 0, 2), ty.to(device).permute(1, 0, 2)
    # bx_expanded = torch.cat([bx[0, :32], next_time(bx[0, -1, 1], 500)])
        samples = net.sample(tx, ty[:args.n_past], args.s_test_len - args.n_past)
        joined_samples = torch.cat([ty[:args.n_past], samples], dim=0)
        fig = plt.figure(figsize=(20, 10))
        for plot_id in range(16):
            plt.subplot(8, 2, plot_id+1)
            plt.plot(range(args.s_test_len), ty[:, 0, plot_id].cpu(), c='g', label='true')
            plt.plot(range(args.s_test_len), joined_samples[:, 0, plot_id].cpu(), c='r', label='pred')
            plt.ylim([0, ty[:, 0, plot_id].max().item() * 1.5])
        plt.legend()
        plt.tight_layout()
        
        # writer.add_figure('sample', fig, global_iteration)
        plt.savefig(os.path.join(args.log_dir, 'result_%d.png' % (epoch // 10)))
        plt.close()
    if epoch % 10 == 0:
        torch.save(net.cpu().state_dict(), 'pretrained/sampler_traffic_%s.pt' % args.name, _use_new_zipfile_serialization=False)
        net.to(device)
        message(epoch)
        
#         buf = io.BytesIO()
#         plt.savefig(buf, format='jpeg')
#         buf.seek(0)
#         image = PIL.Image.open(buf)
#         image = ToTensor()(image).permute(1, 2, 0)
#         print(image.shape)
#         plt.imshow(image)
#         plt.show()
#         writer.add_image(name, image, iteration)
        
#         write_plt(writer, 'sample', global_iteration)

#         
#         plt.show()

Finished epoch 0, time elapsed 530.2
Finished epoch 10, time elapsed 5405.2
Finished epoch 20, time elapsed 10247.6
Finished epoch 30, time elapsed 15095.5
Finished epoch 40, time elapsed 19974.8
Finished epoch 50, time elapsed 24880.1
Finished epoch 60, time elapsed 29782.7
Finished epoch 70, time elapsed 34679.3
Finished epoch 80, time elapsed 39573.9
Finished epoch 100, time elapsed 49407.0
Finished epoch 110, time elapsed 54290.6
Finished epoch 120, time elapsed 59217.5
Finished epoch 130, time elapsed 64121.8
Finished epoch 140, time elapsed 68946.6
Finished epoch 150, time elapsed 73666.0
Finished epoch 160, time elapsed 78409.7
Finished epoch 170, time elapsed 83046.7
Finished epoch 180, time elapsed 87798.5
Finished epoch 190, time elapsed 92535.3
Finished epoch 200, time elapsed 97257.7
Finished epoch 210, time elapsed 101994.8
Finished epoch 220, time elapsed 106817.1
Finished epoch 230, time elapsed 111544.3
Finished epoch 240, time elapsed 116312.1
Finished epoch 250, time 

Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/util.py", line 262, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.6/multiprocessing/util.py", line 186, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/shutil.py", line 490, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/usr/lib/python3.6/shutil.py", line 488, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-skv4n7ls'


Finished epoch 280, time elapsed 135317.3
Finished epoch 290, time elapsed 140017.5
Finished epoch 300, time elapsed 144773.5
Finished epoch 310, time elapsed 149523.0
Finished epoch 320, time elapsed 154267.1
