In [None]:
"""
Implementation of the SimCLR baseline based on the original code available on
https://github.com/sthalles/SimCLR
"""

import os
import random
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader,Dataset

import torch
import torch.nn as nn

os.chdir("../") #Load from parent directory
from data_utils import Plots,gen_loader,load_datasets,split_series,CustomTensorDataset
from models import select_encoder
utils_plot=Plots()

In [None]:
class NTXentLoss(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity=True):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        zis = nn.functional.normalize(zis, dim=1)
        zjs = nn.functional.normalize(zjs, dim=1)
        
        representations = torch.cat([zjs, zis], dim=0)
        similarity_matrix = self.similarity_function(representations, representations)
        #print(similarity_matrix)
        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        
        #print(l_pos.shape,r_pos.shape,self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature

        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        
        
        loss = self.criterion(logits, labels)
        return loss / (2 * self.batch_size)

In [None]:
def train_simclr(device,device_ids,datasets,n_cross_val,verbose,lr,data_type,encoder_type,show_encodings,
                 tr_percentage,window_size,batch_size,encoding_size,n_epochs,suffix):
    
    train_accs, test_accs = {},{}
    train_losses, test_losses ={},{}
    val_accs,val_losses ={},{}
    
    for cv in range(n_cross_val):
        
        #Save Location
        save_dir = './results/baselines/%s_simclr/%s/'%(datasets,data_type)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            
        save_file = str((save_dir +'encoding_%d_encoder_%d_checkpoint_%d%s.pth.tar')
               %(encoding_size,encoder_type, cv,suffix))
        
        if verbose:
            print('Saving at: ',save_file)
            
        best_loss = np.inf
        
        
        train_data,train_labels,test_data,test_labels = load_datasets(data_type,datasets,cv)
        x,y = split_series(train_data,train_labels,window_size)
        
        if batch_size<1:
            batch_size = max(1,int(min(len(train_data),len(test_data))*batch_size))
            print('Using batch_size:', batch_size)

        nt_xent_criterion = NTXentLoss(device, batch_size,temperature = 0.5,use_cosine_similarity=True)
        inds = list(range(len(x)))
        random.shuffle(inds)
        x = x[inds]
        y = y[inds]
        n_train = int(tr_percentage*len(x))
        
        sample_rate=250
        trainset = CustomTensorDataset([x[:n_train],y[:n_train]],sample_rate=sample_rate,is_transform=True)
        valset = CustomTensorDataset([x[n_train:],y[n_train:]],sample_rate=sample_rate,is_transform=True)
        
        sampler = torch.utils.data.sampler.BatchSampler(torch.utils.data.sampler.RandomSampler
                                                (trainset),batch_size=batch_size,drop_last=True)
        train_loader = torch.utils.data.DataLoader(trainset,sampler=sampler)
        
        sampler = torch.utils.data.sampler.BatchSampler(torch.utils.data.sampler.RandomSampler
                                                (valset),batch_size=batch_size,drop_last=True)
        val_loader = torch.utils.data.DataLoader(valset,sampler=sampler)
        
        input_size = train_data.shape[1]
        encoder,_ = select_encoder(device,encoder_type,input_size,encoding_size)
        encoder = encoder.to(device)
        optimizer = torch.optim.Adam(encoder.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, n_epochs, gamma=0.99)
        
        for epoch in range(n_epochs):
            total_loss = 0
            batch_count = 0
            encoder.train()
            for i,((xis, xjs),_) in enumerate(train_loader):
                optimizer.zero_grad()
                
                xis = xis.squeeze(0).to(device)
                xjs = xjs.squeeze(0).to(device)

                # get the representations and the projections
                zis = encoder(xis.to(device))  # [N,C]
                zjs = encoder(xjs.to(device))  # [N,C]
                
                #print(zis.shape,zjs.shape)
                loss = nt_xent_criterion(zis, zjs)
                total_loss+=loss.item()
                batch_count += 1
                loss.backward()
                #print(loss)
                optimizer.step()
                
            scheduler.step()
            if verbose:
                print('CV ',cv,' Epoch ',epoch,'Train Labels',tr_percentage)
            
                print('Train Results: ',total_loss / batch_count)
            
            val_loss = test_supervised(encoder,device,val_loader,batch_size,calc_auc=False)
            if verbose:
                print('Validation Results: ',val_loss)
            if best_loss>=val_loss:
                best_loss = val_loss
                state = {
                        'epoch': epoch,
                        'encoder_state_dict': encoder.state_dict()
                    }
                torch.save(state, save_file)
                if verbose:
                    print('Saving ckpt')
            if verbose:
                print('')    
    return

In [None]:
def test_supervised(encoder,device,data_loader,batch_size,calc_auc=False):
    encoder.eval()
    
    nt_xent_criterion = NTXentLoss(device, batch_size,temperature = 0.5,use_cosine_similarity=True)
    
    epoch_loss = 0
    batch_count = 0
    
    for i,((xis, xjs),_) in enumerate(data_loader):
        xis = xis.squeeze(0).to(device)
        xjs = xjs.squeeze(0).to(device)

        # get the representations and the projections
        zis = encoder(xis.to(device))  # [N,C]
        zjs = encoder(xjs.to(device))  # [N,C]

        #print(zis.shape,zjs.shape)
        loss = nt_xent_criterion(zis, zjs)
        
        epoch_loss += loss.item()
        batch_count += 1
        
    return epoch_loss / batch_count

In [None]:
def run_simclr(args):

    #Run Process
    train_simclr(**args)
    
    #Plot Features
    title = 'SimCLR Encoding TSNE for %s'%(args['data_type'])
    
    if args['show_encodings']:
        
        for cv in range(args['n_cross_val']):
            _,_,test_data,test_labels = load_datasets(args['data_type'],args['datasets'],cv)
            utils_plot.plot_distribution(test_data, test_labels,args['encoder_type'],
                                                             args['encoding_size'],args['window_size'],'simclr',
                                                             args['datasets'],args['data_type'],args['suffix'],
                                                             args['device'], title, cv,parallel=False)
    return 

In [None]:
def main(args):
    
    #Devices
    args['device'] = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    args['device_ids'] = [0]
    print('Using', args['device'])

    #Experiments
    if args['data_type']=='afdb':
        args['lr'] = 1e-3
    if args['data_type']=='ims':
        args['lr'] = 1e-3    
    if args['data_type']=='urban':
        args['n_cross_val'] = 10
        args['lr'] = 1e-4   
        
    #Experiment Parameters
    args['window_size'] = 2500
    args['encoder_type'] = 1
    args['encoding_size'] = 128
    args['datasets'] = args['data_type']

    run_simclr(args)
    return

In [None]:
args = {'n_cross_val':5,
        'data_type':'afdb', #options: afdb, ims, urban
        'tr_percentage':0.8,
        'n_epochs':15,
        'batch_size':16,
        'suffix':'',
        'show_encodings':False,
        'verbose': True} 

main(args)