In this notebook I'll reconstruct the process and results of the DTCR algorithm presented in the "Learning Representations for Time Series Clustering".

In [1]:
# Imports
import torch
from torch.utils.data import DataLoader
from Utilities.DTCR import DTCRModel, DTCRConfig
from Utilities.UCRParser import read_dataset
from Utilities.DRNN import BidirectionalDRNN

In [2]:
train_ds, test_ds = read_dataset("Two_Patterns")
train_dl = DataLoader(train_ds, batch_size=2, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=2, shuffle=True)

config = DTCRConfig()
config.class_num = train_ds.number_of_labels
config.input_size = train_ds[0][0].shape[1]
config.num_steps = train_ds[0][0].shape[0]
config.batch_size = 2


Loading the Two_Patterns dataset...
The dataset Two_Patterns was loaded.


In [3]:
dtcr_model = DTCRModel(config)


In [4]:
encoder = dtcr_model.encoder
encoder


BidirectionalDRNN(
  (_regular_drnn): DRNN(
    (cells): Sequential(
      (0): GRU(1, 100)
      (1): GRU(100, 50)
      (2): GRU(50, 50)
    )
  )
  (_backwards_drnn): DRNN(
    (cells): Sequential(
      (0): GRU(1, 100)
      (1): GRU(100, 50)
      (2): GRU(50, 50)
    )
  )
)

In [5]:
decoder = dtcr_model.decoder
decoder

DTCRDecoder(
  (_rnn): GRU(400, 400, batch_first=True)
  (_linear): Linear(in_features=400, out_features=1, bias=True)
)

In [6]:
loss = torch.nn.MSELoss()
opt = torch.optim.Adam(list(decoder.parameters()) + list(encoder.parameters()))#, eps=5e-3)
print_interval = 5
for epoch in range(2):
    running_loss = 0.0
    for index, (sample_data, sample_label) in enumerate(train_dl):
        opt.zero_grad()

        _, hidden_outputs = encoder(sample_data)
        latent_representation = dtcr_model.get_latent_representation(hidden_outputs)

        #repr_for_reconstruction = dtcr_model.prepare_representation_for_reconstruction(
        #    latent_representation)
        #reconstructed_inputs, _ = decoder(repr_for_reconstruction)
        reconstructed_inputs, _ = decoder(latent_representation)
        # reconstructed_inputs: (N, seq_len, hidden_size), we need just the last of each
        # hidden_size
        output = loss(reconstructed_inputs, sample_data)
        output.backward()
        opt.step()

        running_loss += output.item()
        if index % print_interval == print_interval - 1:  # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                (epoch + 1, index + 1, running_loss / print_interval))
            running_loss = 0.0


  return F.mse_loss(input, target, reduction=self.reduction)


[1,     5] loss: 1.057
[1,    10] loss: 1.010
[1,    15] loss: 1.001
[1,    20] loss: 0.994
[1,    25] loss: 0.995
[1,    30] loss: 0.993


KeyboardInterrupt: 

In [None]:
reconstructed_inputs[0]

tensor([[-0.0003],
        [-0.0006],
        [-0.0009],
        [-0.0012],
        [-0.0015],
        [-0.0018],
        [-0.0021],
        [-0.0024],
        [-0.0027],
        [-0.0030],
        [-0.0033],
        [-0.0036],
        [-0.0039],
        [-0.0042],
        [-0.0045],
        [-0.0048],
        [-0.0051],
        [-0.0053],
        [-0.0056],
        [-0.0059],
        [-0.0062],
        [-0.0065],
        [-0.0068],
        [-0.0071],
        [-0.0074],
        [-0.0077],
        [-0.0080],
        [-0.0082],
        [-0.0085],
        [-0.0088],
        [-0.0091],
        [-0.0094],
        [-0.0097],
        [-0.0100],
        [-0.0102],
        [-0.0105],
        [-0.0108],
        [-0.0111],
        [-0.0114],
        [-0.0117],
        [-0.0119],
        [-0.0122],
        [-0.0125],
        [-0.0128],
        [-0.0131],
        [-0.0133],
        [-0.0136],
        [-0.0139],
        [-0.0142],
        [-0.0145],
        [-0.0147],
        [-0.0150],
        [-0.

In [None]:
sample_data[0]

tensor([[-0.1375],
        [-0.1239],
        [ 0.4653],
        [ 0.0880],
        [ 0.2117],
        [ 0.5962],
        [ 0.0431],
        [ 0.1223],
        [ 0.3642],
        [-0.5954],
        [ 0.0387],
        [ 0.2478],
        [-0.2279],
        [ 0.3670],
        [ 0.1559],
        [ 0.1879],
        [ 0.6380],
        [-0.0640],
        [-0.3787],
        [-0.2925],
        [ 0.1132],
        [ 0.1191],
        [ 0.2278],
        [ 0.3051],
        [ 0.3733],
        [-0.2286],
        [ 0.0054],
        [ 0.0323],
        [-0.1088],
        [-0.0878],
        [-0.4618],
        [-0.7201],
        [ 0.1494],
        [ 0.3238],
        [ 0.1462],
        [-0.0580],
        [ 0.4726],
        [-0.1032],
        [-0.3556],
        [ 0.0334],
        [-0.6799],
        [ 0.3069],
        [ 0.3108],
        [ 0.1048],
        [-0.2105],
        [ 0.4394],
        [-0.6251],
        [ 0.2421],
        [ 0.3760],
        [-0.7003],
        [-0.1711],
        [ 0.0826],
        [-0.