## Introduction

This notebook provides code for mixup contrastive learning. The method is illustrated on the gunpoint dataset. The dataset used in this notebook is the gunpoint dataset. But more are available. See https://github.com/alan-turing-institute/sktime/tree/master/sktime/datasets/data for more info.

The first two code block clones the sktime Github repo and loads the necessary packages.

In [11]:
#@title Clone Git repos

# !pip install sktime



In [None]:
#@title Load packages and data


import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import matplotlib.pyplot as plt

import os
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from scipy.signal import butter, lfilter
from itertools import repeat

from IPython.display import clear_output
from sktime.datasets import load_gunpoint
from torch.utils.data import Dataset, DataLoader
from sklearn.neighbors import KNeighborsClassifier


def to_np(x):
    return x.cpu().detach().numpy()

## Load data and create Pytorch dataset

The following two code block loads the data, converts it to numpy array, before wrapping it in the Pytorch dataset class.

In [4]:
# #@title load data and convert to numpy array

# x_tr, y_tr = load_gunpoint(split='train', return_X_y=True)

# x_tr = pd.DataFrame(x_tr).to_numpy()
# y_tr = pd.DataFrame(y_tr).to_numpy()

# x_tr = np.array(np.ndarray.tolist(x_tr), dtype=np.float32)
# y_tr = np.array(np.ndarray.tolist(y_tr), dtype=np.int32)

# x_te, y_te = load_gunpoint(split='test', return_X_y=True)


# x_te = pd.DataFrame(x_te).to_numpy()
# y_te = pd.DataFrame(y_te).to_numpy()

# x_te = np.array(np.ndarray.tolist(x_te), dtype=np.float32)
# y_te = np.array(np.ndarray.tolist(y_te), dtype=np.int32)

In [None]:
# #@title create dataset

# class MyDataset(Dataset):
#     def __init__(self, x, y):

#         device = 'cuda'
#         self.x = torch.tensor(x, dtype=torch.float, device=device)
#         self.y = torch.tensor(y, dtype=torch.long, device=device)

#     def __len__(self):
#         return len(self.x)

#     def __getitem__(self, idx):
#         return self.x[idx], self.y[idx]

In [None]:
def load_data(root='dataset', name='chapman', length=None, overlap=0, norm=True):
    ''' load and preprocess data
    '''
    data_path = os.path.join(root, name, 'feature')
    labels, train_ids, valid_ids, test_ids = load_label_split(root, name)
    
    filenames = []
    for fn in os.listdir(data_path):
        filenames.append(fn)
    filenames.sort()
    
    train_trials = []
    train_labels = []
    valid_trials = []
    valid_labels = []
    test_trials = []
    test_labels = []
    
    for i, fn in enumerate(tqdm(filenames, desc=f'=> Loading {name}')):
        label = labels[i]
        feature = np.load(os.path.join(data_path, fn))
        for trial in feature:
            if i+1 in train_ids:
                train_trials.append(trial)
                train_labels.append(label)
            elif i+1 in valid_ids:
                valid_trials.append(trial)
                valid_labels.append(label)
            elif i+1 in test_ids:
                test_trials.append(trial)
                test_labels.append(label)
                
    X_train = np.array(train_trials)
    X_val = np.array(valid_trials)
    X_test = np.array(test_trials)
    y_train = np.array(train_labels)
    y_val = np.array(valid_labels)
    y_test = np.array(test_labels)
    
    if norm:
        X_train = process_batch_ts(X_train, normalized=True, bandpass_filter=False)
        X_val = process_batch_ts(X_val, normalized=True, bandpass_filter=False)
        X_test = process_batch_ts(X_test, normalized=True, bandpass_filter=False)
      
    if length:
        # X_train, y_train = segment(X_train, y_train, split)
        # X_val, y_val = segment(X_val, y_val, split)
        # X_test, y_test = segment(X_test, y_test, split)
        
        X_train, y_train = split_data_label(X_train, y_train, sample_timestamps=length, overlapping=overlap)
        X_val, y_val = split_data_label(X_val, y_val, sample_timestamps=length, overlapping=overlap)
        X_test, y_test = split_data_label(X_test, y_test, sample_timestamps=length, overlapping=overlap)
        
    
    return X_train, X_val, X_test, y_train, y_val, y_test


