In [None]:
import sys
import os
import pathlib
import math
import re
import copy
import random
import pickle
import numpy as np
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (20.0, 30.0)
#plt.ioff()

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
    import torch.optim
    
    from dl.utils.solver import Solver
    from dl.utils.utils import dist
    from dl.models.dense_factor_conv import *
except ImportError:
    pass

if os.path.exists('C:/Users/Tianle/Documents/cs231n/spring1617/my-scripts'):
    sys.path.append('C:/Users/Tianle/Documents/cs231n/spring1617/my-scripts')
    root = 'C:/Users/Tianle/Documents/cs231n/spring1617/my-scripts/'
if os.path.exists('/projects/academic/azhang/tianlema/deeplearning'):
    sys.path.append('/projects/academic/azhang/tianlema/deeplearning')
    root = '/projects/academic/azhang/tianlema/deeplearning/'
from dl.utils.gen_conv_params import reduce_projections, get_itemset
from dl.utils.sampler import *

In [None]:
class Unsqueeze(nn.Module):
    """ Unsqueeze the second dimension by default
    """
    def __init__(self, dim=1):
        super(Unsqueeze, self).__init__()
        self.dim = dim
    def forward(self, x):
        return x.unsqueeze(self.dim)

    
class Squeeze(nn.Module):
    """ Squeeze the last dimension by default
    """
    def __init__(self, dim=-1):
        super(Squeeze, self).__init__()
        self.dim = dim
    def forward(self, x):
        return x.squeeze(self.dim)

    
# Based on nn.Linear
class Linear(nn.Module):
    def __init__(self, num_layers=1):
        super(Linear, self).__init__()
        self.num_layers = num_layers
    def forward(self, projections, out_features, in_channels, out_channels, kernel_size, 
                stride, threshold, force_square, bias, nonlinearity, use_batchnorm):
        assert self.num_layers <= len(out_features)
        assert out_channels == 1, ('This module is only used for compare (Dense)FactorBlock '
                                   'model with out_channesl=1. Implement the same interface')
        in_features = len(get_itemset(projections[0].keys(), projections[0]))
        model = nn.Sequential()
        for i in range(self.num_layers):
            in_features = in_features if i==0 else out_feat
            out_feat = out_features[i+len(out_features)-self.num_layers]
            model.add_module('linear{0}'.format(i), nn.Linear(in_features, out_feat, bias=bias))
            if use_batchnorm:
                # In (Dense)FactorBlock we use nn.BatchNorm2d
                model.add_module('batchnorm{0}'.format(i), nn.BatchNorm1d(out_feat))
            # Should I add nonlinearity in the last layer?
            # This is different from (Dense)FactorBlock 
            ##To do: add more layers after (Dense)FactorBlock
            if i < self.num_layers - 1:
                model.add_module('activation{0}'.format(i), nonlinearity)
        # implement the interface for (Dense)FactorBlock whose output is (N, 1, d)
        model.add_module('unsqueeze', Unsqueeze())
        return model

    
# This will be used for comparison
Linear1 = Linear(num_layers=1)
Linear2 = Linear(num_layers=2)


# Based on FactorConv
# Do not use domain knowledge. This will increase a significant number of parameters
class Conv(nn.Module):
    def __init__(self, num_layers=1):
        super(Conv, self).__init__()
        self.num_layers = num_layers
    def forward(self, projections, out_features, in_channels, out_channels, kernel_size, 
                stride, threshold, force_square, bias, nonlinearity, use_batchnorm):
        assert self.num_layers <= len(out_features)
        assert out_channels == 1, ('This module is only used for compare (Dense)FactorBlock '
                                   'model with out_channesl=1. Implement the same interface')
        in_features = len(get_itemset(projections[0].keys(), projections[0]))
        # "Destroy" projection as if we don't know any factors, make it "fully connected"
        projections = []
        out_feats = out_features[(len(out_features)-self.num_layers):]
        for i in range(self.num_layers):
            in_features = in_features if i==0 else out_feats[i-1]
            projections.append({i: list(range(in_features)) for i in range(out_feats[i])})
        model = FactorBlock(projections, out_feats, in_channels, out_channels, kernel_size, 
                            stride, threshold, force_square, bias, nonlinearity, use_batchnorm)
        return model
    
    
