In [1]:
from __future__ import print_function, division
import os
import pickle
import pandas as pd
import torch
import time
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.nn import Parameter
from numba import cuda
print(torch.cuda.is_available())
import LocalEnergyVct as le

# ignore warnings
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

False


In [2]:
def get_target(X):
    if len(X['features'].shape) == 2:
        X['features'] = X['features'].unsqueeze(0)
    # print(torch.sum(X['features'][:,0:3,9],dim=1))
    target = (X['features'][:,0:3,9]).to(device)  # /X['features'].shape[0]).squeeze().to(device)
    return target

In [3]:
class RNASeqDataset(Dataset):
    """RNA sequences dataset."""

    def __init__(self, device, csv_file='data/SeqCSV/seq_frame.csv', root_dir='data/SeqCSV/', transform=None):
        self.seq_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        # self.transform = transform
        size = len(self.seq_frame)
        lengths = self.seq_frame.iloc[:, 1:].astype('int64')
        lengths = torch.from_numpy(np.array(lengths)).to(device)

        # get features size
        seq_name = os.path.join(self.root_dir, self.seq_frame.iloc[0, 0] + '.csv')
        features = pd.read_csv(seq_name)
        row, col = np.array(features).shape

        features = torch.zeros(size,row,col)
        for i in range(size):
            seq_name = os.path.join(self.root_dir, self.seq_frame.iloc[i, 0] + '.csv')
            seq = pd.read_csv(seq_name)
            features[i,:,:] = torch.from_numpy(np.array(seq))
        features = features.to(device)
        self.dataset = {'lengths': lengths, 'features': features}

    def __len__(self):
        return len(self.seq_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        lengths = self.dataset['lengths'][idx]
        features = self.dataset['features'][idx]
        sample = {'lengths': lengths, 'features': features}

        return sample

In [4]:
class LocalEnergyOpt(nn.Module):

    def __init__(self,fixed_pars,opt_pars):
        super(LocalEnergyOpt, self).__init__()
        self.opt_pars = Parameter(torch.tensor(opt_pars, dtype=torch.float, device=device, requires_grad=True))
        self.bond_type = Parameter(torch.tensor(fixed_pars['bond_type'], dtype=torch.float, device=device, requires_grad=True))
        self.angle_type = Parameter(torch.tensor(fixed_pars['angle_type'], dtype=torch.float, device=device, requires_grad=True))
        self.tor_type = Parameter(torch.tensor(fixed_pars['torsion_type'], dtype=torch.float, device=device, requires_grad=True))

    def forward(self,X):

        X_lengths = X['lengths']
        X_features = X['features']

        if len(X_lengths.shape) == 1:
            X_lengths = X_lengths.unsqueeze(0)
            X_features = X_features.unsqueeze(0)

        energy = torch.zeros(X_lengths.shape[0],3).to(device)

        for i in range(X_lengths.shape[0]):
            lengths = X_lengths[i]
            features = X_features[i]
            if torch.is_tensor(lengths):
                lengths = lengths.tolist()
            atoms = features[:lengths[0],0].long()
            # res_labels
            # res_pointer
            # mass
            # charge
            coords = features[:lengths[5],5].view(-1,3)
            bonds = features[:lengths[6],6].long().view(-1,3)
            angles = features[:lengths[7],7].long().view(-1,4)
            tors = features[:lengths[8],8].long().view(-1,5)  # all indexes: not necessary to convert to tensors
            energy[i,0] = le.bonds_energy(coords,bonds,self.bond_type,self.opt_pars)
            energy[i,1] = le.angles_energy(atoms,coords,angles,self.angle_type,self.opt_pars)
            energy[i,2] = le.torsions_energy(atoms,coords,tors,self.tor_type,self.opt_pars)

        return energy

In [51]:
def loss_with_grad(pred,target,model,lc=0.1):
    batch_size = pred.shape[0]
    grad2 = 0.
    for en in pred.view(-1,):
        grad_list = torch.autograd.grad(en, model.parameters(), create_graph=True)
        for t in grad_list:
            grad2 += t.pow(2).sum().squeeze()
    # print((pred - target).pow(2).sum(), lc*grad2)
    loss = ((pred - target).pow(2).sum() + lc*grad2)/ batch_size 
    return loss

In [155]:
def loss_with_grad2(pred,target,model,lc=0.05):
    batch_size = pred.shape[0]
    grad2 = 0.
    pred_split = torch.split(pred.view(-1,),1)
    grad_list = torch.autograd.grad(pred_split, model.parameters(), create_graph=True)
    for t in grad_list:
        grad2 += t.pow(2).sum().squeeze()
    # print((pred - target).pow(2).sum(), lc*grad2)
    loss = ((pred - target).pow(2).sum() + lc*grad2)/ batch_size 
    return loss

In [132]:
def reshape_parameters(model):
    pars_list = []
    for p in model.parameters():
        pars_list.append(p.reshape(-1,))
    pars_list = torch.cat(pars_list)
    model.parameters = pars_list
    return 

In [146]:
def train(dataloader, model, loss_fn, optimizer):
    # size = len(dataloader.dataset)
    # num_batches = len(dataloader)
    model.train()
    num_batches = 0
    train_loss = 0

    for X in dataloader:
                   
        # Compute prediction error
        pred = model(X)
        target = get_target(X)
        # pars_list = reshape_parameters(model)
        loss = loss_fn(pred, target, model)
        
        if torch.isnan(loss):
            continue
        num_batches += 1
        train_loss += loss.item()

        # Backpropagation   
        optimizer.zero_grad()  
        loss.backward(retain_graph=True)
        optimizer.step()

    train_loss /= num_batches
    print(f'Avg loss = {train_loss:>0.4f}, valid batches = {num_batches}')

    return train_loss

In [147]:
def test(dataloader, model, loss_fn):
    # num_batches = len(dataloader)
    model.eval()
    num_batches = 0
    test_loss = 0
    for X in dataloader:
        pred = model(X)
        target = get_target(X)
        # pars_list = reshape_parameters(model)
        loss = loss_fn(pred, target, model)
        if torch.isnan(loss):
            continue
        num_batches += 1
        test_loss += loss
    test_loss /= num_batches
    print(f'Avg test_loss = {test_loss:>0.4f}, valid batches = {num_batches}')
    return test_loss

In [160]:
seq_data = RNASeqDataset(device=device)
print(f'dataset allocated on {device}')

tot_length = len(seq_data)
set_length = int(0.2*tot_length)
train_set, test_set = random_split(seq_data, [tot_length - set_length, set_length], generator=torch.Generator().manual_seed(42))
print(len(train_set))
print(len(test_set))

batch_size = 1
train_dataloader = DataLoader(train_set,batch_size=batch_size,shuffle=True,num_workers=1,pin_memory=True)
test_dataloader = DataLoader(test_set,batch_size=batch_size,shuffle=True,num_workers=1,pin_memory=True)

fixed_pars = pickle.load(open('data/SeqCSV/fixed_pars.p', 'rb'))
opt_pars = pickle.load(open('data/SeqCSV/pars.p', 'rb'))

model = LocalEnergyOpt(fixed_pars,opt_pars).to(device)

dataset allocated on cpu
1327
331


In [161]:
lr = 1e-7
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = 500, cooldown = 1000, threshold = 1e-12, verbose = True)
loss_fn = loss_with_grad2