def load_label_split(root='dataset', name='chapman'):
    '''
    load labels for dataset and split information
    '''
    label_path = os.path.join(root, name, 'label', 'label.npy')
    labels = np.load(label_path)
    
    if name == 'chapman':
        pids_sb = list(labels[np.where(labels[:, 0]==0)][:, 1])
        pids_af = list(labels[np.where(labels[:, 0]==1)][:, 1])
        pids_gsvt = list(labels[np.where(labels[:, 0]==2)][:, 1])
        pids_sr = list(labels[np.where(labels[:, 0]==3)][:, 1])
        
        train_ids = pids_sb[:-500] + pids_af[:-500] + pids_gsvt[:-500] + pids_sr[:-500]
        val_ids = pids_sb[-500:-250] + pids_af[-500:-250] + pids_gsvt[-500:-250] + pids_sr[-500:-250]
        test_ids = pids_sb[-250:] + pids_af[-250:] + pids_gsvt[-250:] + pids_sr[-250:]
        
    elif name == 'ptb':
        pids_neg = list(labels[np.where(labels[:, 0]==0)][:, 1])
        pids_pos = list(labels[np.where(labels[:, 0]==1)][:, 1])
        
        train_ids = pids_neg[:-14] + pids_pos[:-42]  # specify patient ID for training, validation, and test set
        val_ids = pids_neg[-14:-7] + pids_pos[-42:-21]   # 28 patients, 7 healthy and 21 positive
        test_ids = pids_neg[-7:] + pids_pos[-21:]  # # 28 patients, 7 healthy and 21 positive
        
    elif name == 'ptbxl':
        pids_norm = list(labels[np.where(labels[:, 0]==0)][:, 1])
        pids_mi = list(labels[np.where(labels[:, 0]==1)][:, 1])
        pids_sttc = list(labels[np.where(labels[:, 0]==2)][:, 1])
        pids_cd = list(labels[np.where(labels[:, 0]==3)][:, 1])
        pids_hyp = list(labels[np.where(labels[:, 0]==3)][:, 1])
        
        train_ids = pids_norm[:-1200] + pids_mi[:-600] + pids_sttc[:-600] + pids_cd[:-400] + pids_hyp[:-200]
        val_ids = pids_norm[-1200:-600] + pids_mi[-600:-300] + pids_sttc[-600:-300] + pids_cd[-400:-200] + pids_hyp[-200:-100]
        test_ids = pids_norm[-600:] + pids_mi[-300:] + pids_sttc[-300:] + pids_cd[-200:] + pids_hyp[-100:]
    
    # TODO: CPSC2018, etc.
    else:
        raise ValueError(f'Unknown dataset: {name}')
        
    return labels, train_ids, val_ids, test_ids


def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    ''' see https://stackoverflow.com/questions/12093594/how-to-implement-band-pass-butterworth-filter-with-scipy-signal-butter

    '''
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data, axis=0)
    return y


def process_ts(ts, fs, normalized=True, bandpass_filter=False):
    ''' preprocess a time-series data

    Args:
        ts (numpy.ndarray): The input time-series in shape (timestamps, feature).
        fs (float): The sampling frequency for bandpass filtering.
        normalized (bool): Whether to normalize the time-series data.
        bandpass_filter (bool): Whether to filter the time-series data.

    Returns:
        ts (numpy.ndarray): The processed time-series.
    '''

    if bandpass_filter:
        ts = butter_bandpass_filter(ts, 0.5, 50, fs, 5)
    if normalized:
        scaler = StandardScaler()
        scaler.fit(ts)
        ts = scaler.transform(ts)
    return ts


def process_batch_ts(batch, fs=256, normalized=True, bandpass_filter=False):
    ''' preprocess a batch of time-series data

    Args:
        batch (numpy.ndarray): A batch of input time-series in shape (n_samples, timestamps, feature).

    Returns:
        A batch of processed time-series.
    '''

    bool_iterator_1 = repeat(fs, len(batch))
    bool_iterator_2 = repeat(normalized, len(batch))
    bool_iterator_3 = repeat(bandpass_filter, len(batch))
    return np.array(list(map(process_ts, batch, bool_iterator_1, bool_iterator_2, bool_iterator_3)))


