In [None]:
"""
Implementation of our Contrastive Neural Process approach
Resources for neural processes are from: https://github.com/YannDubs/Neural-Process-Family
"""

import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
import matplotlib
os.chdir("../") #Load from parent directory
from data_utils import gen_loader,load_datasets,split_series
from sklearn.metrics import roc_auc_score,balanced_accuracy_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score,davies_bouldin_score
from models import select_encoder
import pickle
import logging
import warnings
import random
warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")
logging.disable(logging.ERROR)

In [None]:
from npf import ConvCNP,CNP
from functools import partial
from utils.helpers import count_parameters,set_seed
from npf.architectures import CNN, MLP, ResConvBlock, SetConv, discard_ith_arg,merge_flat_input
from npf.utils.datasplit import GridCntxtTrgtGetter,RandomMasker,no_masker
from utils.data import cntxt_trgt_collate


def define_model(device,encoder_type,input_size,encoding_size,range_indcs,density_induced=64,is_contrastive=True):
    get_cntxt_trgt = cntxt_trgt_collate(GridCntxtTrgtGetter(context_masker=partial(RandomMasker(
        a=0.0, b=0.6,range_indcs=range_indcs),is_range=False), target_masker=no_masker),
                                           is_duplicate_batch=is_contrastive)          
    KWARGS = dict(
    XEncoder=partial(MLP, n_hidden_layers=1, hidden_size=encoding_size),
    Decoder=merge_flat_input(  # MLP takes single input but we give x and R so merge them
        partial(MLP, n_hidden_layers=4, hidden_size=encoding_size), is_sum_merge=True,
    ),
    r_dim=encoding_size,)
    
    model = partial(
    CNP,
    x_dim=1,
    y_dim=input_size,
    is_contrastive=is_contrastive,
    XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
        partial(MLP, n_hidden_layers=2, hidden_size=encoding_size * 2), is_sum_merge=True,
    ),
    **KWARGS,)

    print(f"Number Parameters (1D): {count_parameters(model()):,d}")
    return model,get_cntxt_trgt

In [None]:
import skorch
from npf.utils.predict import SamplePredictor
from npf import Contrastive_CNPFLoss,CNPFLoss
from utils.ntbks_helpers import add_y_dim
from utils.train import train_models
import pandas as pd

def train(device,get_cntxt_trgt,trainset,testset,lr,decay,data_type,datasets,batch_size,model,
          epochs,is_retrain=True,is_contrastive=True):
    
    save_dir = "results/pretrained/%s/"%(datasets)
    if not os.path.exists(save_dir):
            os.makedirs(save_dir)
    print('Saving at: ',save_dir)
            
    KWARGS = dict(is_retrain=is_retrain,criterion=partial(
        Contrastive_CNPFLoss,is_contrastive=is_contrastive, device=device, batch_size=batch_size,lreg=0.001),
                  chckpnt_dirname=save_dir,device=device, lr=lr, decay_lr=decay,
                  seed=123,batch_size=batch_size)

    trainers = train_models(
        {data_type:trainset},
        {"ContrCNP_%s"%(data_type): model},
        train_split=skorch.dataset.CVSplit(0.2),
        test_datasets={data_type:testset},
        iterator_train__collate_fn=get_cntxt_trgt,
        iterator_valid__collate_fn=get_cntxt_trgt,
        max_epochs=epochs,
        **KWARGS)

    return trainers

In [None]:
def find_range(data_len,trainset):

    if data_len != 'Vary':
        
        num = int(data_len)//100

        a = num+1
        b = int(data_len)-num-1
    else:
        min_len = np.inf
        for x,y in trainset:
            min_len = min(min_len,x.shape[1])

        data_len = min_len
        num = int(data_len)//100

        a = num+1
        b = int(data_len)-num-1
        
    s = int(data_len)//5

    return [a,b],s