epochs = 100
train_loss = []
test_loss = []
for index_epoch in range(epochs):
    print(f'epoch {index_epoch+1}/{epochs} \n-------------------------')
    t0 = time.time()
    train_tmp = train(train_dataloader, model, loss_fn, optimizer)
    test_tmp = test(test_dataloader, model, loss_fn)    
    train_loss.append(train_tmp)
    test_loss.append(test_tmp)
    tf = time.time()
    print(f'time for epoch: {tf-t0} \n')
    
for p in model.parameters():
    print(p.data)

epoch 1/100 
-------------------------
Avg loss = 18799.4356, valid batches = 1313
Avg test_loss = 9049.5068, valid batches = 329
time for epoch: 5.540524959564209 

epoch 2/100 
-------------------------
Avg loss = 8572.6342, valid batches = 1313
Avg test_loss = 7402.9292, valid batches = 329
time for epoch: 5.205782651901245 

epoch 3/100 
-------------------------
Avg loss = 7505.0285, valid batches = 1313
Avg test_loss = 6858.7407, valid batches = 329
time for epoch: 5.604363203048706 

epoch 4/100 
-------------------------
Avg loss = 6863.6918, valid batches = 1313
Avg test_loss = 6449.8589, valid batches = 329
time for epoch: 5.48005747795105 

epoch 5/100 
-------------------------
Avg loss = 6390.5934, valid batches = 1313
Avg test_loss = 5983.9355, valid batches = 329
time for epoch: 5.252687454223633 