def split_data_label(X_trial, y_trial, sample_timestamps, overlapping):
    ''' split a batch of time-series trials into samples and adding trial ids to the label array y

    Args:
        X_trial (numpy.ndarray): It should have a shape of (n_trials, trial_timestamps, features) B_trial x T_trial x C.
        y_trial (numpy.ndarray): It should have a shape of (n_trials, 2). The first column is the label and the second column is patient id.
        sample_timestamps (int): The length for sample-level data (T_sample).
        overlapping (float): How many overlapping for each sample-level data in a trial.

    Returns:
        X_sample (numpy.ndarray): It should have a shape of (n_samples, sample_timestamps, features) B_sample x T_sample x C. The B_sample = B x sample_num.
        y_sample (numpy.ndarray): It should have a shape of (n_samples, 3). The three columns are the label, patient id, and trial id.
    '''
    X_sample, trial_ids, sample_num = split_data(X_trial, sample_timestamps, overlapping)
    # all samples from same trial should have same label and patient id
    y_sample = np.repeat(y_trial, repeats=sample_num, axis=0)
    # append trial ids. Segments split from same trial should have same trial ids
    label_num = y_sample.shape[0]
    y_sample = np.hstack((y_sample.reshape((label_num, -1)), trial_ids.reshape((label_num, -1))))
    # X_sample, y_sample = shuffle(X_sample, y_sample)
    return X_sample, y_sample


def split_data(X_trial, sample_timestamps=256, overlapping=0.5):
    ''' split a batch of trials into samples and mark their trial ids

    Args:
        See split_data_label() function

    Returns:
        X_sample (numpy.ndarray): (n_samples, sample_timestamps, feature).
        trial_ids (numpy.ndarray): (n_samples,)
        sample_num (int): one trial splits into sample_num of samples
    '''
    length = X_trial.shape[1]
    # check if sub_length and overlapping compatible
    if overlapping:
        assert (length - (1-overlapping)*sample_timestamps) % (sample_timestamps*overlapping) == 0
        sample_num = (length - (1 - overlapping) * sample_timestamps) / (sample_timestamps * overlapping)
    else:
        assert length % sample_timestamps == 0
        sample_num = length / sample_timestamps
    sample_feature_list = []
    trial_id_list = []
    trial_id = 1
    for trial in X_trial:
        counter = 0
        # ex. split one trial(5s, 1280 timestamps) into 9 half-overlapping samples (1s, 256 timestamps)
        while counter*sample_timestamps*(1-overlapping)+sample_timestamps <= trial.shape[0]:
            sample_feature = trial[int(counter*sample_timestamps*(1-overlapping)):int(counter*sample_timestamps*(1-overlapping)+sample_timestamps)]
            # print(f"{int(counter*length*(1-overlapping))}:{int(counter*length*(1-overlapping)+length)}")
            sample_feature_list.append(sample_feature)
            trial_id_list.append(trial_id)
            counter += 1
        trial_id += 1
    X_sample, trial_ids = np.array(sample_feature_list), np.array(trial_id_list)

    return X_sample, trial_ids, sample_num

## Define neural network

In this block we define the neural network architecture used. This architecture is based on the fully convolutional network from https://arxiv.org/abs/1611.06455, but with dilation added to each convolutional layer.

In [None]:
from model.encoder import TSEncoder

In [None]:
#@title Define FCN

class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()

        self.encoder = TSEncoder(input_dims=12, output_dims=320)

        self.proj_head = nn.Sequential(
            nn.Linear(320, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 128)
        )

    def forward(self, x):

        h = self.encoder(x, pool='max')
        out = self.proj_head(h)

        return out, h

## Define loss, training function and evaluation function.

The following three code blocks implements the mixup contrastive loss, the training function and the evaluation function.

In [None]:
#@title define MixUp Loss

class MixUpLoss(torch.nn.Module):

    def __init__(self, device, batch_size):
        super(MixUpLoss, self).__init__()
        
        self.tau = 0.5
        self.device = device
        self.batch_size = batch_size
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, z_aug, z_1, z_2, lam):

        z_1 = nn.functional.normalize(z_1)
        z_2 = nn.functional.normalize(z_2)
        z_aug = nn.functional.normalize(z_aug)

        labels_lam_0 = lam*torch.eye(self.batch_size, device=self.device)
        labels_lam_1 = (1-lam)*torch.eye(self.batch_size, device=self.device)

        labels = torch.cat((labels_lam_0, labels_lam_1), 1)

        logits = torch.cat((torch.mm(z_aug, z_1.T),
                         torch.mm(z_aug, z_2.T)), 1)

        loss = self.cross_entropy(logits / self.tau, labels)

        return loss

    def cross_entropy(self, logits, soft_targets):
        return torch.mean(torch.sum(- soft_targets * self.logsoftmax(logits), 1))


