<a href="https://colab.research.google.com/github/ShinyaKatoh/Trial_Polarity/blob/main/model_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/ShinyaKatoh/Trial_Polarity
!pip install einops
!pip install torchinfo

In [None]:
# Google driveにマウント
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.append('/content/Trial_Polarity')

In [None]:
import torch
import torch.nn as nn
from torchinfo import summary
from torch.utils.data import DataLoader, Dataset

import torch.nn.functional as F
import torch.nn.init as init

import os
import glob
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt

import util

In [None]:
# モデル構造の定義
#
# model1 : Neural Network
#
# model2 : Convolutional Neural Network
#
# model3 : Convolutional Neural Network + Regularization
#

import model1

model = model1.Model()

# モデル保存のディレクトリの指定
save_dir = '/content/drive/MyDrive/model1'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [None]:
#-------------------------------------------------------------------------------------------------------------------------------
# パラメータ設定
#-------------------------------------------------------------------------------------------------------------------------------

# バッチサイズ
batch_size = 128

#エポック数
num_epochs = 100

#-------------------------------------------------------------------------------------------------------------------------------
# デバイスの割り当て
#-------------------------------------------------------------------------------------------------------------------------------

device = torch.device(("cuda" if torch.cuda.is_available() else "cpu"))
print(device)

#-------------------------------------------------------------------------------------------------------------------------------
# データの読み込み
#-------------------------------------------------------------------------------------------------------------------------------

train_dataset = util.MyDataset('/content/Trial_Polarity/data/train_data.pt', '/content/Trial_Polarity/data/train_label_for_CLA.pt')
train_loader = util.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

valid_dataset = util.MyDataset('/content/Trial_Polarity/data/valid_data.pt', '/content/Trial_Polarity/data/valid_label_for_CLA.pt')
valid_loader = util.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

#-------------------------------------------------------------------------------------------------------------------------------
# 学習の実行
#-------------------------------------------------------------------------------------------------------------------------------

# モデル構造の転送
model.to(device)

# パラメータの初期化
model.apply(util.init_weights)

# モデル構造の表示
print(summary(model, input_size=(1,1,128)))

# 最適化アリゴリズムの設定
optimizer = torch.optim.Adam(model.parameters())

# 学習の実行
history = np.zeros((0,3))
history = util.fit(model, optimizer, num_epochs, train_loader, valid_loader, device, history, save_dir)

# 学習過程の保存
np.save(save_dir+'/history', history)