# Scanner-Generalization Neural Networks (SGNN)

In [1]:
import os
import gc
import time
import math
import random
import numpy as np
import pandas as pd
import scipy.stats as stats
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
import torch.optim as optim
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [3]:
import warnings
warnings.filterwarnings('ignore')

## Assigning GPU for training

In [4]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

In [5]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(use_cuda)

True


## Controling seed 

In [6]:
def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(0)

## Loading dataset

In [9]:
data = np.load("Sample_RSFC_p_factor_scanner.npz", allow_pickle=True)
X = data["X"]
y = data["y"]
print(X.shape, y.shape)

(20, 61776) (20, 2)


In [10]:
p_factor_idx, scanner_idx = 0, 1

In [11]:
# Training dataset
class TrainDataset(Dataset): 
    def __init__(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
        
    def __len__(self):
        return len(self.X_train)
    
    def __getitem__(self, idx): 
        X_train = torch.from_numpy(self.X_train[idx]).type(torch.FloatTensor)
        y_train = torch.from_numpy(self.y_train[idx]).type(torch.FloatTensor)

        return X_train, y_train

In [12]:
# Test dataset
class ValidDataset(Dataset): 
    def __init__(self, X_valid, y_valid):
        self.X_valid = X_valid
        self.y_valid = y_valid
        
    def __len__(self):
        return len(self.X_valid)
    
    def __getitem__(self, idx): 
        X_valid = torch.from_numpy(self.X_valid[idx]).type(torch.FloatTensor)
        y_valid = torch.from_numpy(self.y_valid[idx]).type(torch.FloatTensor)
        
        return X_valid, y_valid

In [13]:
# Test dataset
class TestDataset(Dataset): 
    def __init__(self, X_test, y_test):
        self.X_test = X_test
        self.y_test = y_test
        
    def __len__(self):
        return len(self.X_test)
    
    def __getitem__(self, idx): 
        X_test = torch.from_numpy(self.X_test[idx]).type(torch.FloatTensor)
        y_test = torch.from_numpy(self.y_test[idx]).type(torch.FloatTensor)
        
        return X_test, y_test

# Function for gradient reversal layer (Ganin et al., 2015)

In [14]:
class GradRevFunc(Function):

    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = lambda_ * grads.neg()
        return dx, None

In [15]:
class GradRev(torch.nn.Module):
    def __init__(self, lambda_=0.0):
        super(GradRev, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradRevFunc.apply(x, self.lambda_)

# Defining SGNN

In [16]:
class SGNN(nn.Module):
    def __init__(self, fe_hidden, prd_hidden, dsc_hidden,
                 dropout_fe, dropout_prd, dropout_dsc, act_func_name, lambda_):
        super(SGNN, self).__init__()
        self.fe_1 = nn.Linear(input_dim, fe_hidden)
        self.fe_bn_1 = nn.BatchNorm1d(fe_hidden)
        
        self.prd_1 = nn.Linear(fe_hidden, prd_hidden)
        self.prd_bn_1 = nn.BatchNorm1d(prd_hidden)
        self.prd_2 = nn.Linear(prd_hidden, output_prd_dim)
        
        self.dsc_1 = nn.Linear(fe_hidden, dsc_hidden)
        self.dsc_bn_1 = nn.BatchNorm1d(dsc_hidden)
        self.dsc_2 = nn.Linear(dsc_hidden, output_dsc_dim)

        self.dropout_fe = nn.Dropout(p=dropout_fe)
        self.dropout_prd = nn.Dropout(p=dropout_prd)
        self.dropout_dsc = nn.Dropout(p=dropout_dsc)
        
        self.act_func = get_act_func(act_func_name)
        self.GradRev = GradRev(lambda_)
        self.weights_init()
    
    def forward(self, x):
        x_ftr = self.fe_1(x)
        x_ftr = self.fe_bn_1(x_ftr)
        x_ftr = self.act_func(x_ftr)
        x_ftr = self.dropout_fe(x_ftr)
        
        x_prd = self.prd_1(x_ftr)
        x_prd = self.prd_bn_1(x_prd)
        x_prd = self.act_func(x_prd)
        x_prd = self.dropout_prd(x_prd)
        x_prd = self.prd_2(x_prd)
        
        x_rev = self.GradRev(x_ftr)
        x_dsc = self.dsc_1(x_rev)
        x_dsc = self.dsc_bn_1(x_dsc)
        x_dsc = self.act_func(x_dsc)
        x_dsc = self.dropout_dsc(x_dsc)
        x_dsc = self.dsc_2(x_dsc)
        
        return x_prd, x_dsc
    
    def weights_init(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                nn.init.normal_(m.bias, std=0.01)

In [17]:
def get_optimizer(model, opt_name, learning_rate=None, l2_param=None):
    lower_opt_name = opt_name.lower()
    if lower_opt_name == 'momentum':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param)
    elif lower_opt_name == 'nag':
        return optim.SGD(model.parameters(), lr=learning_rate, 
                         momentum=momentum, weight_decay=l2_param, nesterov=True)
    elif lower_opt_name == 'adam':
        return optim.Adam(model.parameters(), lr=learning_rate, 
                          weight_decay=l2_param)
    else:
        sys.exit("Illegal arguement for optimizer type")

In [18]:
def get_act_func(act_func_name):
    act_func_name = act_func_name.lower()
    if act_func_name == 'relu':
        return nn.ReLU()
    elif act_func_name == 'prelu':
        return nn.PReLU()
    elif act_func_name == 'elu':
        return nn.ELU()
    elif act_func_name == 'silu':
        return nn.SiLU()
    elif act_func_name == 'leakyrelu':
        return nn.LeakyReLU()
    elif act_func_name == 'tanh':
        return nn.Tanh()
    else:
        sys.exit("Illegal arguement for activation function type")

# Functions for weight sparsity control with Hoyer's sparsness

In [19]:
def init_hsp(n_wsc, epochs):
    hsp_val = torch.zeros(n_wsc)
    beta_val = torch.clone(hsp_val)
    hsp_list = torch.zeros((n_wsc, epochs))
    beta_list = torch.zeros((n_wsc, epochs))
    
    return hsp_val, beta_val, hsp_list, beta_list

In [20]:
# Weight sparsity control with Hoyer's sparsness (Layer wise)
def calc_hsp(w, beta, max_beta, beta_lr, tg_hsp):
    
    # Get value of weight
    [dim, n_nodes] = w.shape
    num_elements = dim * n_nodes
    norm_ratio = torch.norm(w.detach(), 1) / torch.norm(w.detach(), 2)

    # Calculate hoyer's sparsity level
    num = math.sqrt(num_elements) - norm_ratio
    den = math.sqrt(num_elements) - 1
    hsp = torch.tensor(num / den).to(device)

    # Update beta
    beta = beta.clone() + beta_lr * torch.sign(torch.tensor(tg_hsp).to(device) - hsp)
    
    # Trim value
    beta = 0 if beta < 0 else beta
    beta = max_beta if beta > max_beta else beta

    return [hsp, beta]

In [21]:
def calc_l1(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp):
    l1_reg = None
    layer_idx = 0
    wsc_idx = 0

    for name, param in model.named_parameters():
        if "weight" in name and "bn" not in name:
            if "fe" in name or "prd_1" in name or "dsc_1" in name:
                temp_w = param
                
                if wsc_flag[layer_idx] != 0:
                    hsp_val[wsc_idx], beta_val[wsc_idx] = calc_hsp(
                        temp_w, beta_val[wsc_idx], max_beta[wsc_idx], 
                        beta_lr[wsc_idx], tg_hsp[wsc_idx]
                    )
                    hsp_list[wsc_idx, epoch - 1] = hsp_val[wsc_idx]
                    beta_list[wsc_idx, epoch - 1] = beta_val[wsc_idx]
                    layer_reg = torch.norm(temp_w, 1) * beta_val[wsc_idx].clone()
                    wsc_idx += 1
                else:
                    layer_reg = torch.norm(temp_w, 1) * l1_param

                if l1_reg is None:
                    l1_reg = layer_reg
                else:
                    l1_reg = l1_reg + layer_reg
                layer_idx += 1
        
    return l1_reg

In [22]:
def calc_pearsonr(x, y):
    x_mean = torch.mean(x)
    y_mean = torch.mean(y)
    xx = x.sub(x_mean)
    yy = y.sub(y_mean)
    num = xx.dot(yy)
    den = torch.norm(xx, 2) * torch.norm(yy, 2)
    corr = num / den
    return corr

In [23]:
def calc_mae(x, y):
    return torch.abs(x - y).mean().data

In [24]:
def train(model, epoch, train_loader, optimizer, criterion_prd, criterion_dsc, 
          hsp_val, beta_val, hsp_list, beta_list, tg_hsp, lambda_, X_train, y_train):
    model.train()
    prd_loss = 0
    dsc_loss = 0
    dsc_acc = 0
    cost = 0
    total = 0
    correct = 0
    y_train_true = []
    y_train_pred = []
    
    for batch_idx, (input, target) in enumerate(train_loader):
        optimizer.zero_grad(set_to_none=True)
        input, target = input.to(device), target.to(device)
        output_prd, output_dsc = model(input)
        target_prd = target[:, p_factor_idx].view(-1, 1)
        target_dsc = target[:, scanner_idx].long().view(-1)
        running_prd_loss = criterion_prd(output_prd, target_prd)
        running_dsc_loss = criterion_dsc(output_dsc, target_dsc)
        l1_norm = calc_l1(model, epoch, hsp_val, beta_val, hsp_list, beta_list, tg_hsp)
        running_loss = running_prd_loss + running_dsc_loss + l1_norm.clone()

        cost = running_loss
        cost.backward()
        optimizer.step()
        
        prd_loss += running_prd_loss.detach()
        total += output_prd.size(0)
        true_batch = torch.flatten(target_prd.detach())
        pred_batch = torch.flatten(output_prd.detach())
        y_train_true.append(true_batch)
        y_train_pred.append(pred_batch)
        
    X_train = torch.from_numpy(X_train).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_train)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_train[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)
    
    prd_loss /= total
    dsc_loss /= total
    y_train_true = torch.flatten(torch.stack(y_train_true))
    y_train_pred = torch.flatten(torch.stack(y_train_pred))
    train_corr = calc_pearsonr(y_train_true, y_train_pred)
    train_mae = calc_mae(y_train_true, y_train_pred)
    torch.cuda.empty_cache()

    return prd_loss, dsc_loss, dsc_acc, train_corr, train_mae

In [25]:
def valid(model, epoch, valid_loader, criterion_prd, criterion_dsc, X_valid, y_valid):
    model.eval()
    prd_loss = 0
    dsc_loss = 0
    dsc_acc = 0
    correct = 0
    total = 0
    y_valid_true = []
    y_valid_pred = []
    
    with torch.no_grad():
        for input, target in valid_loader:
            input, target = input.to(device), target.to(device)
            output_prd, output_dsc = model(input)
            target_prd = target[:, p_factor_idx].view(-1, 1)
            target_dsc = target[:, scanner_idx].long().view(-1)
            running_dsc_loss = criterion_dsc(output_dsc, target_dsc)
            running_prd_loss = criterion_prd(output_prd, target_prd)
            dsc_loss += running_dsc_loss.detach()
            prd_loss += running_prd_loss.detach()
            total += output_prd.size(0)
            _, pred = torch.max(output_dsc.data, 1)
            correct += (pred.view(-1, 1) == target).sum().detach()
            true_batch = torch.flatten(target_prd.detach())
            pred_batch = torch.flatten(output_prd.detach())
            y_valid_true.append(true_batch)
            y_valid_pred.append(pred_batch)

    X_valid = torch.from_numpy(X_valid).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_valid)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_valid[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)

    y_valid_true = torch.flatten(torch.stack(y_valid_true))
    y_valid_pred = torch.flatten(torch.stack(y_valid_pred))
    valid_corr = calc_pearsonr(y_valid_true, y_valid_pred)
    valid_mae = calc_mae(y_valid_true, y_valid_pred)
    torch.cuda.empty_cache()

    return prd_loss, dsc_loss, dsc_acc, valid_corr, valid_mae