In [None]:
#@title mixup model trainer per epoch


def train_mixup_model_epoch(model, training_set, test_set, optimizer, alpha, epochs):

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size_tr = len(training_set.x)

    LossList, AccList = [], []
    criterion = MixUpLoss(device, batch_size_tr)

    training_generator = DataLoader(training_set, batch_size=batch_size_tr,
                                    shuffle=True, drop_last=True)

    for epoch in range(epochs):

        for x, y in training_generator:

            model.train()

            optimizer.zero_grad()

            x_1 = x
            x_2 = x[torch.randperm(len(x))]

            lam = np.random.beta(alpha, alpha)

            x_aug = lam * x_1 + (1-lam) * x_2

            z_1, _ = model(x_1)
            z_2, _ = model(x_2)
            z_aug, _ = model(x_aug)

            loss= criterion(z_aug, z_1, z_2, lam)
            loss.backward()
            optimizer.step()
            LossList.append(loss.item())


        AccList.append(test_model(model, training_set, test_set))

        print(f"Epoch number: {epoch}")
        print(f"Loss: {LossList[-1]}")
        print(f"Accuracy: {AccList[-1]}")
        print("-"*50)

        torch.save(model.encoder.state_dict(), f'model_{epoch}.pth')
        if epoch % 10 == 0 and epoch != 0: clear_output()
            
    return LossList, AccList

In [None]:
#@title model evaluation


def test_model(model, training_set, test_set):

    model.eval()

    N_tr = len(training_set.x)
    N_te = len(test_set.x)

    training_generator = DataLoader(training_set, batch_size=1,
                                    shuffle=True, drop_last=False)
    test_generator = DataLoader(test_set, batch_size= 1,
                                    shuffle=True, drop_last=False)

    H_tr = torch.zeros((N_tr, 128))
    y_tr = torch.zeros((N_tr), dtype=torch.long)

    H_te = torch.zeros((N_te, 128))
    y_te = torch.zeros((N_te), dtype=torch.long)

    for idx_tr, (x_tr, y_tr_i) in enumerate(training_generator):
        with torch.no_grad():
            _, H_tr_i = model(x_tr)
            H_tr[idx_tr] = H_tr_i
            y_tr[idx_tr] = y_tr_i

    H_tr = to_np(nn.functional.normalize(H_tr))
    y_tr = to_np(y_tr)


    for idx_te, (x_te, y_te_i) in enumerate(test_generator):
        with torch.no_grad():
            _, H_te_i = model(x_te)
            H_te[idx_te] = H_te_i
            y_te[idx_te] = y_te_i

    H_te = to_np(nn.functional.normalize(H_te))
    y_te = to_np(y_te)

    clf = KNeighborsClassifier(n_neighbors=1).fit(H_tr, y_tr)

    return clf.score(H_te, y_te)


## Block for training the model

This block trains the neural network using mixup contrastive learning. 

In [None]:
#@title Experiment number of epochs

device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 100

alpha = 1.0

X_train, X_val, X_test, y_train, y_val, y_test = load_data(root='/root/auto-tmp/dataset', name='ptbxl', length=300)
training_set = torch.utils.data.TensorDataset(torch.from_numpy(X_train).to(torch.float32), torch.from_numpy(y_train).to(torch.long))
test_set = torch.utils.data.TensorDataset(torch.from_numpy(X_val).to(torch.float32), torch.from_numpy(y_val).to(torch.long))

model = FCN().to(device)

optimizer = torch.optim.Adam(model.parameters())
LossListM, AccListM = train_mixup_model_epoch(model, training_set, test_set,
                                              optimizer, alpha, epochs)


print(f"Score for alpha = {alpha}: {AccListM[-1]}")


plt.figure(1, figsize=(8, 8))
plt.subplot(121)
plt.plot(LossListM)
plt.title('Loss')
plt.subplot(122)
plt.plot(AccListM)
plt.title('Accuracy')
plt.show()

