In [3]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.optim as optim
from torch.autograd import Variable
import os
from tqdm import tqdm
import sys
sys.path.append("/home/wuwenjun/jupyter_code/Shannon/AlphaNet/packages/")

In [4]:
x = torch.randn(4000000,3,108)
y = torch.randn(4000000,1)

In [5]:
train_dataset = Data.TensorDataset(x, y)
batch_size = 1024
train_loader = Data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=16,
    pin_memory=True
)

In [6]:
class AlphaNet_LSTM_V1(nn.Module):
    def __init__(self, factor_num, fully_connect_layer_neural):
        super(AlphaNet_LSTM_V1, self).__init__()
        self.fc1_neuron = factor_num  # 108
        self.fc2_neuron = fully_connect_layer_neural  # 32

        # Layer
        self.batch = torch.nn.BatchNorm1d(self.fc1_neuron)
        self.lstm = nn.LSTM(self.fc1_neuron, self.fc2_neuron, 5, batch_first=True, bidirectional=True)
        self.batch2 = torch.nn.BatchNorm1d(self.fc2_neuron *2)
        self.dropout = nn.Dropout(0.3)
        #         self.relu = nn.ReLU()
        self.out = nn.Linear(self.fc2_neuron * 2, 1)

    def forward(self, x):
        x = torch.transpose(x, 1, 2)
        x = self.batch(x)
        x = torch.transpose(x, 1, 2)
        r_out, (hn, cn) = self.lstm(x)  # hn.shape: torch.Size([4, 512, 30])
        r_out = r_out[:, -1]
        # hn = hn[-1, :, :]  # torch.Size([512, 30])
        r_out = self.batch2(r_out)  # torch.Size([512, 30])
        #         hn = self.relu(hn)
        r_out = self.dropout(r_out)
        y_pred = self.out(r_out)
        return y_pred

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
alphanet = AlphaNet_LSTM_V1(108, 64)
# if torch.cuda.device_count() > 1:
#     alphanet  = torch.nn.DataParallel(alphanet,device_ids=[i for i in range(torch.cuda.device_count())]) # 默认使用所有的device_ids 
alphanet = alphanet.to(device)
total_length = 4000000

print(alphanet)

LR = 0.01
loss_function = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
epoch_num = 50
loss_list = []

min_loss = float("inf")
for epoch in tqdm(range(epoch_num)):
    total_loss = 0
    for _, (inputs, outputs) in enumerate(train_loader):
        inputs = Variable(inputs).float().to(device)
        outputs = Variable(outputs).float().to(device)
        optimizer.zero_grad() # noticed:  the grad return to zero before starting the loop
        
        # forward + backward +update
        pred = alphanet(inputs)
        pred = pred.to(device)
        loss = loss_function(pred, outputs)
        loss.backward()
        optimizer.step()
        
#         lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        total_loss += loss.item()
    total_loss = total_loss * batch_size / total_length
    print('Epoch: ', epoch + 1, ' loss: ', total_loss)
    loss_list.append(total_loss)

  0%|          | 0/50 [00:00<?, ?it/s]

AlphaNet_LSTM(
  (batch): BatchNorm1d(108, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lstm): LSTM(108, 30, num_layers=2, batch_first=True, bidirectional=True)
  (batch2): BatchNorm1d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=30, out_features=1, bias=True)
)


  2%|▏         | 1/50 [00:36<29:40, 36.34s/it]

Epoch:  1  loss:  1.0020837746582032


  4%|▍         | 2/50 [01:07<26:48, 33.52s/it]

Epoch:  2  loss:  1.0013881094207764


  6%|▌         | 3/50 [01:37<24:54, 31.80s/it]

Epoch:  3  loss:  1.00134922555542


  8%|▊         | 4/50 [02:07<23:53, 31.16s/it]

Epoch:  4  loss:  1.0013752050476075


 10%|█         | 5/50 [02:38<23:15, 31.00s/it]

Epoch:  5  loss:  1.0013266999511719


 12%|█▏        | 6/50 [03:08<22:31, 30.72s/it]

Epoch:  6  loss:  1.0013184588623047


 14%|█▍        | 7/50 [03:37<21:39, 30.21s/it]

Epoch:  7  loss:  1.00127641355896


 16%|█▌        | 8/50 [04:06<20:53, 29.85s/it]

Epoch:  8  loss:  1.0013046231842042


 18%|█▊        | 9/50 [04:35<20:13, 29.60s/it]

Epoch:  9  loss:  1.0012859051513672


 20%|██        | 10/50 [05:04<19:36, 29.40s/it]

Epoch:  10  loss:  1.0013068291320801


 22%|██▏       | 11/50 [05:34<19:02, 29.31s/it]

Epoch:  11  loss:  1.001306768081665


 24%|██▍       | 12/50 [06:03<18:31, 29.26s/it]

Epoch:  12  loss:  1.0012937840118408


 26%|██▌       | 13/50 [06:32<18:05, 29.34s/it]

Epoch:  13  loss:  1.0013226898498535


 28%|██▊       | 14/50 [07:02<17:38, 29.42s/it]

Epoch:  14  loss:  1.0013162078704834


 30%|███       | 15/50 [07:32<17:20, 29.72s/it]

Epoch:  15  loss:  1.0013025900268555


 32%|███▏      | 16/50 [08:02<16:55, 29.86s/it]

Epoch:  16  loss:  1.001305619796753


 34%|███▍      | 17/50 [08:33<16:33, 30.11s/it]

Epoch:  17  loss:  1.0012660859680176


 36%|███▌      | 18/50 [09:04<16:06, 30.21s/it]

Epoch:  18  loss:  1.0013224674072265


 38%|███▊      | 19/50 [09:34<15:34, 30.15s/it]

Epoch:  19  loss:  1.0013520905303954


 40%|████      | 20/50 [10:03<15:01, 30.05s/it]

Epoch:  20  loss:  1.0012619309234618


 42%|████▏     | 21/50 [10:33<14:28, 29.94s/it]

Epoch:  21  loss:  1.001280602508545


 44%|████▍     | 22/50 [11:03<14:00, 30.02s/it]

Epoch:  22  loss:  1.00134472164917


 46%|████▌     | 23/50 [11:33<13:29, 29.99s/it]

Epoch:  23  loss:  1.0012902372894288


 48%|████▊     | 24/50 [12:03<12:56, 29.85s/it]

Epoch:  24  loss:  1.0012722888031005


 50%|█████     | 25/50 [12:33<12:26, 29.87s/it]

Epoch:  25  loss:  1.0013096490631104


 52%|█████▏    | 26/50 [13:02<11:54, 29.77s/it]

Epoch:  26  loss:  1.0012907608642578


 54%|█████▍    | 27/50 [13:32<11:23, 29.72s/it]

Epoch:  27  loss:  1.001314271347046


 56%|█████▌    | 28/50 [14:02<10:55, 29.79s/it]

Epoch:  28  loss:  1.0012949231262207


 58%|█████▊    | 29/50 [14:32<10:30, 30.01s/it]

Epoch:  29  loss:  1.0013233465576172


 60%|██████    | 30/50 [15:02<09:59, 29.97s/it]

Epoch:  30  loss:  1.0013038831939698


 62%|██████▏   | 31/50 [15:32<09:28, 29.93s/it]

Epoch:  31  loss:  1.0013081331176759


 64%|██████▍   | 32/50 [16:02<08:57, 29.87s/it]

Epoch:  32  loss:  1.0013298102416992


 64%|██████▍   | 32/50 [16:12<09:07, 30.40s/it]


KeyboardInterrupt: 