In [26]:
def test(model, epoch, test_loader, criterion_prd, criterion_dsc, X_test, y_test):
    model.eval()
    prd_loss = 0
    total = 0
    y_test_true = []
    y_test_pred = []
    
    with torch.no_grad():
        for input, target in test_loader:
            input, target = input.to(device), target.to(device)
            output_prd, output_clf = model(input)
            target_prd = target[:, p_factor_idx].view(-1, 1)
            running_prd_loss = criterion_prd(output_prd, target_prd)
            prd_loss += running_prd_loss.detach()
            total += output_prd.size(0)
            true_batch = torch.flatten(target_prd.detach())
            pred_batch = torch.flatten(output_prd.detach())
            y_test_true.append(true_batch)
            y_test_pred.append(pred_batch)

    X_test = torch.from_numpy(X_test).type(torch.FloatTensor).to(device)
    _, output_dsc = model(X_test)
    _, scnr_pred = torch.max(output_dsc.data, 1)
    scnr_pred = scnr_pred.detach().cpu().numpy().ravel()
    scnr_true = y_test[:, scanner_idx].ravel()
    dsc_acc = balanced_accuracy_score(scnr_true, scnr_pred)

    y_test_true = torch.flatten(torch.stack(y_test_true))
    y_test_pred = torch.flatten(torch.stack(y_test_pred))
    test_corr = calc_pearsonr(y_test_true, y_test_pred)
    test_mae = calc_mae(y_test_true, y_test_pred)
    torch.cuda.empty_cache()

    return prd_loss, test_corr, test_mae, dsc_acc

