In [3]:
from __future__ import print_function, division
import numpy as np
import sys
sys.path.append('..')

import time
import torch
import pickle
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
from torch.nn import Parameter
from src.data_classes import Model, RNADataset
from src.optim_functions import get_target, loss_fn, train, test
from sklearn import decomposition, manifold
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import seaborn as sns
import math
from cmath import rect, phase
# plt.style.use('bmh')

from matplotlib import rc
plt.style.use('default')
plt.rcParams["font.family"] = "serif"
rc('text.latex', preamble=r'\usepackage{cmbright}')

In [4]:
def plot_hist(idx_dict,energies):
    fig,ax = plt.subplots(1,3,figsize=(18,5))
    for i in idx_dict.keys():
        hist1 = ax[i].hist(energies['amber'][i], bins=30, density=True, label='Amber')
        ax[i].hist(energies['hire'][i], bins=30, alpha=0.6, density=True, label='HiRE')
        ax[i].set_xlabel(idx_dict[i]+' energy', fontsize=15)
        ax[i].set_ylabel('Prob. distribution', fontsize=15)
        ax[i].set_title(idx_dict[i]+' energy distribution', fontsize=18)
        ax[i].legend(fontsize=15)
    return 0


def compare_energies(dataset,model,plot=False):
    
    idx_dict = {
        0: 'Bonds',
        1: 'Angles',
        2: 'Torsions'
    }
    energies = {'amber': [], 'hire': []}
    stats = {'amber': [], 'hire': []}
    for i in idx_dict.keys():
        amber_en = np.array([dataset[j]['features'][i,9].item() for j in range(len(dataset))])
        hire_en = np.array([model(dataset[j]).squeeze()[i].item() for j in range(len(dataset))])
        energies['amber'].append(amber_en)
        energies['hire'].append(hire_en)
        stats['amber'].append([amber_en.mean(), amber_en.var()])
        stats['hire'].append([hire_en.mean(), hire_en.var()])
        print(idx_dict[i]+' energy computed')
    stats['amber'] = np.array(stats['amber'])
    stats['hire'] = np.array(stats['hire'])
    
    if plot:
        plot_hist(idx_dict,energies)
        
    return energies,stats


def amber_dist(dataset):
    idx_dict = {
        0: 'Bonds',
        1: 'Angles',
        2: 'Torsions'
    }
    energies = []
    stats = []
    for i in idx_dict.keys():
        amber_en = np.array([dataset[j]['features'][i,9].item() for j in range(len(dataset))])
        energies.append(amber_en)
        stats.append([amber_en.mean(), amber_en.var()])
    fig,ax = plt.subplots(1,3,figsize=(18,5.5))
    for i in idx_dict.keys():
        if i == 0:
            hist1 = ax[i].hist(energies[i], bins = 50, density=True, color='steelblue', lw=0)
        else:
            hist1 = ax[i].hist(energies[i], bins = 25, density=True, color='steelblue', lw=0)
        ax[i].set_xlabel(idx_dict[i]+r' energy (kcal/mol)', fontsize=16, fontname='serif')
        ax[i].set_title(idx_dict[i]+' energy distribution', fontsize=18, fontname='serif')
        ax[i].grid(linewidth=0.2)
        textstr = '\n'.join((
        r'$\mu=%.2f$' % (stats[i][0], ),
        r'$\sigma=%.2f$' % (stats[i][1]**0.5, )))

        # these are matplotlib.patch.Patch properties
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.3)

        # place a text box in upper left in axes coords
        ax[i].text(0.65, 0.9, textstr, transform=ax[i].transAxes, fontsize=16,
                verticalalignment='top', bbox=props, fontname='serif')
        
    ax[0].set_ylabel('Prob. distribution', fontsize=16, fontname='serif')
    ax[0].set_xlim([5,30])
    plt.savefig('Images/amber_histo.png', bbox_inches='tight', dpi=100)
    return energies,stats


In [6]:
dataset = RNADataset()
model = Model()
model.load_state_dict((torch.load('../results/b4_iv_e4/200ep.pth')))
compare_energies(dataset,model,plot=True)

FileNotFoundError: [Errno 2] No such file or directory: '../results/b4_iv_e4/200ep.pth'

In [10]:
loss = np.load('../results/b4_all10_e43/loss_1000ep.npy')

(2, 5)
