### import packages

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import random_split, ConcatDataset, DataLoader, TensorDataset
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix
from torch import Tensor

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device is {DEVICE}")

### models

In [None]:
class CVAE(nn.Module):

    def __init__(self, x_size, y_size, latent_size ):
        super().__init__()
        self.__latent_size = latent_size
        self.fc1 = nn.Linear(x_size + y_size, 128)
        self.fc2 = nn.Linear(128, latent_size)
        self.fc3 = nn.Linear(128, latent_size)
        self.fc4 = nn.Linear(latent_size + y_size, 256)
        self.fc5 = nn.Linear(256, x_size)


    def encoder(self,x):
        x = F.relu(self.fc1(x))
        mu = self.fc2(x)
        log_var = self.fc3(x)
        return mu, log_var


    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps*std
    

    def sample(self, y):
        z = torch.randn(y.size(0), self.__latent_size, device=y.device)
        return self.decoder(torch.cat((z, y), dim=1)), z
    

    def decoder(self, z):
        z = F.relu(self.fc4(z))
        x_hat = self.fc5(z)
        return x_hat
    

    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_hat = self.decoder(torch.cat((z, y), dim=1))
        return x_hat, mu, log_var
    

    def train_loss(self, x, y):
        x_hat, mu, log_var = self.forward(x, y)
        MSE = F.mse_loss(x_hat, x, reduction='mean')
        KLD = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        return MSE + KLD
    

class MultiView(nn.Module):

    def __init__(self, multiview_mask):
        super().__init__()
        self.multiview_mask = multiview_mask
        view_count, input_size = multiview_mask.shape
        self.bn = nn.BatchNorm1d(input_size)
        # Performance optimization: parallel compute
        # https://stackoverflow.com/a/58389075/318557
        self.mlp_parallel_cnn = nn.Sequential(
            nn.Conv1d(input_size * view_count, 64*view_count, kernel_size = 1, stride = 1, padding = 0, groups = view_count),
            nn.ReLU(),
            nn.Conv1d(64*view_count, 128*view_count, kernel_size = 1, stride = 1, padding = 0, groups = view_count),
            nn.ReLU(),
            nn.Conv1d(128*view_count, input_size * view_count, kernel_size = 1, stride = 1, padding = 0, groups = view_count),
            nn.Sigmoid()
        )
        self.mlp_classifier = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2),
            nn.Softmax(dim=1)
        )


    def forward(self, x):
        mask = self.get_score(x)
        x = mask * x
        y_hat = self.mlp_classifier(x)
        return y_hat, mask
    

    def get_score(self, x):
        x = self.bn(x)
        x = x.unsqueeze(1).repeat(1, self.multiview_mask.shape[0], 1)
        x = self.multiview_mask * x
        x = torch.flatten(x, start_dim=1).unsqueeze(2)
        x = self.mlp_parallel_cnn(x)
        x = x.view(x.shape[0], self.multiview_mask.shape[0], -1)
        x = self.multiview_mask * x
        x = x.sum(dim=1)
        return x
    

    def train_loss(self, x, y):
        y_hat, mask = self.forward(x)
        return F.cross_entropy(y_hat, y.view(-1).long()) + 2.2e-3 * mask.mean()

### load data

In [None]:
from datas import load_data
LOAD_PATH = "./data/"
# samples size: (47, 392)
# labels size: (47,)
samples, labels = load_data(LOAD_PATH)
std, mean = torch.std_mean(samples, dim=0)
samples = ((samples - mean) / std).to(DEVICE)
labels = labels.to(DEVICE)

### utils

In [None]:
def split_dataset(samples: Tensor, labels: Tensor):
    data_p = samples[labels.view(-1) == 1]
    data_n = samples[labels.view(-1) == 0]
    split_ratio = (0.6, 0.4)
    train_p, test_p = random_split(TensorDataset(data_p, torch.ones(len(data_p), 1, device=DEVICE)), split_ratio)
    train_n, test_n = random_split(TensorDataset(data_n, torch.zeros(len(data_n), 1, device=DEVICE)), split_ratio)
    train_set = ConcatDataset([train_p, train_n])
    test_set = ConcatDataset([test_p, test_n])
    train_loader = DataLoader(train_set, batch_size=47, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=47, shuffle=False)
    return train_loader, test_loader


def print_classification_result(cm: torch.Tensor) -> None:
    acc = (cm[:, 0, 0] + cm[:, 1, 1]) / cm.sum(dim=(1, 2))
    precision = cm[:, 1, 1] / (cm[:, 1, 1] + cm[:, 0, 1])
    recall = cm[:, 1, 1] / (cm[:, 1, 1] + cm[:, 1, 0])
    fpr = cm[:, 0, 1] / (cm[:, 0, 1] + cm[:, 0, 0])
    f1 = 2 * precision * recall / (precision + recall)
    n = len(cm)
    acc_std, acc_mean = torch.std_mean(acc)
    precision_std, precision_mean = torch.std_mean(precision)
    recall_std, recall_mean = torch.std_mean(recall)
    fpr_std, fpr_mean = torch.std_mean(fpr)
    f1_std, f1_mean = torch.std_mean(f1)
    sqrt_n = torch.sqrt(torch.tensor(n, dtype=float))
    print(f"acc: {acc_mean * 100:.2f}%±{1.96 * acc_std * 100 / sqrt_n:.2f}%.")
    print(f"precision: {precision_mean * 100:.2f}%±{1.96 * precision_std * 100 / sqrt_n:.2f}%.")
    print(f"recall: {recall_mean * 100:.2f}%±{1.96 * recall_std * 100 / sqrt_n:.2f}%.")
    print(f"fpr: {fpr_mean * 100:.2f}%±{1.96 * fpr_std * 100 / sqrt_n:.2f}%.")
    print(f"f1: {f1_mean * 100:.2f}%±{1.96 * f1_std * 100 / sqrt_n:.2f}%.")


