In [1]:
import torch
import torch.nn as nn
import os
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datetime import datetime
torch.set_default_dtype(torch.float64)

In [2]:
class TrainData():
    def __init__(self, data_path, nsub=3, length=5):
        self.data = torch.load(data_path)
        self.nsub = nsub
        self.length = length

    def getData(self,):
        rho = torch.clone(self.data['rho'])
        ans = torch.clone(self.data['ans'])
        output = torch.clone(self.data['output_a'])
        
        return (rho, ans, output)

In [3]:
class CustomLoss(nn.Module):
    def __int__(self, ):
        super(CustomLoss,self).__init__()
    def forward(self, a, b):
        loss = torch.norm(a-b,p = 'fro')
        return loss

In [6]:
class Predictor(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Predictor,self).__init__()
        self.input = nn.Linear(input_dim,output_dim)
        self.hidden =  nn.ModuleList([ nn.Sequential( nn.Linear(output_dim, output_dim) ) for i in range(4) ])
        self.pred = nn.Linear(output_dim,output_dim)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
    def forward(self, x):
        x0 = self.input(x)
        x = x0
        x = x / torch.norm(x)
        for layer in self.hidden:
            x = self.tanh(layer(x) + x0)
            x = x / torch.norm(x)
        x = self.pred(x)
        x = self.pred(x)
        return x
    
class ConvBlock(nn.Module):
    def __init__(self, cin, cout):
        super().__init__() # necessary
        self.conv = nn.Conv2d(cin, cout, (3, 3), padding=1)
        # self.bn = nn.BatchNorm2d(cout)
        # self.relu = nn.LeakyReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Net(nn.Module):
    def __init__(self,nsub,length):
        super().__init__()
        self.features = nn.Sequential(
            ConvBlock(1,64),
            ConvBlock(64,64),
            ConvBlock(64,64),
            ConvBlock(64,64),
            ConvBlock(64,64),
        )

        self.regression = nn.Sequential(
            nn.Flatten(),
            Predictor((length**nsub)**2,length*nsub),
        )

    def forward(self, x):
        # features = self.features(x)
        y = self.regression(x)
        return y

In [7]:
class Trainer:
    def __init__(self, epochs=10000, lr=1e-2, nsub=3, length=5):
        self.log_dir = './runs/'+datetime.now().strftime('%b.%d %H-%M-%S')
        os.makedirs(self.log_dir, exist_ok=True)
        self.data = TrainData(r'C:\workspace\1_Michael\Research\NN_solveQuantumProblem\DataSet\rankone_R3_64_1.pt').getData()
        if torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'
        self.nsub = nsub
        self.length = length
        self.model = Net(nsub,length).to(self.device)
        self.criterion = CustomLoss()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.95, patience=8,eps = 10e-32, verbose=True)
        self.max_epoch = epochs

    def run(self):
        training_result_dir = self.log_dir + '/training_result'
        os.makedirs(training_result_dir)
        metrics = {'train_loss': []}
        for self.epoch in range(self.max_epoch): # epochs
            train_loss = self.train() # train 1 epoch
            print('lr:',self.get_lr(self.optimizer))
            print(f'Epoch {self.epoch:03d}:')
            print('train loss:', train_loss)
            metrics['train_loss'].append(train_loss)

            if torch.tensor(metrics['train_loss']).argmin() == self.epoch:
                torch.save(self.model.state_dict(), str(training_result_dir + '/model.pth'))
        # fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=100)
        # ax.set_title('Loss')
        # ax.plot(range(self.epoch + 1), metrics['train_loss'], label='Train')
        # ax.plot(range(self.epoch + 1), metrics['valid_loss'], label='Valid')
        # ax.legend()
        # plt.show()
        # fig.savefig(str(training_result_dir / 'metrics.jpg')) 
        # plt.close()
    def process(self,output):
        output = output.view(self.nsub, self.length)
        for i in range(output.shape[0]):
            a = torch.clone(output[i]).view(1,self.length)
            A = torch.matmul(a.t(),a)
            if i == 0:
                kron = A
            else:
                kron = torch.kron(kron,A)

        return kron
    def train(self):
        loss_steps = []
        rho,ans,output_a = self.data
        rho = rho.to(self.device)
        ans = ans.to(self.device)
        self.optimizer.zero_grad()
        out = self.model(rho)
        out = self.process(out)
        loss = self.criterion(out, ans)
        loss.backward()
        # print(loss.item())
        self.optimizer.step()
        loss_steps.append(loss.detach().item())

        avg_loss = sum(loss_steps) / len(loss_steps)
        self.scheduler.step(avg_loss) #加入scheduler
        return avg_loss

    def get_lr(self,optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

print(Trainer().run())

lr: 0.01
Epoch 000:
train loss: 0.9999210476883259
lr: 0.01
Epoch 001:
train loss: 0.9999839479524231
lr: 0.01
Epoch 002:
train loss: 0.9997117832660655
lr: 0.01
Epoch 003:
train loss: 0.9985071008292374
lr: 0.01
Epoch 004:
train loss: 0.9948215549363022
lr: 0.01
Epoch 005:
train loss: 0.9853308059684515
lr: 0.01
Epoch 006:
train loss: 0.9629633253540503
lr: 0.01
Epoch 007:
train loss: 0.9132382628628949
lr: 0.01
Epoch 008:
train loss: 0.8203938037205513
lr: 0.01
Epoch 009:
train loss: 0.7421777555414035
lr: 0.01
Epoch 010:
train loss: 0.6485674064336671
lr: 0.01
Epoch 011:
train loss: 0.48510640058614646
lr: 0.01
Epoch 012:
train loss: 0.3379050462308718
lr: 0.01
Epoch 013:
train loss: 0.26063687269941477
lr: 0.01
Epoch 014:
train loss: 0.19920713674140525
lr: 0.01
Epoch 015:
train loss: 0.19000004992278202
lr: 0.01
Epoch 016:
train loss: 0.19839083862496706
lr: 0.01
Epoch 017:
train loss: 0.21579638788262323
lr: 0.01
Epoch 018:
train loss: 0.2259403935147001
lr: 0.01
Epoch 019:
train