In [27]:
act_func_name = "elu"
optimizer_name = "nag"

fe_hidden = 100
pp_hidden = 100
scd_hidden = 100

dropout_fe = 0.5
dropout_pp = 0.5
dropout_scd = 0.5

batch_size = 4
learning_rate = 1e-04
epochs = 300

l1_param = 0
l2_param = 5e-02
lmd = 0.1

In [28]:
momentum = 0.90
input_dim = 61776
n_classes = 2
output_prd_dim = 1
output_dsc_dim = n_classes

wsc_flag = [1, 1, 1]
tg_hsp_list = [[0.7], [0.3], [0.3]]
beta_lr = [1e-04, 1e-03, 1e-03]
max_beta = [1e-02, 5e-02, 5e-02]
n_wsc = wsc_flag.count(1)

In [29]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y[:, scanner_idx], random_state=1000
)

train_dataset = TrainDataset(X_train, y_train)
test_dataset = TestDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=4, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=len(y_test), pin_memory=True, shuffle=True, num_workers=4, drop_last=True)

In [30]:
model = SGNN(
    fe_hidden, scd_hidden, pp_hidden, dropout_fe, dropout_pp, dropout_scd, act_func_name, lmd
).to(device)

In [31]:
optimizer = get_optimizer(model, optimizer_name, learning_rate, l2_param)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,  step_size=10, gamma=0.5)
criterion_prd = nn.MSELoss()
criterion_dsc = nn.CrossEntropyLoss()