Conv1 = Conv(num_layers=1)
Conv2 = Conv(num_layers=2)

    
class LinearEmbedding(nn.Module):
    def __init__(self, out_channels, in_channels=1):
        super(LinearEmbedding, self).__init__()
        self.in_channels = in_channels
        self.embed = nn.Conv1d(in_channels, out_channels, 1)
    def forward(self, x):
        # "lazy" input x of size (N, d)
        if self.in_channels == 1:
            if len(x.size()) == 1:
                x = x.unsqueeze(1)
            if len(x.size()) == 2:
                x = x.unsqueeze(1)
        assert x.size(1) == self.in_channels
        out = self.embed.forward(x)
        return out

    
class Embedding(nn.Module):
    """ Used for categorical features
    Args:
        num_features: int, should be num_levels of feature
        hidden_dim: int
    """
    def __init__(self, num_features, hidden_dim):
        super(Embedding, self).__init__()
        self.num_features = num_features
        assert num_features >= 2, '{0} < 2'.format(num_features)
        self.W_embed = nn.Parameter(torch.Tensor(num_features, hidden_dim))
        nn.init.normal(self.W_embed)
        self.register_parameter('W_embed', self.W_embed)
    def forward(self, x):
        assert (x.dim() == 1 and isinstance(x.data, torch.LongTensor) 
                and min(x.data) >= 0 and max(x.data) <= self.num_features)
        return self.W_embed[x]

    
class LossAvg(nn.Module):
    """ Combine loss
    Args:
        num: how many losses to be combined
    """
    def __init__(self, num):
        super(LossAvg, self).__init__()
        self.num = num
        self.w = nn.Parameter(torch.Tensor(num))
        nn.init.normal(self.w)
        self.register_parameter('w', self.w)
    def forward(self, x):
        assert x.dim() == 1 and x.size(0) == self.num
        tmp = self.w.exp()
        return (x * tmp / tmp.sum()).sum()
    
    
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        return x.squeeze(1)

    
class AverageMeter(object):
    def __init__(self):
        self._reset()
    
    def _reset(self):
        self.val = 0
        self.sum = 0
        self.cnt = 0
        self.avg = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt
        
def cal_acc(output, target, topk=(1,)):
    target = target.contiguous().long()
    maxk = max(topk)
    _, pred = output.topk(maxk, 1)
    res = []
    for k in topk:
        acc = pred.eq(target.view(-1,1).expand(pred.size()))[:, :k].float().view(-1).sum(0)
        acc.mul_(100 / target.size(0))
        res.append(acc)
    return res

def binsort(n, bins, start=0):
    assert sorted(bins) == bins
    mid = len(bins) // 2
    if len(bins) == 0:
        return start
    if n > bins[mid]:
        start = start + mid + 1
        return binsort(n, bins[(mid+1):], start)
    else:
        return binsort(n, bins[:mid], start)

In [None]:
EXP_TYPE='rnaseq'
INPUT_SUFFIX='3rand'
ALPHA=0.5
TRAIN=1
NORM_TYPE=0
USE_ALL_DATA=True #if True, unbalanced; otherwise, balanced
BATCH_SIZE=50
NUM_ITER=100
PRINT_EVERY=1
MODEL=Linear1
KERNEL_SIZE=10
THRESHOLD=10
WHICH_LOSS='ce'
WEIGHT_DECAY=1e-3
LOSS_INDEX=4
SEED=1
dim_age = 10
dim_iss = 15
num_partitions = int(INPUT_SUFFIX[0])
try:
    extra_partitions = int(INPUT_SUFFIX[-1])
except ValueError:
    extra_partitions = 0
check_acc_idx = [4]

In [None]:
if MODEL == FactorBlock:
    model_name = 'fb'