In [None]:
def get_posterior_samples(data,data_len,model,is_uniform_grid=False,img_indcs=None,n_plots=4,seed=123,
                          n_samples=1,is_select_different=False,n_cntxt=100):
    upscale_factor=1
    print('Using %s Points'%(n_cntxt))
    get_cntxt_trgt = GridCntxtTrgtGetter(
            context_masker=RandomMasker(a=n_cntxt, b=n_cntxt),
            target_masker=no_masker,
            is_add_cntxts_to_trgts=False,
            upscale_factor=upscale_factor,
        )
    set_seed(seed)
    model.eval()

    if img_indcs is None:
        img_indcs = [random.randint(0, len(data) - 1) for _ in range(n_plots)]
    n_plots = len(img_indcs)

    imgs = [data[i] for i in img_indcs]

    cntxt_trgt = cntxt_trgt_collate(
        get_cntxt_trgt, is_return_masks=is_uniform_grid
    )(imgs)[0]

    mask_cntxt, Y_cntxt, mask_trgt, Y_trgt = (
        cntxt_trgt["X_cntxt"],
        cntxt_trgt["Y_cntxt"],
        cntxt_trgt["X_trgt"],
        cntxt_trgt["Y_trgt"],
    )
    print(mask_cntxt.shape, Y_cntxt.shape, mask_trgt.shape)
    y_pred = SamplePredictor(model.to('cuda'), is_dist=True)(mask_cntxt.to('cuda'), Y_cntxt.to('cuda'), mask_trgt.to('cuda'))
    #print(y_pred)
    if is_select_different:
        # select the most different in average pixel L2 distance
        keep_most_different_samples_(y_pred, n_samples)
        
    elif isinstance(n_samples, int):
        # select first n_samples
        y_pred.base_dist.loc = y_pred.base_dist.loc[:n_samples, ...]
        y_pred.base_dist.scale = y_pred.base_dist.scale[:n_samples, ...]
    elif n_samples is None:
        pass  # select all
    else:
        ValueError(f"Unkown n_samples={n_samples}.")
        
    ids = (mask_cntxt[0]+1)*data_len/2
    ids=ids.int()
    mean_ys = y_pred.sample_n(n_samples)[:, 0, ...]
    mean_y = mean_ys[0]
    print(Y_cntxt.shape)

    figure, ax = plt.subplots(2,1)
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(8, 5)
    
    ax[0].plot(mean_y.to('cpu').numpy()[0,:,0])
    ax[0].scatter(ids,Y_cntxt[0,:,0],color='r')
    ax[0].set_ylabel('Amplitude')
    ax[1].set_ylabel('Amplitude')
    ax[1].plot(Y_trgt[0,:,0])
    ax[1].scatter(ids,Y_cntxt[0,:,0],color='r')
    ax[1].set_xlabel('Number of points')
    
    plt.show()

    save={}
    save['mean_y']=mean_y
    save['ids']=ids
    save['Y_cntxt']=Y_cntxt
    save['Y_trgt']=Y_trgt
    return save


In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from torchvision import datasets
import seaborn as sns
from sklearn.metrics import silhouette_score,davies_bouldin_score

def plot_enc(x_test,y_test,encoding_size,test_labels,batch_size,device,enc_model,n_cntxt,
             upscale_factor,get_cntxt_trgt,window_size=2500,augment=100):
    
    n_test = len(x_test)
    inds = np.random.randint(0, x_test.shape[-1] - window_size, n_test * augment)
    windows = np.array([x_test[int(i % n_test), :, ind:ind + window_size] for i, ind in enumerate(inds)])
    windows_state = [np.round(np.mean(y_test[i % n_test, ind:ind + window_size], axis=-1)) for i, ind in
                    enumerate(inds)]
    
    testset = torch.utils.data.TensorDataset(torch.Tensor(windows).to(device),torch.Tensor(windows_state).to(device))
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False,drop_last=True)
    encodings = torch.empty(0,encoding_size)

    windows_state=[]
    for x,y in test_loader:
        windows_state.extend(y)

        temp = [(x[i].to('cuda'),y[i].to('cuda')) for i in range(batch_size)]
        encodings = torch.cat((encodings.to('cuda'),get_encodings('cuda',temp,enc_model.to('cuda'),get_cntxt_trgt).to('cuda')))
        encodings = encodings.clone().detach()
    tsne = TSNE(n_components=2)
        
    windows_state = torch.tensor(windows_state)    
    print(encodings.shape,windows_state.shape)
    embedding = tsne.fit_transform(encodings.detach().cpu().numpy())
    df_encoding = pd.DataFrame({"f1": embedding[:, 0], "f2": embedding[:, 1], "state": windows_state})

    fig, ax = plt.subplots()
    sns.set_style("white")
    sns.scatterplot(x="f1", y="f2", data=df_encoding, hue="state", palette="deep")
    plt.show()
    return encodings,windows_state

