In [1]:
import sys, os
sys.path.insert(0, os.environ['KZ_CODE'])
from kz_code import *
sys.path.insert(0, os.environ['CCBA'])
from library.ccba import *
%matplotlib inline

Using the following packages:
matplotlib v1.5.1
numpy v1.10.4
pandas v0.18.0
rpy2 v2.7.9
scikit-learn v0.17.1
scipy v0.17.0
seaborn v0.7.0


In [None]:
# Read gene expression
gene_x_sample_df = pd.read_csv('./data/GES24759/GSE24759_series_matrix.txt', sep='\t', comment='!', index_col=0)
print('gene_x_sample_df shape: {}'.format(gene_x_sample_df.shape))

gene_x_sample_df shape: (22944, 211)


In [None]:
# NMF and select k
KS = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]  # Ks from which to select k
nmf_results, scores = nmf_and_score(gene_x_sample_df, KS, method='cophenetic_correlation', verbose=True)
#nmf_results, scores = nmf_and_score(gene_x_sample_df, KS, method='intra_inter_ratio', verbose=True)

Computing clustering score for k=2 using method cophenetic_correlation ...
Computing the cophenetic correlation coefficient ...
Score for k=2: 0.7884401821498143
Computing clustering score for k=3 using method cophenetic_correlation ...


In [None]:
# Plot reconstruction error
ax = sns.pointplot(x=list(nmf_results.keys()), y=[v['ERROR'] for v in nmf_results.values()])
ax.set(xlabel='k', ylabel='Reconstruction Error')
ax.set_title('k vs. Reconstruction Error')
plt.show()

In [None]:
ax = sns.pointplot(x=[k for k, v in scores.items()], y=[v for k, v in scores.items()])
ax.set(xlabel='k', ylabel='Mean')
ax.set_title('k vs. Mean')
plt.show()

In [None]:
# Plot W and H matrices
k = 8
ax = sns.heatmap(nmf_results[k]['W'], yticklabels=False)
ax.set(xlabel='Component', ylabel='Gene')
ax.set_title('H matrix generated using k={}'.format(k))
plt.show()

ax = sns.heatmap(nmf_results[k]['H'], xticklabels=False)
ax.set(xlabel='Sample', ylabel='Component')
ax.set_title('H matrix generated using k={}'.format(k))
plt.show()

In [None]:
# Make gene-to-info dictionary
with open('./data/GES24759/GSE24759_family.soft') as f:
    gene_info = {}
    parse_errors = {}
    gene_symbol_errors = {}
    parse = False
    for i, line in enumerate(f.readlines()):
        line = line.strip()
        if not parse:
            if line == '!platform_table_begin':
                parse = True
                print('Start parsing at line {}'.format(i))
        else:
            if line == '!platform_table_end':
                print('Stop parsing at line {}'.format(i))
                parse = False
                break
            else:
                split = line.split('\t')
                try:
                    if (split[8] == '' or split[8] =='NA'):
                        gene_symbol_errors[split[0]] = line
                        continue
                    elif split[0] in gene_info:
                        print('Duplicate at {}'.format(split[0]))
                    else:
                        gene_info[split[0]] = {'gene_title':split[7],
                                               'gene_symbol':split[8].split(' /// '),
                                               'entrez_gene':split[9],
                                               'refseq':split[10].split(' /// ')}
                except:
                    parse_errors[split[0]] = line
    print('Parse error: {}'.format(len(parse_errors)))
    print('Gene symbol error {}'.format(len(gene_symbol_errors)))

In [None]:
# TODO: figure out why the number of passing values are the same for eother_ch col in W
top_genes = {}  # dictionary(key:component; value:top genes)

percentile = 0.99
for i in range(nmf_results[k]['W'].shape[1]):
    name = 'k{}c{}'.format(k, i)
    print('Analyzing {} ...'.format(name))
    
    top_genes[name] = set()
    
    col = nmf_results[k]['W'][:, i]
    threshold = quantile(col, 100/((1-percentile)*100))[-1]
    indices_passing_threshold = np.nonzero(np.greater_equal(col, threshold)*col)[0]
    print('{} genes above {} percentile threshold.'.format(len(indices_passing_threshold), percentile*100))

    for gene in gene_x_sample_df.index[indices_passing_threshold]:
        if gene in gene_info:
            top_genes[name].update(set(gene_info[gene]['gene_symbol']))
        elif gene in parse_errors:
            print('Parse error at {}'.format(gene))
        elif gene in gene_symbol_errors:
            print('Gene symbol error at {}'.format(gene))
        else:
            print('Unknown error at {}'.format(gene))

In [None]:
for k, v in top_genes.items():
    print('*** {} ***'.format(k))
    for g in v:
        print('{}'.format(g))