In [1]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.optim as optim
from torch.autograd import Variable
import os
from tqdm import tqdm
import sys
sys.path.append("/home/wuwenjun/jupyter_code/Shannon/AlphaNet/packages/")

In [2]:
x = torch.randn(4000000,3,108)
y = torch.randn(4000000,1)

In [3]:
train_dataset = Data.TensorDataset(x, y)
batch_size = 1024
train_loader = Data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=16,
    pin_memory=True
)

In [4]:
class AlphaNet_LSTM_V1(nn.Module):
    def __init__(self, factor_num, fully_connect_layer_neural):
        super(AlphaNet_LSTM_V1, self).__init__()
        self.fc1_neuron = factor_num  # 108
        self.fc2_neuron = fully_connect_layer_neural  # 32

        # Layer
        self.batch = torch.nn.BatchNorm1d(self.fc1_neuron)
        self.lstm = nn.LSTM(self.fc1_neuron, self.fc2_neuron, 5, batch_first=True, bidirectional=True)
        self.batch2 = torch.nn.BatchNorm1d(self.fc2_neuron *2)
        self.dropout = nn.Dropout(0.3)
        #         self.relu = nn.ReLU()
        self.out = nn.Linear(self.fc2_neuron * 2, 1)

    def forward(self, x):
        x = torch.transpose(x, 1, 2)
        x = self.batch(x)
        x = torch.transpose(x, 1, 2)
        r_out, (hn, cn) = self.lstm(x)  # hn.shape: torch.Size([4, 512, 30])
        r_out = r_out[:, -1]
        # hn = hn[-1, :, :]  # torch.Size([512, 30])
        r_out = self.batch2(r_out)  # torch.Size([512, 30])
        #         hn = self.relu(hn)
        r_out = self.dropout(r_out)
        y_pred = self.out(r_out)
        return y_pred

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
alphanet = AlphaNet_LSTM_V1(108, 64)
# if torch.cuda.device_count() > 1:
#     alphanet  = torch.nn.DataParallel(alphanet,device_ids=[i for i in range(torch.cuda.device_count())]) # 默认使用所有的device_ids 
alphanet = alphanet.to(device)
total_length = 4000000

print(alphanet)

LR = 0.01
loss_function = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(alphanet.parameters(), lr=LR)
epoch_num = 50
loss_list = []

min_loss = float("inf")
for epoch in tqdm(range(epoch_num)):
    total_loss = 0
    for _, (inputs, outputs) in enumerate(train_loader):
        inputs = Variable(inputs).float().to(device)
        outputs = Variable(outputs).float().to(device)
        optimizer.zero_grad() # noticed:  the grad return to zero before starting the loop
        
        # forward + backward +update
        pred = alphanet(inputs)
        pred = pred.to(device)
        loss = loss_function(pred, outputs)
        loss.backward()
        optimizer.step()
        
#         lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        total_loss += loss.item()
    total_loss = total_loss * batch_size / total_length
    print('Epoch: ', epoch + 1, ' loss: ', total_loss)
    loss_list.append(total_loss)

  0%|          | 0/50 [00:00<?, ?it/s]

AlphaNet_LSTM_V1(
  (batch): BatchNorm1d(108, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lstm): LSTM(108, 64, num_layers=5, batch_first=True, bidirectional=True)
  (batch2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=128, out_features=1, bias=True)
)


  2%|▏         | 1/50 [00:40<33:03, 40.49s/it]

Epoch:  1  loss:  0.9992210764007569


  4%|▍         | 2/50 [01:39<41:16, 51.60s/it]

Epoch:  2  loss:  0.9990657085418702


  6%|▌         | 3/50 [02:45<45:28, 58.06s/it]

Epoch:  3  loss:  0.9991338519744873


  8%|▊         | 4/50 [03:51<46:53, 61.16s/it]

Epoch:  4  loss:  0.9990270543060302


 10%|█         | 5/50 [04:57<47:03, 62.75s/it]

Epoch:  5  loss:  0.9990491723022461


 12%|█▏        | 6/50 [06:02<46:40, 63.64s/it]

Epoch:  6  loss:  0.9990332359008789


 14%|█▍        | 7/50 [07:07<45:57, 64.12s/it]

Epoch:  7  loss:  0.9991233188781738


 16%|█▌        | 8/50 [08:12<45:00, 64.31s/it]

Epoch:  8  loss:  0.9990963634185791


 18%|█▊        | 9/50 [09:17<44:03, 64.48s/it]

Epoch:  9  loss:  0.9990708125915527


 20%|██        | 10/50 [10:21<43:01, 64.54s/it]

Epoch:  10  loss:  0.9990506929473877


 22%|██▏       | 11/50 [11:26<41:57, 64.56s/it]

Epoch:  11  loss:  0.9990495448455811


 24%|██▍       | 12/50 [12:31<40:54, 64.59s/it]

Epoch:  12  loss:  0.9990593813171387


 24%|██▍       | 12/50 [13:33<42:56, 67.80s/it]


KeyboardInterrupt: 