In [25]:
import numpy as np
import matplotlib.pyplot as plt
import torch 
from torch import nn
from torch.utils.data import TensorDataset,Dataset,DataLoader,random_split
import datetime

# prepare data

In [2]:
with open('config.npy', 'rb') as f:
    config = np.load(f)
with open('J.npy', 'rb') as f:
    J = np.load(f)
with open('energy.npy', 'rb') as f:
    energy = np.load(f)

In [16]:
data = TensorDataset(torch.tensor(J).float(),torch.tensor(energy).reshape(len(energy),1).float())

n_train = int(len(data)*0.8)
n_valid = len(data) - n_train
ds_train,ds_valid = random_split(data,[n_train,n_valid])

dl_train,dl_valid = DataLoader(ds_train,batch_size = 8),DataLoader(ds_valid,batch_size = 8)


In [4]:
for features,labels in dl_train:
    print(features.shape)
    print(labels.shape)
    break

torch.Size([8, 32])
torch.Size([8])


# prepare model

In [5]:
class Net(nn.Module): 
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(32,16)
        self.sigmoid1 = nn.Sigmoid()
        self.linear2 = nn.Linear(16,1)
        self.sigmoid2 = nn.Sigmoid()
        
    def forward(self,x):
        x = self.linear1(x)
        y = self.sigmoid1(x)
        x = self.linear2(x)
        y = self.sigmoid2(x)
        return y
        
model = Net()

In [6]:
from torchkeras import summary
summary(model,input_shape= (1000,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1             [-1, 1000, 16]             528
           Sigmoid-2             [-1, 1000, 16]               0
            Linear-3              [-1, 1000, 1]              17
           Sigmoid-4              [-1, 1000, 1]               0
Total params: 545
Trainable params: 545
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.122070
Forward/backward pass size (MB): 0.259399
Params size (MB): 0.002079
Estimated Total Size (MB): 0.383549
----------------------------------------------------------------


In [7]:
model.optimizer = torch.optim.SGD(model.parameters(),lr = 0.01)
model.loss_func = torch.nn.BCELoss()
model.metric_func = lambda y_pred,y_true: torch.mean(1-torch.abs(y_true-y_pred))
model.metric_name = "auc"

In [22]:
def train_step(model,features,labels):
    
    # 训练模式，dropout层发生作用
    model.train()
    
    # 梯度清零
    model.optimizer.zero_grad()
    
    # 正向传播求损失
    predictions = model(features)
    loss = model.loss_func(predictions,labels)
    metric = model.metric_func(predictions,labels)

    # 反向传播求梯度
    loss.backward()
    model.optimizer.step()

    return loss.item(),metric.item()

def valid_step(model,features,labels):
    
    # 预测模式，dropout层不发生作用
    model.eval()
    # 关闭梯度计算
    with torch.no_grad():
        predictions = model(features)
        loss = model.loss_func(predictions,labels)
        metric = model.metric_func(predictions,labels)
    
    return loss.item(), metric.item()


# 测试train_step效果
features,labels = next(iter(dl_train))
labels = labels.reshape(8,1)
train_step(model,features,labels)

(-17.6319637298584, -10.993547439575195)

In [9]:
labels.reshape(8,1).shape

torch.Size([8, 1])

In [27]:
def train_model(model,epochs,dl_train,dl_valid,log_step_freq):

    metric_name = model.metric_name
    print("Start Training...")
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("=========="*8 + "%s"%nowtime)

    for epoch in range(1,epochs+1):  

        # 1，训练循环-------------------------------------------------
        loss_sum = 0.0
        metric_sum = 0.0
        step = 1

        for step, (features,labels) in enumerate(dl_train, 1):

            loss,metric = train_step(model,features,labels)

            # 打印batch级别日志
            loss_sum += loss
            metric_sum += metric
            if step%log_step_freq == 0:   
                print(("[step = %d] loss: %.3f, "+metric_name+": %.3f") %
                      (step, loss_sum/step, metric_sum/step))

        # 2，验证循环-------------------------------------------------
        val_loss_sum = 0.0
        val_metric_sum = 0.0
        val_step = 1

        for val_step, (features,labels) in enumerate(dl_valid, 1):

            val_loss,val_metric = valid_step(model,features,labels)

            val_loss_sum += val_loss
            val_metric_sum += val_metric

        # 3，记录日志-------------------------------------------------
        info = (epoch, loss_sum/step, metric_sum/step, 
                val_loss_sum/val_step, val_metric_sum/val_step)

        # 打印epoch级别日志
        print(("\nEPOCH = %d, loss = %.3f,"+ metric_name + \
              "  = %.3f, val_loss = %.3f, "+"val_"+ metric_name+" = %.3f") 
              %info)
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("\n"+"=========="*8 + "%s"%nowtime)

    print('Finished Training...')
    
    return info

In [29]:
epochs = 20

history = train_model(model,epochs,dl_train,dl_valid,log_step_freq = 50)

Start Training...
[step = 50] loss: -497.770, auc: -11.168
[step = 100] loss: -495.062, auc: -11.152

EPOCH = 1, loss = -495.062,auc  = -11.152, val_loss = -491.924, val_auc = -11.085

[step = 50] loss: -497.961, auc: -11.168
[step = 100] loss: -495.255, auc: -11.152

EPOCH = 2, loss = -495.255,auc  = -11.152, val_loss = -492.114, val_auc = -11.085

[step = 50] loss: -498.147, auc: -11.168
[step = 100] loss: -495.445, auc: -11.152

EPOCH = 3, loss = -495.445,auc  = -11.152, val_loss = -492.300, val_auc = -11.085

[step = 50] loss: -498.329, auc: -11.168
[step = 100] loss: -495.630, auc: -11.152

EPOCH = 4, loss = -495.630,auc  = -11.152, val_loss = -492.482, val_auc = -11.085

[step = 50] loss: -498.507, auc: -11.168
[step = 100] loss: -495.811, auc: -11.152

EPOCH = 5, loss = -495.811,auc  = -11.152, val_loss = -492.660, val_auc = -11.085

[step = 50] loss: -498.682, auc: -11.168
[step = 100] loss: -495.989, auc: -11.152

EPOCH = 6, loss = -495.989,auc  = -11.152, val_loss = -492.835,

In [30]:
history

(20,
 -498.15185546875,
 -11.152167596817016,
 -494.96343383789065,
 -11.0851762008667)