In [11]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

import time
import logging
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [12]:
rate = torch.from_numpy(np.load('./data/elementary_reaction_rates.npy')[::5, :])
x_dot = torch.from_numpy(np.load('./data/net_production_rates.npy')[::5, :])
mat = torch.from_numpy(np.load('./data/stoichiometric_matrix.npy')).to('cuda')

dataset = TensorDataset(rate, x_dot)
data_loader = DataLoader(
    dataset=dataset,
    batch_size=65536,
    shuffle=True,
    num_workers=8,
)

In [25]:
def plot_record(log, epoch):
    plt.figure(figsize=(24, 6))
    plt.subplot(1, 3, 1)
    plt.plot(np.arange(epoch), log[:epoch, 0])
    plt.title('Regression Loss')
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Regression Loss')

    plt.subplot(1, 3, 2)
    plt.plot(np.arange(epoch), log[:epoch, 1])
    plt.title('Sparse Loss')
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Sparse Loss')

    plt.subplot(1, 3, 3)
    plt.plot(np.arange(epoch), log[:epoch, 2])
    plt.title('Total Loss')
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Total Loss')

    plt.savefig('./figs/loss_record_9.pdf')
    plt.savefig('./figs/loss_record_9.svg')
    # plt.show()
    plt.close()

In [26]:
def plot_weight(weight):
    plt.figure(figsize=(8,5))
    plt.bar(np.arange(weight.shape[0]), weight)
    plt.xlabel('Reaction Index')
    plt.ylabel('Sparse Weight')
    plt.ylim([0., 1.])

    plt.tight_layout()
    plt.savefig('./figs/weight_9.pdf')
    plt.savefig('./figs/weight_9.svg')
    # plt.show()
    plt.close()

In [27]:
class sparse_model(torch.nn.Module):
    def __init__(self, mat):
        super().__init__()
        self.mat = torch.nn.Parameter(mat)
        self.sparse_weight = torch.Tensor(mat.shape[0]).uniform_(-1, 1)
        self.sparse_weight = torch.nn.Parameter(self.sparse_weight)

    def forward(self, rate):
        x = rate * F.sigmoid(self.sparse_weight) @ self.mat 
        return x

In [28]:
model = sparse_model(mat).to('cuda')
model = torch.nn.DataParallel(model, device_ids=[0,1,2,3])#必须从零开始(这里0表示第1块卡，1表示第2块卡.)

optimizer = torch.optim.Adam(
    [model.module.sparse_weight],
    lr=5e-3,
    betas=(0.9, 0.99),
    weight_decay=0,
    eps=1e-15,
    amsgrad=False,
)

num_epochs = 2000
lambda_1 = 0.9
loss_record = torch.zeros((num_epochs, 3))
start_optim = time.time()
for epoch in range(num_epochs):
    optimizer.zero_grad()
    running_loss = 0.
    for batch_idx, (rate, x_dot) in enumerate(data_loader):
        rate = rate.to('cuda')
        x_dot = x_dot.to('cuda')
        reduced_x_dot = model(rate) # rate * F.sigmoid(sparse_weight) @ mat
        weight = rate.pow(2).sum(0) / rate.pow(2).sum()

        regression_loss = ((reduced_x_dot - x_dot) / x_dot.abs().max(dim=1)[0].unsqueeze(1)).norm(p=2, dim=1).mean() * 100
        sparse_loss = (F.sigmoid(model.module.sparse_weight)).mean() * 100
        loss = (1 - lambda_1) * regression_loss + lambda_1 * sparse_loss

        loss_record[epoch, 0] += regression_loss.item()
        loss_record[epoch, 1] += sparse_loss.item()
        loss_record[epoch, 2] += loss.item()

        loss.backward()
        optimizer.step()

        print(f"Epoch : {epoch:3}, batch : {batch_idx+1:2}, "
              f"loss : {loss.item():>10.5f}, "
              f"regression loss : {regression_loss.item():>10.5f}, "
              f"sparse loss : {sparse_loss.item():>10.5f}, "
              f"time : {time.time()-start_optim:>10.5f}.")

    if (epoch + 1) % 100 == 0:
        np.save('./data/sparse_weight_9.npy', F.sigmoid(model.module.sparse_weight).data.cpu().numpy())
        np.save('./data/loss_record_9.npy', loss_record)
        plot_weight(F.sigmoid(model.module.sparse_weight).data.cpu().numpy())
        plot_record(loss_record, epoch)

Epoch :   0, batch :  1, loss :   53.36107, regression loss :   77.67509, sparse loss :   50.65952, time :    3.43388.
Epoch :   1, batch :  1, loss :   53.24986, regression loss :   77.54098, sparse loss :   50.55084, time :    7.00328.
Epoch :   2, batch :  1, loss :   53.13866, regression loss :   77.40723, sparse loss :   50.44216, time :   10.42152.
Epoch :   3, batch :  1, loss :   53.02751, regression loss :   77.27373, sparse loss :   50.33349, time :   13.73198.
Epoch :   4, batch :  1, loss :   52.91641, regression loss :   77.14036, sparse loss :   50.22486, time :   16.93625.
Epoch :   5, batch :  1, loss :   52.80532, regression loss :   77.00706, sparse loss :   50.11625, time :   20.14756.
Epoch :   6, batch :  1, loss :   52.69430, regression loss :   76.87392, sparse loss :   50.00767, time :   23.41027.
Epoch :   7, batch :  1, loss :   52.58330, regression loss :   76.74098, sparse loss :   49.89912, time :   26.73761.
Epoch :   8, batch :  1, loss :   52.47234, regr