elif MODEL == DenseFactorBlock:
    model_name = 'dense'
elif MODEL == Linear1:
    model_name = 'Linear1'
elif MODEL == Linear2:
    model_name = 'Linear2'
elif MODEL == Conv1:
    model_name = 'Conv1'
elif MODEL == Conv2:
    model_name = 'Conv2'
else:
    raise ValueError('MODEL {0} undefined'.format(MODEL.__repr__()))

balance = 'unbal' if USE_ALL_DATA else 'bal'

ckp_path = '../checkpoints/'
script_file_prefix = ('{0}_{1}_seed{2}_norm{3}-{4}-train{5}-alpha{6}-reg{7}'.format(
        EXP_TYPE, INPUT_SUFFIX, SEED, NORM_TYPE, model_name, TRAIN, ALPHA, math.log10(WEIGHT_DECAY)))
ckp_file_prefix = ckp_path + script_file_prefix
ckp_file = ckp_file_prefix + '.pt'

In [None]:
input_path = root + 'mm-dream/processed_data/'
fileName = 'exp_features_unnormalized.pkl'
input_file = input_path + fileName
with open(input_file, 'rb') as f:
    data = pickle.load(f)
res = data['res']
bins_age = [35, 50, 60, 70, 80] 
levels_age = len(bins_age) + 1
levels_iss = 3

In [None]:
def prep_input(data, exp_type=EXP_TYPE, norm_type=0, bins=bins_age, levels_iss=3):
    x = data['exp_'+exp_type]
    # how to normalize?
    if norm_type == 1:
        x = x / np.sum(x, axis=1, keepdims=True)
    elif norm_type == 2:
        x = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True)
    elif norm_type == 3:
        x = (x - np.mean(x, axis=1, keepdims=True)) / np.std(x, axis=1, keepdims=True)
    clinical = data['clinical_'+exp_type]
    var_names = data['var_names']
    y_truth = np.array([[np.log2(float(v[var_names['D_PFS']])), np.log2(float(v[var_names['D_OS']])),
                       float(v[var_names['D_OS_FLAG']]), float(v[var_names['D_PFS_FLAG']]),
                       1 if v[var_names['HR_FLAG']] == 'TRUE' else 0] for v in clinical])
    levels_age = len(bins) + 1
    age = np.array([float(v[var_names['D_Age']]) for v in clinical if v[var_names['D_Age']] != 'NA'])
    age = np.array([float(v[var_names['D_Age']]) 
                    if v[var_names['D_Age']] != 'NA' else np.mean(age) for v in clinical])  
    age = np.array([binsort(a, bins) for a in age])

    iss = np.array([int(v[var_names['D_ISS']]) - 1 for v in clinical if v[var_names['D_ISS']] != 'NA'])
    tmp = np.bincount(iss)
    tmp = tmp / np.sum(tmp)
    iss = np.array([int(v[var_names['D_ISS']]) - 1 
              if v[var_names['D_ISS']] != 'NA' else np.random.choice([0,1,2], p=tmp)
              for v in clinical])

    x_age = Variable(torch.from_numpy(age).float())
    x_iss = Variable(torch.from_numpy(iss).float())
    x_exp = Variable(torch.from_numpy(x).float())
    y_truth = Variable(torch.from_numpy(y_truth).float())
    return x_age, x_iss, x_exp, y_truth

In [None]:
x_age, x_iss, x_exp, y_truth = prep_input(data, exp_type=EXP_TYPE, norm_type=NORM_TYPE)

In [None]:
if INPUT_SUFFIX == '{0}rand'.format(num_partitions):
    np.random.seed(SEED)
    split = [e.tolist() for e in np.array_split(np.random.permutation(y_truth.size(0)), num_partitions)]
elif INPUT_SUFFIX == '{0}rand{1}'.format(num_partitions, extra_partitions):
    np.random.seed(SEED)
    split = [e.tolist() for e in np.array_split(np.random.permutation(y_truth.size(0)), 
                                                num_partitions+extra_partitions)]
    for i in range(1, 1+extra_partitions):
        split[0] += split[i]
    for i in range(extra_partitions):
        del split[1]
