<a href="https://colab.research.google.com/github/ShinyaKatoh/Trial_Polarity/blob/main/model_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/ShinyaKatoh/Trial_Polarity
!pip install einops
!pip install torchinfo

Cloning into 'Trial_Polarity'...
remote: Enumerating objects: 16, done.[K
remote: Total 16 (delta 0), reused 0 (delta 0), pack-reused 16 (from 1)[K
Receiving objects: 100% (16/16), 46.01 MiB | 15.35 MiB/s, done.
Resolving deltas: 100% (5/5), done.
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
# Google driveにマウント
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import sys
sys.path.append('/content/Trial_Polarity')

In [5]:
import torch
import torch.nn as nn
from torchinfo import summary
from torch.utils.data import DataLoader, Dataset

import torch.nn.functional as F
import torch.nn.init as init

import os
import glob
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt

import util

In [6]:
# モデル構造の定義
#
# model1 : Neural Network
#
# model2 : Convolutional Neural Network
#
# model3 : Convolutional Neural Network + Regularization
#

import model1

model = model1.Model()

# モデル保存のディレクトリの指定
save_dir = '/content/drive/MyDrive/model1'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
#-------------------------------------------------------------------------------------------------------------------------------
# パラメータ設定
#-------------------------------------------------------------------------------------------------------------------------------

# バッチサイズ
batch_size = 128

#エポック数
num_epochs = 100

#-------------------------------------------------------------------------------------------------------------------------------
# デバイスの割り当て
#-------------------------------------------------------------------------------------------------------------------------------

device = torch.device(("cuda" if torch.cuda.is_available() else "cpu"))
print(device)

#-------------------------------------------------------------------------------------------------------------------------------
# データの読み込み
#-------------------------------------------------------------------------------------------------------------------------------

train_dataset = util.MyDataset('/content/Trial_Polarity/data/train_data.pt', '/content/Trial_Polarity/data/train_label_for_CLA.pt')
train_loader = util.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

valid_dataset = util.MyDataset('/content/Trial_Polarity/data/valid_data.pt', '/content/Trial_Polarity/data/valid_label_for_CLA.pt')
valid_loader = util.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

#-------------------------------------------------------------------------------------------------------------------------------
# 学習の実行
#-------------------------------------------------------------------------------------------------------------------------------

# モデル構造の転送
model.to(device)

# パラメータの初期化
model.apply(util.init_weights)

# モデル構造の表示
print(summary(model, input_size=(1,1,128)))

# 最適化アリゴリズムの設定
optimizer = torch.optim.Adam(model.parameters())

# 学習の実行
history = np.zeros((0,3))
history = util.fit(model, optimizer, num_epochs, train_loader, valid_loader, device, history, save_dir)

# 学習過程の保存
np.save(save_dir+'/history', history)

cpu




Layer (type:depth-idx)                   Output Shape              Param #
Model                                    [1, 3]                    --
├─Linear: 1-1                            [1, 128]                  16,512
├─ReLU: 1-2                              [1, 128]                  --
├─Linear: 1-3                            [1, 128]                  16,512
├─ReLU: 1-4                              [1, 128]                  --
├─Linear: 1-5                            [1, 128]                  16,512
├─ReLU: 1-6                              [1, 128]                  --
├─Linear: 1-7                            [1, 64]                   8,256
├─ReLU: 1-8                              [1, 64]                   --
├─Linear: 1-9                            [1, 32]                   2,080
├─ReLU: 1-10                             [1, 32]                   --
├─Linear: 1-11                           [1, 3]                    99
├─Softmax: 1-12                          [1, 3]                    

Epoch Train 1/100: 100%|██████████| 793/793 [00:07<00:00, 99.25it/s]
Epoch Valid 1/100: 100%|██████████| 100/100 [00:01<00:00, 83.37it/s]


Epoch [1/100], loss: 28.53585 acc: 0.62355 val_loss: 25.08379 val_acc: 0.67589


Epoch Train 2/100: 100%|██████████| 793/793 [00:08<00:00, 98.62it/s] 
Epoch Valid 2/100: 100%|██████████| 100/100 [00:00<00:00, 143.68it/s]


Epoch [2/100], loss: 23.36857 acc: 0.70501 val_loss: 22.04856 val_acc: 0.74030


Epoch Train 3/100: 100%|██████████| 793/793 [00:08<00:00, 92.52it/s] 
Epoch Valid 3/100: 100%|██████████| 100/100 [00:00<00:00, 155.14it/s]


Epoch [3/100], loss: 19.33633 acc: 0.78523 val_loss: 18.50635 val_acc: 0.80479


Epoch Train 4/100: 100%|██████████| 793/793 [00:07<00:00, 104.31it/s]
Epoch Valid 4/100: 100%|██████████| 100/100 [00:00<00:00, 100.08it/s]


Epoch [4/100], loss: 16.41739 acc: 0.82697 val_loss: 17.64846 val_acc: 0.81993


Epoch Train 5/100: 100%|██████████| 793/793 [00:08<00:00, 95.02it/s] 
Epoch Valid 5/100: 100%|██████████| 100/100 [00:00<00:00, 150.43it/s]


Epoch [5/100], loss: 14.61528 acc: 0.84913 val_loss: 16.26930 val_acc: 0.83467


Epoch Train 6/100: 100%|██████████| 793/793 [00:09<00:00, 86.43it/s] 
Epoch Valid 6/100: 100%|██████████| 100/100 [00:00<00:00, 160.05it/s]


Epoch [6/100], loss: 13.37855 acc: 0.86512 val_loss: 15.58292 val_acc: 0.84461


Epoch Train 7/100: 100%|██████████| 793/793 [00:08<00:00, 96.42it/s]
Epoch Valid 7/100: 100%|██████████| 100/100 [00:01<00:00, 97.45it/s]


Epoch [7/100], loss: 12.30298 acc: 0.87786 val_loss: 14.89559 val_acc: 0.85762


Epoch Train 8/100: 100%|██████████| 793/793 [00:08<00:00, 98.00it/s] 
Epoch Valid 8/100: 100%|██████████| 100/100 [00:00<00:00, 167.17it/s]


Epoch [8/100], loss: 11.30307 acc: 0.88985 val_loss: 14.91629 val_acc: 0.85754


Epoch Train 9/100: 100%|██████████| 793/793 [00:09<00:00, 87.16it/s] 
Epoch Valid 9/100: 100%|██████████| 100/100 [00:00<00:00, 159.30it/s]


Epoch [9/100], loss: 10.52497 acc: 0.89854 val_loss: 14.80032 val_acc: 0.86377


Epoch Train 10/100: 100%|██████████| 793/793 [00:07<00:00, 105.73it/s]
Epoch Valid 10/100: 100%|██████████| 100/100 [00:00<00:00, 161.90it/s]


Epoch [10/100], loss: 9.82553 acc: 0.90663 val_loss: 14.86829 val_acc: 0.86455


Epoch Train 11/100: 100%|██████████| 793/793 [00:08<00:00, 92.19it/s] 
Epoch Valid 11/100: 100%|██████████| 100/100 [00:00<00:00, 175.21it/s]


Epoch [11/100], loss: 9.22614 acc: 0.91360 val_loss: 14.63544 val_acc: 0.86968


Epoch Train 12/100: 100%|██████████| 793/793 [00:08<00:00, 91.36it/s]
Epoch Valid 12/100: 100%|██████████| 100/100 [00:01<00:00, 90.34it/s]


Epoch [12/100], loss: 8.64516 acc: 0.91941 val_loss: 14.74714 val_acc: 0.86582


Epoch Train 13/100: 100%|██████████| 793/793 [00:07<00:00, 100.71it/s]
Epoch Valid 13/100: 100%|██████████| 100/100 [00:00<00:00, 158.73it/s]


Epoch [13/100], loss: 8.17277 acc: 0.92465 val_loss: 14.97236 val_acc: 0.87007


Epoch Train 14/100: 100%|██████████| 793/793 [00:08<00:00, 91.55it/s] 
Epoch Valid 14/100: 100%|██████████| 100/100 [00:00<00:00, 152.68it/s]


Epoch [14/100], loss: 7.74865 acc: 0.92880 val_loss: 15.06662 val_acc: 0.87740


Epoch Train 15/100: 100%|██████████| 793/793 [00:08<00:00, 92.31it/s]
Epoch Valid 15/100: 100%|██████████| 100/100 [00:00<00:00, 100.78it/s]


Epoch [15/100], loss: 7.32640 acc: 0.93389 val_loss: 15.38245 val_acc: 0.87323


Epoch Train 16/100: 100%|██████████| 793/793 [00:08<00:00, 94.68it/s]
Epoch Valid 16/100: 100%|██████████| 100/100 [00:00<00:00, 141.33it/s]


Epoch [16/100], loss: 6.94249 acc: 0.93676 val_loss: 16.16341 val_acc: 0.87236


Epoch Train 17/100: 100%|██████████| 793/793 [00:09<00:00, 87.80it/s] 
Epoch Valid 17/100: 100%|██████████| 100/100 [00:00<00:00, 156.58it/s]


Epoch [17/100], loss: 6.65695 acc: 0.93963 val_loss: 16.25297 val_acc: 0.87283


Epoch Train 18/100: 100%|██████████| 793/793 [00:08<00:00, 97.59it/s]
Epoch Valid 18/100: 100%|██████████| 100/100 [00:01<00:00, 99.61it/s]


Epoch [18/100], loss: 6.32460 acc: 0.94362 val_loss: 17.15697 val_acc: 0.86455


Epoch Train 19/100: 100%|██████████| 793/793 [00:09<00:00, 84.35it/s]
Epoch Valid 19/100: 100%|██████████| 100/100 [00:00<00:00, 132.80it/s]


Epoch [19/100], loss: 6.07469 acc: 0.94521 val_loss: 16.25301 val_acc: 0.87662


Epoch Train 20/100: 100%|██████████| 793/793 [00:09<00:00, 84.95it/s]
Epoch Valid 20/100: 100%|██████████| 100/100 [00:00<00:00, 149.97it/s]


Epoch [20/100], loss: 5.75693 acc: 0.94904 val_loss: 16.91516 val_acc: 0.87307


Epoch Train 21/100: 100%|██████████| 793/793 [00:07<00:00, 100.67it/s]
Epoch Valid 21/100: 100%|██████████| 100/100 [00:01<00:00, 97.54it/s]


Epoch [21/100], loss: 5.55635 acc: 0.95091 val_loss: 17.31847 val_acc: 0.87386


Epoch Train 22/100: 100%|██████████| 793/793 [00:08<00:00, 90.25it/s] 
Epoch Valid 22/100: 100%|██████████| 100/100 [00:00<00:00, 159.24it/s]


Epoch [22/100], loss: 5.21583 acc: 0.95403 val_loss: 17.63197 val_acc: 0.87055


Epoch Train 23/100: 100%|██████████| 793/793 [00:09<00:00, 86.58it/s]
Epoch Valid 23/100: 100%|██████████| 100/100 [00:00<00:00, 159.61it/s]


Epoch [23/100], loss: 5.04537 acc: 0.95591 val_loss: 18.36902 val_acc: 0.87457


Epoch Train 24/100: 100%|██████████| 793/793 [00:07<00:00, 102.29it/s]
Epoch Valid 24/100: 100%|██████████| 100/100 [00:00<00:00, 101.69it/s]


Epoch [24/100], loss: 4.83333 acc: 0.95735 val_loss: 18.11873 val_acc: 0.87354


Epoch Train 25/100: 100%|██████████| 793/793 [00:08<00:00, 90.16it/s] 
Epoch Valid 25/100: 100%|██████████| 100/100 [00:00<00:00, 171.97it/s]


Epoch [25/100], loss: 4.64040 acc: 0.95889 val_loss: 20.48219 val_acc: 0.87401


Epoch Train 26/100: 100%|██████████| 793/793 [00:08<00:00, 94.71it/s]
Epoch Valid 26/100: 100%|██████████| 100/100 [00:00<00:00, 131.29it/s]


Epoch [26/100], loss: 4.48029 acc: 0.96101 val_loss: 18.68951 val_acc: 0.87614


Epoch Train 27/100: 100%|██████████| 793/793 [00:07<00:00, 108.02it/s]
Epoch Valid 27/100: 100%|██████████| 100/100 [00:00<00:00, 161.25it/s]


Epoch [27/100], loss: 4.24308 acc: 0.96363 val_loss: 19.73697 val_acc: 0.87165


Epoch Train 28/100: 100%|██████████| 793/793 [00:09<00:00, 86.16it/s]
Epoch Valid 28/100: 100%|██████████| 100/100 [00:00<00:00, 141.72it/s]


Epoch [28/100], loss: 4.12980 acc: 0.96436 val_loss: 19.99970 val_acc: 0.87157


Epoch Train 29/100: 100%|██████████| 793/793 [00:08<00:00, 96.67it/s]
Epoch Valid 29/100: 100%|██████████| 100/100 [00:01<00:00, 86.32it/s]


Epoch [29/100], loss: 3.93213 acc: 0.96595 val_loss: 21.16569 val_acc: 0.87283


Epoch Train 30/100: 100%|██████████| 793/793 [00:07<00:00, 102.19it/s]
Epoch Valid 30/100: 100%|██████████| 100/100 [00:00<00:00, 168.25it/s]


Epoch [30/100], loss: 3.74839 acc: 0.96796 val_loss: 21.46396 val_acc: 0.87370


Epoch Train 31/100: 100%|██████████| 793/793 [00:08<00:00, 89.26it/s] 
Epoch Valid 31/100: 100%|██████████| 100/100 [00:00<00:00, 167.40it/s]


Epoch [31/100], loss: 3.64026 acc: 0.96855 val_loss: 22.02000 val_acc: 0.87677


Epoch Train 32/100: 100%|██████████| 793/793 [00:07<00:00, 103.50it/s]
Epoch Valid 32/100: 100%|██████████| 100/100 [00:00<00:00, 136.80it/s]


Epoch [32/100], loss: 3.62133 acc: 0.96872 val_loss: 21.19200 val_acc: 0.87386


Epoch Train 33/100:  26%|██▌       | 204/793 [00:02<00:07, 82.28it/s]