In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import argparse
import os
from bs4 import BeautifulSoup
import requests
import json
import time
import random
from torch.utils.data import Dataset, DataLoader

from model import LabelSmoothing,NoamOpt,SimpleLossCompute,make_model,run_epoch
from util import Batch
from util import greedy_decode

In [2]:
parser = argparse.ArgumentParser(description='Transformer Task')
parser.add_argument('--dataset', type=str, default='phi',choices=['RMSD','phi','psi'])
parser.add_argument('--unidirection', type=bool, default=False)
parser.add_argument('--gpu_id', type=str, default='cuda:3')
#args = parser.parse_args()
args, unknown = parser.parse_known_args()

In [3]:
# data are provided with 1.0ps as saving interval
if args.dataset=='RMSD':
    train,valid=np.loadtxt('data/alanine/train',dtype=int),np.loadtxt('data/alanine/valid',dtype=int)
    train=train.reshape(-1,100)
    valid=valid.reshape(-1,100)
elif args.dataset=='phi':
    train,valid=np.loadtxt('data/phi-psi/train_phi_1.0ps',dtype=int),np.loadtxt('data/phi-psi/valid_phi_1.0ps',dtype=int)
elif args.dataset=='psi':
    train,valid=np.loadtxt('data/phi-psi/train_psi_1.0ps',dtype=int),np.loadtxt('data/phi-psi/valid_psi_1.0ps',dtype=int)
     
lag=1
log_dir="logs/fit/{}_{}ps/".format(args.dataset,lag)
os.makedirs(log_dir, exist_ok=True)
save_dir = 'result/{}_{}ps/'.format(args.dataset,lag)
os.makedirs(save_dir, exist_ok=True)
ckpt_dir='ckpt/training_checkpoints_{}_{}ps/'.format(args.dataset,lag)
os.makedirs(ckpt_dir, exist_ok=True)
device= torch.device(args.gpu_id)

In [4]:
def data_generator(fulldata, batch,pad):
    nbatches=int(fulldata.shape[0]//batch)
    "Generate random data for a src copy task."
    choice=[i for i in range(nbatches)]  
    random.shuffle(choice)
    for i in choice:
        data=fulldata[i*batch:(i+1)*batch]      
        src = Variable(torch.from_numpy(data), requires_grad=False)
        yield Batch(src.to(device), src.to(device), pad,uni=args.unidirection)

In [None]:
V = len(np.unique(train))
pad=V+1
criterion = LabelSmoothing(size=V, padding_idx=pad, smoothing=0.0)
model = make_model(V, V, N=2)
model.to(device)
model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir)


In [None]:
# Train the simple copy task.
V = len(np.unique(train))
pad=V+1
criterion = LabelSmoothing(size=V, padding_idx=pad, smoothing=0.0)
model = make_model(V, V, N=2)
model.to(device)
model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir)

for epoch in range(201):
    start = time.time()
    model.train()
    loss_train=run_epoch(data_generator(train, 32, pad), model, 
              SimpleLossCompute(model.generator, criterion, model_opt,device))
    elapsed = time.time() - start
    print("Epoch  %d  Time: %f" %
                    (epoch, elapsed))
    model.eval()
    loss_valid=run_epoch(data_gen(valid, 32, pad), model, 
                    SimpleLossCompute(model.generator, criterion, None,device))
    if epoch<5 or epoch%5==0:
        torch.save(model.state_dict(), ckpt_dir+'epoch{}.pt'.format(epoch))
    writer.add_scalar('Loss/train', loss_train, epoch)
    writer.add_scalar('Loss/test', loss_valid, epoch)

In [None]:
model.eval()
num_generate=10
ckpt_dir='ckpt/training_checkpoints_Ldata_phi_1ps/'
for epoch in [2,5,90,100]:
    model.load_state_dict(torch.load(ckpt_dir+'epoch{}.pt'.format(epoch)))
    for i in range(100):
        text4activation=train.reshape(-1)[i*80000:(i+1)*80000]   #for phi!
        start0 = time.time()
        src = Variable(torch.from_numpy(text4activation[-5000:]).unsqueeze(0)).to(device)
        src_mask = (src != pad).unsqueeze(-2)
        start_symbol=src[-1][-1]

        prediction=greedy_decode(model, src, src_mask, max_len=num_generate, start_symbol=start_symbol,pad=pad)
        print ('Time taken for total {} sec\n'.format(time.time() - start0))
        save=save_dir+'epoch{}/'.format(epoch)
        os.makedirs(save, exist_ok=True)
        np.savetxt(save+'prediction_'+str(i),prediction.cpu(),fmt='%i')
    print('epoch{} done'.format(epoch))