elif INPUT_SUFFIX == '{0}balanced'.format(num_partitions):
    idx0 = np.where(y_truth[:,-1] == 0)[0]
    idx0 = np.array_split(idx0, num_partitions)
    idx1 = np.where(y_truth[:,-1] == 1)[0]
    idx1 = np.array_split(idx1, num_partitions)
    split = [np.concatenate([idx0[i], idx1[i]]).tolist() for i in range(num_partitions)]
elif INPUT_SUFFIX == '3natural':
    if EXP_TYPE == 'ma':
        a = list(range(133))
        b = list(range(133, 403))
        c = list(range(403, 957))
    elif EXP_TYPE == 'rnaseq':
        a = list(range(142))
        b = list(range(142, 283))
        c = list(range(283, 424))
    else:
        raise ValueError('datatype should be either ma or rnaseq')
    split = [a, b, c]
else:
        raise ValueError('split_type: {0} currently undefined'.format(INPUT_SUFFIX))

In [None]:
X = []
Y = []
for idx in split:
    idx = torch.LongTensor(idx)
    X.append([x_age[idx], x_iss[idx], x_exp[idx]])
    Y.append(y_truth[idx])

In [None]:
num_features = len(X[0])
num_losses = Y[0].size(1)

In [None]:
train = TRAIN
which_loss = WHICH_LOSS
# solver params
num_iter = NUM_ITER
batch_size = BATCH_SIZE
use_all_data = USE_ALL_DATA
print_every = PRINT_EVERY
lr = 1e-2
weight_decay = WEIGHT_DECAY
adj_lr_every = 1000
lr_decay = 4e-1

# model params
threshold=THRESHOLD
kernel_size=KERNEL_SIZE 
in_channels=1
out_channels=1
num_layers=2
stride=2 
force_square=False
bias=False
nonlinearity=nn.ReLU6() 
use_batchnorm=True

In [None]:
loss_fns = [nn.SmoothL1Loss()]*2 + [nn.CrossEntropyLoss()]*3

In [None]:
models = []
optimizers = []

for out_dim in [1]*2 + [2]*3:
    pathway_to_output = {k: list(range(len(res[0][num_layers-1]))) for k in range(out_dim)}
    projections = [res[0][1], pathway_to_output]
    
    model_age = nn.Sequential(Embedding(levels_age, dim_age), nn.Linear(dim_age, out_dim))
    optimizer_age = torch.optim.Adam(model_age.parameters(), lr=lr, weight_decay=weight_decay)
    model_iss = nn.Sequential(Embedding(levels_iss, dim_iss), nn.Linear(dim_iss, out_dim))
    optimizer_iss = torch.optim.Adam(model_iss.parameters(), lr=lr, weight_decay=weight_decay)
    model_exp = MODEL(projections, out_features=[len(v) for v in projections], in_channels=in_channels, 
                        out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                        threshold=threshold, force_square=force_square, bias=bias,
                        nonlinearity=nonlinearity, use_batchnorm=use_batchnorm)
    optimizer_exp = torch.optim.Adam(model_exp.parameters(), lr=lr, weight_decay=weight_decay)
    models.append([model_age, model_iss, model_exp])
    optimizers += [optimizer_age, optimizer_iss, optimizer_exp]
    model_avg = nn.Sequential(nn.Conv1d(out_dim, out_dim, len(models[-1])), Squeeze(-1))
    models[-1].append(model_avg)
    optimizers.append(torch.optim.Adam(model_avg.parameters(), lr=lr, weight_decay=weight_decay))
    
model_loss_avg = []
for i in range(num_losses):
    # models.shape[0] = num_losses, models.shape[1] = num_features + 1 (the last one is for model_avg) 
    model_loss_avg.append(LossAvg(num_features))