epoch 6/100 
-------------------------
Avg loss = 6047.6292, valid batches = 1313
Avg test_loss = 5701.5430, valid batches = 329
time for epoch: 5.051347017288208 

epoch 7/10

Avg test_loss = 2742.8506, valid batches = 329
time for epoch: 5.205306529998779 

epoch 51/100 
-------------------------
Avg loss = 2753.7656, valid batches = 1313
Avg test_loss = 2720.9739, valid batches = 329
time for epoch: 5.505064487457275 

epoch 52/100 
-------------------------
Avg loss = 2727.1018, valid batches = 1313
Avg test_loss = 2695.9497, valid batches = 329
time for epoch: 5.206534385681152 

epoch 53/100 
-------------------------
Avg loss = 2704.6116, valid batches = 1313
Avg test_loss = 2672.0627, valid batches = 329
time for epoch: 5.154438018798828 

epoch 54/100 
-------------------------
Avg loss = 2680.2198, valid batches = 1313
Avg test_loss = 2645.6323, valid batches = 329
time for epoch: 5.213413715362549 

epoch 55/100 
-------------------------
Avg loss = 2655.8350, valid batches = 1313
Avg test_loss = 2624.6897, valid batches = 329
time for epoch: 5.293657064437866 

epoch 56/100 
-------------------------
Avg loss = 2632.8357, valid batches = 1313
Avg 

Avg loss = 2075.3333, valid batches = 1313
Avg test_loss = 2071.6975, valid batches = 329
time for epoch: 5.345065593719482 

tensor([2.0742e-01, 9.6970e-01, 1.4935e+00, 2.2974e+00, 4.1144e+00, 3.1898e-01,
        1.6127e+00, 3.6594e+00, 5.1664e-02, 2.0307e+00, 1.5223e+01, 1.0000e+00,
        2.8000e+00, 2.5050e+00, 1.8260e+00, 3.9320e+00, 4.3090e+00, 4.7750e+00,
        4.5460e+00, 2.8210e+00, 3.8130e+00, 3.0100e+00, 9.0800e-01, 3.0000e+00,
        4.0000e+00, 2.2570e+00, 4.8000e-01, 5.0000e-01, 3.6814e+00, 1.0790e+01,
        1.0912e+01, 4.9256e+00, 3.9731e-01, 6.4252e-01, 4.2243e-01, 4.8386e-01,
        4.1908e-01, 3.3597e-01, 1.2000e+00, 1.5000e+00, 4.0000e-01, 1.8000e+00,
        8.0000e-01, 1.4231e+02, 1.0000e+00, 0.0000e+00, 2.6810e-01])
tensor([[ 29.9668,   3.5824],
        [199.9992,   2.3451],
        [199.9999,   2.6474],
        [200.0000,   2.6466],
        [200.0000,   3.0695],
        [200.0000,   3.0095],
        [200.0000,   2.4636],
        [200.0000,   2.1910],
     

In [154]:
torch.save(model.state_dict(), 'data/Results/model_withgrad_pars.pth')


tensor([ 1.2747e-01,  6.5985e-01,  1.4806e+00,  2.2628e+00,  4.0585e+00,
         8.3665e-02,  9.9526e-01,  2.8479e+00, -1.9098e-01,  2.3038e+00,
         1.5223e+01,  1.0000e+00,  2.8000e+00,  2.5050e+00,  1.8260e+00,
         3.9320e+00,  4.3090e+00,  4.7750e+00,  4.5460e+00,  2.8210e+00,
         3.8130e+00,  3.0100e+00,  9.0800e-01,  3.0000e+00,  4.0000e+00,
         2.2570e+00,  4.8000e-01,  5.0000e-01,  3.4612e+00,  1.0763e+01,
         1.0823e+01,  4.0378e+00,  3.7675e-01,  6.0168e-01,  4.3745e-01,
         4.9668e-01,  4.2787e-01,  3.3753e-01,  1.2000e+00,  1.5000e+00,
         4.0000e-01,  1.8000e+00,  8.0000e-01,  1.4231e+02,  1.0000e+00,
         0.0000e+00,  2.4391e-01])
tensor([[ 29.9731,   3.6991],
        [199.9973,   2.3438],
        [199.9998,   2.6474],
        [200.0000,   2.6459],
        [200.0000,   3.0669],
        [199.9999,   3.0081],
        [199.9999,   2.4630],
        [200.0000,   2.1899],
        [200.0000,   1.5242],
        [200.0000,   1.6067],
        