In [32]:
hsp_val, beta_val, hsp_list, beta_list = init_hsp(n_wsc, epochs)

for epoch in range(1, epochs + 1):
    
    train_prd_loss, train_dsc_loss, train_acc, train_corr, train_mae = train(
        model, epoch, train_loader, optimizer, criterion_prd, criterion_dsc,
        hsp_val, beta_val, hsp_list, beta_list, tg_hsp_list, lmd,
        X_train, y_train
    )

    test_prd_loss, test_corr, test_mae, test_acc = test(
        model, epoch, test_loader, criterion_prd, criterion_dsc, X_test, y_test
    )

    if epoch % 300 == 0:
        print("\nEpoch [{:d}/{:d}]".format(epoch, epochs), end=" ")
        print("Train corr: {:.4f}, Test corr: {:.4f}, Train loss: {:.4f}, Test loss: {:.4f}"
              .format(train_corr, test_corr, train_prd_loss, test_prd_loss))
        for i in range(len(wsc_flag)):
            if wsc_flag[i] != 0:
                print("Layer {:d}: [{:.4f}/{:.4f}]".
                      format( i + 1, hsp_val[i], tg_hsp_list[i][0]), end=" ")
        print("Scanner classification accuracy : {:.2f}%".format(train_acc * 100))


Epoch [300/300] Train corr: 0.8956, Test corr: 0.6678, Train loss: 0.0529, Test loss: 1.0377
Layer 1: [0.7067/0.7000] Layer 2: [0.3037/0.3000] Layer 3: [0.3005/0.3000] Scanner classification accuracy : 43.65%
