In [None]:
"""
Implementation of the TNC baseline based on the code available on
https://openreview.net/forum?id=8qDwejCuCN
"""

import os
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader,Dataset

import torch
import torch.nn as nn

os.chdir("../") #Load from parent directory
from data_utils import Plots,gen_loader,load_datasets
from models import select_encoder

utils_plot=Plots()

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, input_size, device):
        super(Discriminator, self).__init__()
        self.device = device
        self.input_size = input_size

        self.model = torch.nn.Sequential(torch.nn.Linear(2*self.input_size, 4*self.input_size),
                                         torch.nn.ReLU(inplace=True),
                                         torch.nn.Dropout(0.5),
                                         torch.nn.Linear(4*self.input_size, 1))

        torch.nn.init.xavier_uniform_(self.model[0].weight)
        torch.nn.init.xavier_uniform_(self.model[3].weight)

    def forward(self, x, x_tild):
        """
        Predict the probability of the two inputs belonging to the same neighbourhood.
        """
        x_all = torch.cat([x, x_tild], -1)
        p = self.model(x_all)
        return p.view((-1,))

In [None]:
class TNCDataset(Dataset):
    def __init__(self, x, mc_sample_size, window_size, augmentation, epsilon = 3, state=None, adf=False):
        super(TNCDataset, self).__init__()
        self.time_series = x
        self.T = x.shape[-1]
        self.window_size = window_size
        self.sliding_gap = int(window_size*25.2)
        self.window_per_sample = (self.T-2*self.window_size)//self.sliding_gap
        self.mc_sample_size = mc_sample_size
        self.state = state
        self.augmentation = augmentation
        self.adf = adf
        if not self.adf:
            self.epsilon = epsilon
            self.delta = 5*window_size*epsilon

    def __len__(self):
        return len(self.time_series)*self.augmentation

    def __getitem__(self, ind):
        ind = ind%len(self.time_series)
        t = np.random.randint(2*self.window_size, self.T-2*self.window_size)
        x_t = self.time_series[ind][:,t-self.window_size//2:t+self.window_size//2]
        X_close = self._find_neighours(self.time_series[ind], t)
        X_distant = self._find_non_neighours(self.time_series[ind], t)

        if self.state is None:
            y_t = -1
        else:
            y_t = torch.round(torch.mean(self.state[ind][t-self.window_size//2:t+self.window_size//2]))
        return x_t, X_close, X_distant, y_t

    def _find_neighours(self, x, t):
        T = self.time_series.shape[-1]
        if self.adf:
            gap = self.window_size
            corr = []
            for w_t in range(self.window_size,4*self.window_size, gap):
                try:
                    p_val = 0
                    for f in range(x.shape[-2]):
                        p = adfuller(np.array(x[f, max(0,t - w_t):min(x.shape[-1], t + w_t)].reshape(-1, )))[1]
                        p_val += 0.01 if math.isnan(p) else p
                    corr.append(p_val/x.shape[-2])
                except:
                    corr.append(0.6)
            self.epsilon = len(corr) if len(np.where(np.array(corr) >= 0.01)[0])==0 else (np.where(np.array(corr) >= 0.01)[0][0] +1)
            self.delta = 5*self.epsilon*self.window_size

        ## Random from a Gaussian
        t_p = [int(t+np.random.randn()*self.epsilon*self.window_size) for _ in range(self.mc_sample_size)]
        t_p = [max(self.window_size//2+1,min(t_pp,T-self.window_size//2)) for t_pp in t_p]
        x_p = torch.stack([x[:, t_ind-self.window_size//2:t_ind+self.window_size//2] for t_ind in t_p])
        return x_p

    def _find_non_neighours(self, x, t):
        T = self.time_series.shape[-1]
        if t>T/2:
            t_n = np.random.randint(min(self.window_size//2+1, t - self.delta), (t - self.delta + 1), self.mc_sample_size)
        else:
            t_n = np.random.randint(min((t + self.delta), (T - self.window_size-1)), (T - self.window_size//2), self.mc_sample_size)
        x_n = torch.stack([x[:, t_ind-self.window_size//2:t_ind+self.window_size//2] for t_ind in t_n])
        return x_n

In [None]:
def epoch_run(loader, disc_model, encoder, device, w=0, optimizer=None, train=True):
    if train:
        encoder.train()
        disc_model.train()
    else:
        encoder.eval()
        disc_model.eval()
    # loss_fn = torch.nn.BCELoss()
    loss_fn = torch.nn.BCEWithLogitsLoss()
    encoder.to(device)
    disc_model.to(device)
    epoch_loss = 0
    epoch_acc = 0
    batch_count = 0
    for x_t, x_p, x_n, _ in loader:
        mc_sample = x_p.shape[1]
        batch_size, f_size, len_size = x_t.shape
        x_p = x_p.reshape((-1, f_size, len_size))
        x_n = x_n.reshape((-1, f_size, len_size))
        x_t = np.repeat(x_t, mc_sample, axis=0)
        neighbors = torch.ones((len(x_p))).to(device)
        non_neighbors = torch.zeros((len(x_n))).to(device)
        x_t, x_p, x_n = x_t.to(device), x_p.to(device), x_n.to(device)

        z_t = encoder(x_t)
        z_p = encoder(x_p)
        z_n = encoder(x_n)

        d_p = disc_model(z_t, z_p)
        d_n = disc_model(z_t, z_n)

        p_loss = loss_fn(d_p, neighbors)
        n_loss = loss_fn(d_n, non_neighbors)
        n_loss_u = loss_fn(d_n, neighbors)
        loss = (p_loss + w*n_loss_u + (1-w)*n_loss)/2

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        p_acc = torch.sum(torch.nn.Sigmoid()(d_p) > 0.5).item() / len(z_p)
        n_acc = torch.sum(torch.nn.Sigmoid()(d_n) < 0.5).item() / len(z_n)
        epoch_acc = epoch_acc + (p_acc+n_acc)/2
        epoch_loss += loss.item()
        batch_count += 1
    return epoch_loss/batch_count, epoch_acc/batch_count


In [None]:
def learn_encoder(n_cross_val,data_type,tr_percentage,n_epochs,suffix,show_encodings,verbose,device,device_ids,batch_size,
                  window_size,encoder_type,encoding_size,lr,decay,datasets,w,mc_sample_size,augmentation):

    accuracies, losses = [], []
    
    for cv in range(n_cross_val):
        train_data,train_labels,test_data,test_labels = load_datasets(data_type,datasets,cv)
        #Save Location
        save_dir = './results/baselines/%s_tnc/%s/'%(datasets,data_type)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            
        save_file = str((save_dir +'encoding_%d_encoder_%d_checkpoint_%d%s.pth.tar')
               %(encoding_size,encoder_type, cv,suffix))
        
        input_size = train_data.shape[1]
        encoder,_ = select_encoder(device,encoder_type,input_size,encoding_size)
        encoder = encoder.to(device)
        
        disc_model = Discriminator(encoder.encoding_size, device)
        params = list(disc_model.parameters()) + list(encoder.parameters())
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, n_epochs, gamma=0.999)
        inds = list(range(len(train_data)))
        random.shuffle(inds)
        train_data = train_data[inds]
        n_train = int(tr_percentage*len(train_data))
        performance = []
        best_acc = 0
        best_loss = np.inf

        for epoch in range(n_epochs):

            trainset = TNCDataset(torch.Tensor(train_data[:n_train]), mc_sample_size=mc_sample_size,
                                  window_size=window_size, augmentation=augmentation, adf=True)
            train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=3)
            
            validset = TNCDataset(torch.Tensor(train_data[n_train:]), mc_sample_size=mc_sample_size, adf=True,
                                  window_size=window_size, augmentation=augmentation)

            valid_loader = DataLoader(validset, batch_size=batch_size, shuffle=True)
            
            epoch_loss, epoch_acc = epoch_run(train_loader, disc_model, encoder, optimizer=optimizer,
                                              w=w, train=True, device=device)

            val_loss, val_acc = epoch_run(valid_loader, disc_model, encoder, train=False, w=w, device=device)
            performance.append((epoch_loss, val_loss, epoch_acc, val_acc))
            scheduler.step()
            
            if verbose:
                print('\nEpoch ', epoch)
                print('Train ===> Loss: ', epoch_loss)
                print('Validation ===> Loss: ', val_loss)
                
            if best_loss > val_loss:
                best_acc = val_acc
                best_loss = val_loss
                state = {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict(),
                    'discriminator_state_dict': disc_model.state_dict(),
                    'best_accuracy': val_acc
                }
                torch.save(state, save_file)
                if verbose:
                    print('Saving ckpt')
                
        accuracies.append(best_acc)
        losses.append(best_loss)
        
        # Save performance plots
            
        train_loss = [t[0] for t in performance]
        val_loss = [t[1] for t in performance]
        train_acc = [t[2] for t in performance]
        val_acc = [t[3] for t in performance]
        
        plt.figure()
        plt.plot(np.arange(n_epochs), train_loss, label="Train")
        plt.plot(np.arange(n_epochs), val_loss, label="Validation")
        plt.title("TNC Unsupervised Loss")
        plt.legend()
        plt.savefig(save_dir +'encoding_%d_encoder_%d_checkpoint_%d%s_loss.png'%(encoding_size,encoder_type, cv,suffix))
        if verbose:
            plt.show()
        plt.close()
        
        plt.figure()
        plt.plot(np.arange(n_epochs), train_acc, label="Train")
        plt.plot(np.arange(n_epochs), val_acc, label="Validation")
        plt.title("TNC Unsupervised Accuracy")
        plt.legend()
        plt.savefig(save_dir +'encoding_%d_encoder_%d_checkpoint_%d%s_acc.png'%(encoding_size,encoder_type, cv,suffix))

        if verbose:
            plt.show()
        plt.close()
        
    if verbose:
        print('=======> Performance Summary:')
        print('Accuracy: %.2f +- %.2f'%(100*np.mean(accuracies), 100*np.std(accuracies)))
        print('Loss: %.4f +- %.4f'%(np.mean(losses), np.std(losses)))
    
    return

In [None]:
def run_tnc(args):

    #Run Process
    learn_encoder(**args)
    
    #Plot Features
    title = 'TNC Encoding TSNE for %s'%(args['data_type'])
    
    if args['show_encodings']:
        for cv in range(args['n_cross_val']):
            _,_,test_data,test_labels = load_datasets(args['data_type'],args['datasets'],cv)
            utils_plot.plot_distribution(test_data, test_labels,args['encoder_type'],
                                                             args['encoding_size'],args['window_size'],'tnc',
                                                             args['datasets'],args['data_type'],args['suffix'],
                                                             args['device'], title, cv)
    return

In [None]:
def main(args):
    
    #Devices
    args['device'] = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    args['device_ids'] = [i for i in range(torch.cuda.device_count())]
    print('Using', args['device'])

    #Experiments
    
    if args['data_type']=='afdb':
        args['lr'] = 1e-3
        args['w'] = .05
    if args['data_type']=='ims':
        args['lr'] = 1e-4
        args['w'] = .1
    if args['data_type']=='urban':
        args['n_cross_val'] = 10
        args['lr'] = 1e-4
        args['w'] = .05
        
    #Experiment Parameters
    args['window_size'] = 2500
    args['encoder_type'] = 1
    args['encoding_size'] = 128
    args['decay'] = 1e-5
    args['datasets'] = args['data_type']
    args['mc_sample_size'] = 10
    args['augmentation'] = 7
    run_tnc(args)
        
    return

In [None]:
args = {'n_cross_val':5,
        'data_type':'afdb', #options: afdb, ims, urban
        'tr_percentage':0.8,
        'n_epochs':100,
        'batch_size':8,
        'suffix':'',
        'show_encodings':False,
        'verbose': True} 

main(args)