def data_augmentation(train_loader, input_size, da_count=256):
    epochs = 500
    learning_rate = 4.7e-3
    latent_size =  16
    cvae = CVAE(input_size, 1, latent_size).to(DEVICE)
    optimizer = torch.optim.Adam(cvae.parameters(), lr=learning_rate)
    for _ in range(epochs):
        for x, y in train_loader:
            optimizer.zero_grad()
            loss = cvae.train_loss(x, y)
            loss.backward()
            optimizer.step()
    with torch.no_grad():
        y_ = torch.bernoulli(torch.ones(da_count, 1, device=DEVICE) * 0.5)
        x_, z_ = cvae.sample(y_)
    return x_, y_


def feature_cluster(x):
    x = x.detach().cpu().numpy()
    kmeans = KMeans(n_clusters=8)
    k_mean_mask = F.one_hot(torch.tensor(kmeans.fit_predict(x.T),dtype=torch.long))
    return k_mean_mask.T


def check_features(samples, labels, selector, use_da=False, repeat=10):
    cm = []
    y_softmax = []
    for _ in range(repeat):
        train_loader, test_loader = split_dataset(samples, labels)
        train_samples, train_labels = data_augmentation(train_loader, samples.shape[1], da_count=256) if use_da else next(iter(train_loader))
        train_samples = train_samples[:, selector] if selector != None else train_samples
        train_loader = DataLoader(TensorDataset(train_samples, train_labels), batch_size=32, shuffle=True)
        classifier = nn.Sequential(
            nn.Linear(train_samples.shape[1], 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2),
            nn.Softmax(dim=1)
        ).to(device=DEVICE)
        optimizer = Adam(classifier.parameters(), lr=1.0e-3, weight_decay=1.0e-5)
        for epoch in range(16):
            for x, y in train_loader:
                y_hat = classifier(x)
                loss = F.cross_entropy(y_hat, y.view(-1).long())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        with torch.no_grad():
            x, y = next(iter(test_loader))
            x = x[:, selector] if selector != None else x
            y = y.view(-1)
            y_hat = classifier(x)
            y_softmax.append(y_hat)
            cm.append(torch.tensor(confusion_matrix(y.cpu().numpy(), y_hat.argmax(dim=1).cpu().numpy())))
    cm = torch.stack(cm)
    y_softmax = torch.stack(y_softmax)
    print_classification_result(cm)
    print()
    return cm, y_softmax


def chekc_stablility(marker: torch.Tensor):
    dh = []
    n = len(marker)
    for i in range(n):
        for j in range(i + 1, n):
            dh.append((marker[i] != marker[j]).sum().item())
    dh = torch.tensor(dh, dtype=torch.float)
    dh_std, dh_mean = torch.std_mean(dh)
    print(f"dh: {dh_mean:.2f}±{1.96 * dh_std / np.sqrt(len(marker)):.2f}")


### training

In [None]:
repeat = 300
multiview_mask = feature_cluster(samples).to(DEVICE)
scores = []
for _ in range(repeat):
    train_loader, test_loader = split_dataset(samples, labels)
    train_samples, train_labels = data_augmentation(train_loader, samples.shape[1], da_count=256)
    train_loader = DataLoader(TensorDataset(train_samples, train_labels), batch_size=32, shuffle=True)
    mvfs = MultiView(multiview_mask).to(DEVICE)
    params = [{
        "params": mvfs.mlp_parallel_cnn.parameters(),
        "weight_decay": 1.0e-6
    }, {
        "params": mvfs.mlp_classifier.parameters(),
        "weight_decay": 1.0e-4
    }, {
        "params": mvfs.bn.parameters(),
        "weight_decay": 1.0e-6
    }]
    optimizer = torch.optim.Adam(params, lr=1.0e-3)
    mvfs.train(True)
    epochs = 500
    for epoch in range(epochs):
        for x, y in train_loader:
            loss = mvfs.train_loss(x, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    with torch.no_grad():
        mvfs.train(False)
        score = mvfs.get_score(train_samples).mean(dim=0)
        scores.append(score)
threshold = 0.6
scores = torch.stack(scores)
chekc_stablility(scores > threshold)
score_std, score_mean = torch.std_mean(scores, dim=0)
pick = score_mean > threshold
check_features(samples, labels, pick, use_da=True, repeat=repeat)
torch.save(scores, "save.pth")