In [None]:
"""
Implementation of the CPC baseline based on the code available on
https://openreview.net/forum?id=8qDwejCuCN
"""

import os
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

os.chdir("../") #Load from parent directory
from data_utils import Plots,gen_loader,load_datasets
from models import select_encoder
utils_plot=Plots()

In [None]:
def epoch_run(data, encoder, ds_estimator, auto_regressor, device, window_size, n_size, optimizer, train=True):
    
    if window_size==-1:
        window_size = max(2,int(data.shape[-1]/10))
        if window_size*10 == data.shape[-1]:
            window_size = window_size-1

    if train:
        encoder.train()
        ds_estimator.train()
        auto_regressor.train()
    else:
        encoder.eval()
        ds_estimator.eval()
        auto_regressor.eval()

    epoch_loss = 0
    acc = 0
    
    for sample in data:

        rnd_t = np.random.randint(5*window_size,sample.shape[-1]-5*window_size)
        sample = torch.Tensor(sample[:,max(0,(rnd_t-20*window_size)):min(sample.shape[-1], rnd_t+20*window_size)])
        
        T = sample.shape[-1]
        windowed_sample = np.split(sample[:, :(T // window_size) * window_size], (T // window_size), -1)
        windowed_sample = torch.tensor(np.stack(windowed_sample, 0), device=device)   
        encodings = encoder(windowed_sample)

        window_ind = torch.randint(2,len(encodings)-2, size=(1,))
        
        _, c_t = auto_regressor(encodings[max(0, window_ind[0]-10):window_ind[0]+1].unsqueeze(0))
        
        density_ratios = torch.bmm(encodings.unsqueeze(1),
                                   ds_estimator(c_t.squeeze(1).squeeze(0)).expand_as(encodings).unsqueeze(-1)).view(-1,)
        
        r = set(range(0, window_ind[0] - 2))
        r.update(set(range(window_ind[0] + 3, len(encodings))))
        rnd_n = np.random.choice(list(r), n_size)
        X_N = torch.cat([density_ratios[rnd_n], density_ratios[window_ind[0] + 1].unsqueeze(0)], 0)
        
        if torch.argmax(X_N)==len(X_N)-1:
            acc += 1
        labels = torch.Tensor([len(X_N)-1]).to(device)
        loss = torch.nn.CrossEntropyLoss()(X_N.view(1, -1), labels.long())
        epoch_loss += loss.item()

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    return epoch_loss / len(data), acc/(len(data))

In [None]:
def learn_encoder(n_cross_val,data_type,datasets,lr,window_size,n_size,tr_percentage,
                  encoder_type,encoding_size,decay,n_epochs,suffix,device,device_ids,verbose,show_encodings):
        
    accuracies=[]
    for cv in range(n_cross_val):
        train_data,train_labels,test_data,test_labels = load_datasets(args['data_type'],args['datasets'],cv)

        #Save Location
        save_dir = './results/baselines/%s_cpc/%s/'%(datasets,data_type)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            
        save_file = str((save_dir +'encoding_%d_encoder_%d_checkpoint_%d%s.pth.tar')
               %(encoding_size,encoder_type, cv,suffix))
        
        if verbose:
            print('Saving at: ',save_file)
        
        #Models

        input_size = train_data.shape[1]
        encoder,_ = select_encoder(device,encoder_type,input_size,encoding_size)
        ds_estimator = torch.nn.Linear(encoder.encoding_size, encoder.encoding_size).to(device)
        auto_regressor = torch.nn.GRU(input_size=encoding_size, hidden_size=encoding_size, batch_first=True).to(device)
        
        #Training init
        params = list(ds_estimator.parameters()) + list(encoder.parameters()) + list(auto_regressor.parameters())
        optimizer = torch.optim.Adam(params, lr=lr, weight_decay=decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, n_epochs, gamma=0.999)
        
        #Split/Shuffle train and val
        inds = list(range(len(train_data)))
        random.shuffle(inds)
        train_data = train_data[inds]
        n_train = int(tr_percentage*len(train_data))
        best_acc = 0
        best_loss = np.inf
        train_loss, val_loss = [], []

        #Train
        for epoch in range(n_epochs):
            epoch_loss, acc = epoch_run(train_data[:n_train], encoder, ds_estimator,
                                        auto_regressor, device, window_size, n_size, optimizer, train=True)
            epoch_loss_val, acc_val = epoch_run(train_data[n_train:], encoder, ds_estimator,
                                                auto_regressor, device, window_size, n_size, optimizer, train=False)
            scheduler.step()
            
            if verbose:
                print('\nEpoch ', epoch)
                print('Train ===> Loss: ', epoch_loss)
                print('Validation ===> Loss: ', epoch_loss_val)
            
            train_loss.append(epoch_loss)
            val_loss.append(epoch_loss_val)
            
            if epoch_loss_val<best_loss:
                state = {
                    'epoch': epoch,
                    'encoder_state_dict': encoder.state_dict()
                }
                best_acc = acc_val
                best_loss = epoch_loss_val
                torch.save(state, save_file)
                if verbose:
                    print('Saving ckpt')
                
        accuracies.append(best_acc)
        plt.figure()
        plt.plot(np.arange(n_epochs), train_loss, label="Train")
        plt.plot(np.arange(n_epochs), val_loss, label="Validation")
        plt.title("CPC Unsupervised Loss")
        plt.legend()
        plt.savefig(save_dir +'encoding_%d_encoder_%d_checkpoint_%d%s.png'%(encoding_size,encoder_type, cv,suffix))
        if verbose:
            plt.show()
        plt.close()
        
        if verbose:
            print('Best Train ===> Loss: ', np.min(train_loss))
            print('Best Validation ===> Loss: ', np.min(val_loss))
            print('-----Accuracy: %.2f +- %.2f-----' % (100 * np.mean(accuracies), 100 * np.std(accuracies)))
    return

In [None]:
def run_cpc(args):
    s_scores=[]
    dbi_scores=[]
    #Run Process
    learn_encoder(**args)
    
    #Plot Features
    title = 'CPC Encoding TSNE for %s'%(args['data_type'])
    if args['show_encodings']:
        for cv in range(args['n_cross_val']):
            train_data,train_labels,test_data,test_labels = load_datasets(args['data_type'],args['datasets'],cv)
            utils_plot.plot_distribution(test_data, test_labels,args['encoder_type'],
                                         args['encoding_size'],args['window_size'],'cpc',
                                         args['datasets'],args['data_type'],args['suffix'],
                                         args['device'], title, cv,augment=100) 
    return

In [None]:
def main(args):
    
    #Devices
    args['device'] = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    args['device_ids'] = [i for i in range(torch.cuda.device_count())]
    args['device'] = torch.device("cuda:0")
    print('Using', args['device'])

    #Experiment Parameters
    args['window_size'] = 2500
    args['encoder_type'] = 1
    args['encoding_size'] = 128
    args['lr'] = 1e-4
    args['decay'] = 1e-5
    args['datasets'] = args['data_type']
    args['n_size'] = 4

    run_cpc(args)
    return

In [None]:
args = {'n_cross_val':5,
        'data_type':'afdb',  #options: afdb, ims, urban
        'tr_percentage':0.8,
        'n_epochs':100,
        'suffix':'',
        'show_encodings':False,
        'verbose': True} 

main(args)