In [None]:
def clusters(x_test,y_test,enc_model,device,window_size,n_cv,datasets,data_type,encoding_size,
             encoder_type,suffix,get_cntxt_trgt,n_classes,batch_size = 8):
    
    input_size = [x.shape for x in x_test][0][0]
    T = x_test.shape[-1]
    x_chopped_test = np.split(x_test[:, :, :window_size * (T // window_size)], (T // window_size), -1)
    y_chopped_test = np.concatenate(np.split(y_test[:, :window_size * (T // window_size)], (T // window_size), -1),0).astype(int)
    x_chopped_test = torch.Tensor(np.concatenate(x_chopped_test, 0))
    y_chopped_test = torch.Tensor(np.array([np.bincount(yy).argmax() for yy in y_chopped_test]))
    
    testset = torch.utils.data.TensorDataset(x_chopped_test, y_chopped_test)
    loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False,drop_last=True)

    s_score = []
    db_score = []
    for cv in range(n_cv):
        encodings = torch.empty(0,encoding_size)

        windows_state=[]
        for x,y in loader:

            temp = [(x[i].to('cuda'),y[i].to('cuda')) for i in range(batch_size)]
            encodings = torch.cat((encodings.to('cuda'),get_encodings('cuda',temp,enc_model.to('cuda'),get_cntxt_trgt).to('cuda')))
            encodings = encodings.clone().detach()
        encodings=encodings.to('cpu').numpy()
            
        kmeans = KMeans(n_clusters=n_classes, random_state=1).fit(encodings)
        cluster_labels = kmeans.labels_
        print(silhouette_score(encodings, cluster_labels),davies_bouldin_score(encodings, cluster_labels))
        s_score.append(silhouette_score(encodings, cluster_labels))
        db_score.append(davies_bouldin_score(encodings, cluster_labels))
        del encodings
        
    print('Silhouette score: ', np.mean(s_score),'+-', np.std(s_score))
    print('Davies Bouldin score: ', np.mean(db_score),'+-', np.std(db_score))
    
    return

In [None]:
def sup_train(batch_size,window_size,data_type,device,enc_model,encoder_type,datasets,
              encoding_size,lr,n_epochs,tr_percentage,n_classes,get_cntxt_trgt,classifier_type,cv):
    
    train_data,train_labels,test_data,test_labels = load_datasets(data_type,datasets,cv)
    
    if data_type in ['afdb','ims','urban']:
        train_data,train_labels = split_series(train_data,train_labels,window_size)
        test_data,test_labels = split_series(test_data,test_labels,window_size)
    
    tr = int(tr_percentage*len(train_data))
    print(train_data.shape)
    trainset=torch.utils.data.TensorDataset(train_data[:tr],train_labels[:tr])
    valset = torch.utils.data.TensorDataset(train_data[tr:],train_labels[tr:])
    testset=torch.utils.data.TensorDataset(test_data,test_labels)

    train_loader=torch.utils.data.DataLoader(trainset,batch_size=batch_size,shuffle=True,drop_last=True)
    val_loader=torch.utils.data.DataLoader(valset,batch_size=batch_size,shuffle=True,drop_last=True)
    test_loader=torch.utils.data.DataLoader(testset,batch_size=batch_size,shuffle=True,drop_last=True)
    
    input_size = [x.shape for (x,y) in train_loader][0][1]
    _,classifier = select_encoder(device,encoder_type,input_size,encoding_size,n_classes,classifier_type)  
    classifier=classifier.to(device)
    enc_model=enc_model.to(device)
    loss_fn = torch.nn.CrossEntropyLoss()
    params = classifier.parameters()
    optimizer = torch.optim.Adam(params, lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, n_epochs, gamma=0.99)
    best_acc = 0 
    
    for e in range(n_epochs):  
        enc_model.eval()
        classifier.train()
        epoch_acc = 0
        batch_count = 0
        epoch_loss = 0
        for x,y in train_loader:
            optimizer.zero_grad()
            
            temp = [(x[i],y[i]) for i in range(batch_size)]
            encodings = get_encodings(device,temp,enc_model,get_cntxt_trgt)
            prediction = classifier(encodings.detach())
            state_prediction = torch.argmax(prediction, dim=1)
            loss = loss_fn(prediction, y.long().to(device))

            loss.backward()
            optimizer.step()

            epoch_acc += torch.eq(state_prediction.to('cpu'), y).sum().item()/len(x)
            epoch_loss += loss.item()
            batch_count += 1
        
        scheduler.step()
        print(' Epoch ',e,'Train Labels',tr_percentage)
        train_loss,train_acc = epoch_loss / batch_count, epoch_acc / batch_count
        print('Train Results: ',train_loss,train_acc)
        
        val_loss,val_acc,_ = test_supervised(enc_model,classifier,device,val_loader,get_cntxt_trgt,batch_size)
        print('Val Results: ',val_loss,val_acc)
        
        
        test_loss,test_acc,test_auc = test_supervised(enc_model,classifier,device,test_loader,get_cntxt_trgt,batch_size,calc_auc=True)
        print('Test Results: ',test_loss,test_acc,test_auc)
        
        
    return best_acc

def test_supervised(enc_model,classifier,device,data_loader,get_cntxt_trgt,batch_size,calc_auc=False):
    enc_model.eval()
    classifier.eval()
    
    loss_fn = torch.nn.CrossEntropyLoss()
    
    epoch_loss = 0
    epoch_acc = 0
    epoch_auc = 0
    batch_count = 0
    y_all, prediction_all = [], []
    
    for x, y in data_loader:
        temp = [(x[i],y[i]) for i in range(batch_size)]
        encodings = get_encodings(device,temp,enc_model,get_cntxt_trgt)

        prediction = classifier(encodings.detach())
        
        state_prediction = torch.argmax(prediction, -1)
        loss = loss_fn(prediction, y.long().to(device))

        epoch_acc += torch.eq(state_prediction.to('cpu'), y).sum().item()/len(x)
        epoch_loss += loss.item()
        batch_count += 1
        
        y_all.append(y)
        prediction_all.append(prediction.detach().cpu().numpy())

    if calc_auc:
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)

        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all = np.zeros(prediction_all.shape)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        epoch_auc = roc_auc_score(y_onehot_all,prediction_all,multi_class='ovo')

    return epoch_loss / batch_count, epoch_acc / batch_count , epoch_auc

In [None]:
def run_contrnp(n_cross_val,data_type,tr_percentage,val_percentage,n_epochs,ablation,classifier_type,
               batch_size,suffix,verbose,show_encodings,device,device_ids,window_size,encoder_type,
               encoding_size,lr,decay,datasets,n_classes,is_retrain):
    
    for cv in range(n_cross_val):
        train_data,train_labels,test_data,test_labels = load_datasets(data_type,datasets,cv)
        test_data_init = test_data
        test_labels_init = test_labels

        if batch_size<1:
            batch_size = max(1,int(min(len(train_data),len(test_data))*batch_size))
            print('Using batch_size:', batch_size)

        if data_type in ['afdb','ims','urban']:
            train_data,train_labels = split_series(train_data,train_labels,window_size)
            test_data,test_labels = split_series(test_data,test_labels,window_size)

        trainset=torch.utils.data.TensorDataset(train_data,train_labels)
        testset=torch.utils.data.TensorDataset(test_data,test_labels)

        data_len = train_data.shape[2]
        input_size = train_data.shape[1]
    
        range_indcs,n_cntxt = find_range(train_data.shape[2],trainset)
        print(range_indcs,n_cntxt)
        
        model,get_cntxt_trgt = define_model(device,encoder_type,input_size,encoding_size,range_indcs)
        trainers = train(device,get_cntxt_trgt,trainset,testset,lr,decay,data_type,datasets,batch_size,model,n_epochs,
                        is_retrain=is_retrain)
        #print(trainers)
        trained_model = trainers["%s/ContrCNP_%s/run_0"%(data_type,data_type)].module_.cpu()


        upscale_factor=1
        save = get_posterior_samples(testset,data_len,trained_model,n_cntxt=n_cntxt)
        get_cntxt_trgt = GridCntxtTrgtGetter(
                context_masker=RandomMasker(a=n_cntxt, b=n_cntxt),
                target_masker=no_masker,
                is_add_cntxts_to_trgts=False,
                upscale_factor=upscale_factor,
            )
        if show_encodings:
            encodings,windows_state = plot_enc(test_data_init,test_labels_init,encoding_size,test_labels,batch_size,device,trained_model,n_cntxt,upscale_factor,get_cntxt_trgt)

        clusters(test_data_init,test_labels_init,trained_model,device,window_size,1,datasets,data_type,
                 encoding_size,encoder_type,suffix,get_cntxt_trgt,n_classes)

        acc = sup_train(batch_size,window_size,data_type,device,trained_model,encoder_type,datasets,
                      encoding_size,lr,n_epochs,tr_percentage,n_classes,get_cntxt_trgt,classifier_type,cv)

        print(acc)
    
    return

In [None]:
def get_encodings(device,data,enc_model,get_cntxt_trgt):

    enc_model.eval()
    
    cntxt_trgt = cntxt_trgt_collate(
        get_cntxt_trgt,is_duplicate_batch=False, is_return_masks=False)(data)[0]

    X_cntxt, Y_cntxt, X_trgt, Y_trgt = (
        cntxt_trgt["X_cntxt"],
        cntxt_trgt["Y_cntxt"],
        cntxt_trgt["X_trgt"],
        cntxt_trgt["Y_trgt"],
    )
    
    X_cntxt, Y_cntxt, X_trgt, Y_trgt = X_cntxt.to(device), Y_cntxt.to(device), X_trgt.to(device), Y_trgt.to(device)
    #print(X_cntxt.shape, Y_cntxt.shape, X_trgt.shape)
    encodings = SamplePredictor(enc_model,is_enc=True)(X_cntxt, Y_cntxt, X_trgt,Y_trgt,True)
    
    return encodings

In [None]:
def main(args):

    #Devices
    args['device'] = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    args['device_ids'] = [i for i in range(torch.cuda.device_count())]
    print('Using', args['device'])

    #Experiments
    if args['data_type']=='afdb' or args['data_type'] == 'ims' or args['data_type'] == 'urban':
        
        #Experiment Parameters
        args['window_size'] = 2500
        args['encoder_type'] = 1
        args['encoding_size'] = 128
        args['lr'] = 1e-3
        args['decay'] = 10
        args['datasets'] = args['data_type']
        
        if args['data_type'] == 'afdb':
            args['n_classes'] = 4
        elif args['data_type'] == 'ims':
            args['n_classes'] = 5
        elif args['data_type'] == 'urban':
            args['n_cross_val'] = 10
            args['n_classes'] = 10
            
        if args['ablation'] == False:
            run_contrnp(**args)
            
        else:
            train_accs_dict,train_losses_dict = {},{}
            test_accs_dict,test_losses_dict = {},{}
            val_accs_dict,val_losses_dict = {},{}
            args['show_encodings'] = False
            tr = [0.01,0.1,0.2,0.3,0.5,0.7,0.8]
            
            for train_per in tr:
                args['tr_percentage']= train_per
                results = run_contrnp(**args)
                
                train_accs_dict[train_per] = results['train_accs']
                train_losses_dict[train_per] = results['train_losses']
                test_accs_dict[train_per] = results['test_accs']
                test_losses_dict[train_per] = results['test_losses']
                val_accs_dict[train_per] = results['val_accs']
                val_losses_dict[train_per] = results['val_losses']
                
                results['train_accs_dict'] = train_accs_dict
                results['train_losses_dict'] = train_losses_dict
                results['test_accs_dict'] = test_accs_dict
                results['test_losses_dict'] = test_losses_dict
                results['val_accs_dict'] = val_accs_dict
                results['val_losses_dict'] = val_losses_dict
    return

In [None]:
args = {'n_cross_val':5,
        'data_type':'afdb',
        'tr_percentage':0.8,
        'val_percentage': 0.2,
        'n_epochs':100,
        'ablation': False,
        'classifier_type':0,
        'batch_size': 8,
        'suffix':'',
        'verbose':True,
        'show_encodings': False,
        'is_retrain':False} 

main(args)