Skip to content

ak-yoshi/EasyTorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EasyTorch

Introduction

This is a template code to reduce the implementation burden in PyTorch coding. You can examine training or evaluation easily by simply preparing model and training data.

Requirement

  • PyTorch 1.0 or later

Installation

You can install source code by using following command.

git clone git@github.com:ak-yoshi/EasyTorch.git

Usage

1. data setting to data manager

You can setup DataManager class by introducing user-defined train/eval/test dataset class. And you can also set properties (batch_size & num_workers) related to Dataloader class in PyTorch.

data_manager = DataManager.DataManager(train_data=ud_train_data, test_data=ud_test_data,
                                       batch_size=batch_size, num_workers=num_workers)

Here, it is needed to inhelit Data class at user-defined dataset class implementation. And user-defined dataset class must have specified variables(data size, data contents and target contents).

from framework import Data

class UserDefinedDataSet(Data.Data):

  # initializer
  def __init__(self, filepath):
    self._len = 0
    self._data = []
    self._target = []
    ...

2. declaration of network

Before declaration of network, you must define user-defined model, criterion, evaluator(optional) and optimizer. Here, it is needed to inhelit nn.Module class and define forward() method in user-defined module class.

import torch.nn as nn

class UserDefinedModel(nn.Module):

  # initializer
  def __init__(self, **kwargs):
    ...
  
  # forward function
  def forward(self, x):
    ...
    return x

After that, all components needed for deep learning are taken into Network class.

net = Network.Network(model=model, criterion=criterion, optimizer=optimizer, evaluator=evaluator, 
                      data_manager=data_manager, device=device, device_ids=device_ids, non_blocking=True)

3. training

You can examine training by calling train() method. The number of learning epochs can be controlled by the number of loop iterations

for _ in range(num_epoch):
  net.train()

4. evaluation/testing

You can examine evaluation/testing by calling eval() and test() method.

net.eval()
net.test()

Reference

[1] PyTorchのテンプレコードを用意してどんなデータセットにも楽々ディープラーニング

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published