# The last one is for the final loss
model_loss_avg.append(LossAvg(num_losses))
[optimizers.append(torch.optim.Adam(m.parameters(), lr=lr, weight_decay=weight_decay)) 
 for m in model_loss_avg]

In [None]:
if os.path.exists(ckp_file):
    checkpoint = torch.load(ckp_file)
    models = checkpoint['models']
    model_loss_avg = checkpoint['model_loss_avg']
    losses_history = checkpoint['losses_history']
    losses_avg_history = checkpoint['losses_avg_history']
    acc_history = checkpoint['acc_history']
    max_min_acc = checkpoint['max_min_acc']
    num_iter_done = len(losses_avg_history)
else:
    losses_history = []
    losses_avg_history = []
    acc_history = []
    # maximize min(acc_train, acc_val, acc_test) during iteration
    # [max_acc, iteration, (acc_train, acc_val, acc_test)]
    max_min_acc = [0, 0, (0, 0, 0)]
    num_iter_done = 0

In [None]:
def get_loss(x, y_truth, model, loss_fn=nn.SmoothL1Loss(), check_acc=False, 
             use_all_data=True, batch_size=None):
    if use_all_data:
        idx = torch.LongTensor(range(x.size(0)))
    else:
        # NOT tested and not used
        assert isinstance(batch_size, int)
        idx = torch.LongTensor(
            balanced_sampler(y_truth.data.long(), batch_size=batch_size, num_iter=1)[0])
    y = model(x[idx])
    if y.dim() == 3 and y.size(1) == 1:
        # y.size() = (N, 1, d) as out_channels=1 for model_exp
        y = y.squeeze(1)
    assert y.dim() == 2, 'y.dim()={0}'.format(y.dim())
    # For classification tasks, output is N x C
    if y.size(1) > 1:
        y_truth = y_truth.long()
    loss = loss_fn(y, y_truth[idx])
    if check_acc:
        acc = cal_acc(y, y_truth[idx])[0].data[0]
    else:
        acc = None
    return loss, y, acc

In [None]:
def eval_one_iter(x, y, models, model_loss_avg, loss_fns, embedding_idx=[0,1], 
                  check_acc_idx=[4], alpha=ALPHA):
    """ eval the whole model with data (x, y)
    Args:
        x: Variable
        y: Variable
        models: num_losses * (num_features + 1) models
        model_loss_avg: num_losses + 1 models, the last one is for the final node
        loss_fns: a list of loss functions; len(loss_fns) = num_losses
        check_acc_idx: acc will be calculated for these indices
        alpha: float number in [0, 1]; y_pred = alpha*y_pred_avg + (1-alpha)*y_pred_conv
    Returns:
        loss: Variable
        loss_history: two-dimensional list: num_losses * (num_features + 2)
        loss_avg_history: one-dimensional list: num_losses + 1 (final loss)
        acc_history: one-dimensional list: len(check_acc_idx) * (num_features + 3)
    """
    
    loss_history = []
    acc_history = []
    num_losses = len(models)
    num_features = len(models[0]) - 1 # models[][-1] is nn.Conv1d layer
    assert len(model_loss_avg) == num_losses + 1 # models_loss_avg[-1] is for the final loss
    assert len(loss_fns) == num_losses
    assert isinstance(check_acc_idx, (list, tuple))
    assert alpha >= 0 and alpha <= 1
    
    loss_all = []
    for idx_loss in range(num_losses):
        loss_feature = []
        y_pred_feature = []
        check_acc = (idx_loss in check_acc_idx)
        for idx_feature in range(num_features):
            x_tmp = x[idx_feature] 
            if idx_feature in embedding_idx:
                x_tmp = x[idx_feature].long()
            loss, y_pred, acc = get_loss(x_tmp, y[:, idx_loss], 
                                         models[idx_loss][idx_feature],
                                         loss_fn=loss_fns[idx_loss], 
                                         check_acc=check_acc)
            
            if check_acc:
                acc_history.append(acc)
            loss_feature.append(loss)
            y_pred_feature.append(y_pred)
        # loss_avg is the weighted loss of loss_feature
        loss_avg = model_loss_avg[idx_loss](torch.cat(loss_feature))
        # After the next line, y_pred_feature will be: N * out_dim * num_features
        y_pred_feature = torch.stack(y_pred_feature, -1)
        loss_conv, y_pred_conv, acc = get_loss(y_pred_feature, y[:, idx_loss], 
                                               models[idx_loss][num_features],
                                               loss_fn=loss_fns[idx_loss], 
                                               check_acc=check_acc)
       
        w_tmp = model_loss_avg[idx_loss].w.exp()
        w_tmp = w_tmp/w_tmp.sum()
        # y_pred_avg: weighted sum; after next line, it will be: N * out_dim
        y_pred_avg = (y_pred_feature * w_tmp).sum(-1)   
        # y_pred: weighted mean as final prediction
        y_pred = alpha*y_pred_avg + (1-alpha)*y_pred_conv
        if check_acc:
            # +=[acc_conv, acc_avg, acc_combined]
            acc_history = acc_history + [acc, 
                                         cal_acc(y_pred_avg, y[:, idx_loss])[0].data[0], 
                                         cal_acc(y_pred, y[:, idx_loss])[0].data[0]]
        # combine loss from features
        loss_all.append(alpha*loss_avg + (1-alpha)*loss_conv)
        loss_history.append([loss.data[0] for loss in loss_feature] + 
                            [loss_avg.data[0], loss_conv.data[0]])
    # final loss
    loss = model_loss_avg[-1](torch.cat(loss_all))
    loss_avg_history = [loss.data[0] for loss in loss_all]
    loss_avg_history.append(loss.data[0])
    # add y_pred for mm-dream submission
    return loss, loss_history, loss_avg_history, acc_history, y_pred

In [None]:
init_acc_repr = []
for exp_type in ['rnaseq', 'ma']:
    x_age, x_iss, x_exp, y_truth = prep_input(data, exp_type, NORM_TYPE)
    _, _, _, acc, y_pred = eval_one_iter([x_age, x_iss, x_exp], y_truth, models, model_loss_avg, 
                             loss_fns, check_acc_idx=check_acc_idx)
    acc_repr = '{0}_acc={1}; '.format(exp_type, acc)
    init_acc_repr.append(acc_repr)

In [None]:
print('Train partition {0}, start at iteration {1}, end at {2}'
      .format(train, num_iter_done, num_iter_done + num_iter))

def zero_grad(optimizer, i, adj_lr_every, lr_decay):   
    if (i + 1) % adj_lr_every == 0:
        for i in range(len(optimizer.param_groups)):
            optimizer.param_groups[i]['lr'] = optimizer.param_groups[i]['lr'] * lr_decay
    optimizer.zero_grad()
    
for i in range(num_iter_done, num_iter_done + num_iter):
    res_loss = [eval_one_iter(x, y, models, model_loss_avg, loss_fns, check_acc_idx=check_acc_idx) 
                for x, y in zip(X, Y)]
    # len(X) = len(Y) = len(Acc) = num_partitions
    # = len(losses) = len(loss_history) = len(loss_avg_history) + len(acc)
    losses, loss_history, loss_avg_history, acc, _ = [[v[i] for v in res_loss] 
                                                   for i in range(len(res_loss[0]))]
    loss = losses[train]
    losses_history.append(loss_history)
    losses_avg_history.append(loss_avg_history)
    acc_history.append(acc)
    acc_current = [a[-1] for a in acc] 
    if i == num_iter_done:
        initial_acc = acc_current
    if min(acc_current) > max_min_acc[0]:
        max_min_acc = [min(acc_current), i, acc_current]
        print('{0}: New max_min_acc={1}'.format(i+1, max_min_acc))
        model_states = {'models': [[m.state_dict() for m in v] for v in models], 
                       'model_loss_avg': [m.state_dict() for m in model_loss_avg]}
        checkpoint = {'model_states': model_states,
                      'models': models,
                      'model_loss_avg': model_loss_avg,
                      'losses_history': losses_history,
                      'losses_avg_history': losses_avg_history,
                      'acc_history': acc_history,
                      'max_min_acc': max_min_acc}
        torch.save(checkpoint, ckp_file) 
        
    [zero_grad(optimizer, i, adj_lr_every, lr_decay) for optimizer in optimizers]
    loss.backward(retain_graph=True)
    [optimizer.step() for optimizer in optimizers]
    
    if (i + 1) % print_every == 0:
        print('{0}: acc={1}, acc.avg={2}, acc.conv={3}'.format(i+1, [a[-1] for a in acc],
                                                               [a[-2] for a in acc], 
                                                               [a[-3] for a in acc]))
        # loss_history is NOT losses_history
        print('{0}: hr_risk losses = {1}'.format(i+1, [v[-1] for v in loss_history]))
        print('{0}: final loss = {1}'.format(i+1, [v[-1] for v in loss_avg_history]))
        sys.stdout.flush() 

In [None]:
model_states = {'models': [[m.state_dict() for m in v] for v in models], 
                'model_loss_avg': [m.state_dict() for m in model_loss_avg]}
checkpoint = {'model_states': model_states,
              'models': models,
              'model_loss_avg': model_loss_avg,
              'losses_history': losses_history,
              'losses_avg_history': losses_avg_history,
              'acc_history': acc_history,
              'max_min_acc': max_min_acc}
ckp_file2 = ('{0}-n{1}.pt'.format(ckp_file_prefix, num_iter_done+num_iter))
torch.save(checkpoint, ckp_file2)

In [None]:
if os.path.exists(ckp_file):
    checkpoint = torch.load(ckp_file)
    models = checkpoint['models']
    model_loss_avg = checkpoint['model_loss_avg']
    
    ckp_file = '{0}_{1}_seed{2}_models.pt'.format(EXP_TYPE, INPUT_SUFFIX, SEED)
    torch.save({'models': models, 'model_loss_avg': model_loss_avg, 
                'entrezIds': res[1][1]}, ckp_file)

In [None]:
final_acc_repr = []
for exp_type in ['rnaseq', 'ma']:
    x_age, x_iss, x_exp, y_truth = prep_input(data, exp_type, NORM_TYPE)
    _, _, _, acc, y_pred = eval_one_iter([x_age, x_iss, x_exp], y_truth, models, model_loss_avg, 
                             loss_fns, check_acc_idx=check_acc_idx)
    acc_repr = '{0}_acc={1}; '.format(exp_type, acc)
    final_acc_repr.append(acc_repr)
    print('{0}: y_pred.size()={1}'.format(exp_type, y_pred.size()))
print('Before optimization: {0}\n{1}'.format(*init_acc_repr))
print('After optimization: {0}\n{1}'.format(*final_acc_repr))

print('Initial acc =', initial_acc)
print('Final max_min_acc =', max_min_acc)
start = num_iter_done
end = num_iter_done + num_iter
markers = ['ro', 'b+', 'g^']
plt.subplot(311)
plt.title(' max_min_acc={0}\n initial_acc={1}'.format(max_min_acc, initial_acc))
plt.ylabel('loss_final')
for idx, marker in zip(range(num_partitions), markers):
    plt.plot([v[idx][-1] for v in losses_avg_history][start:end], marker)
plt.subplot(312)
plt.title('Initial acc:{0}\n{1}'.format(*init_acc_repr))
plt.ylabel('loss_hr_risk')
for idx, marker in zip(range(num_partitions), markers):
    plt.plot([v[idx][-2] for v in losses_avg_history][start:end], marker)
plt.subplot(313)
plt.title('Final acc:{0}\n{1}'.format(*final_acc_repr))
plt.ylabel('acc_hr_risk')
for idx, marker in zip(range(num_partitions), markers):
    plt.plot([v[idx][-1] for v in acc_history][start:end], marker)
plt.xlabel('iteration')
plt.show()