In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
import os

In [39]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(16 * 16, 120) # 16*5*5 input, 120 output
        self.fc2 = nn.Linear(120, 84) # 120 input, 84 output
        self.fc3 = nn.Linear(84, 16 * 16) # 84 input, 10 output

    def forward(self, x):
        x = x.flatten(1) # flatten all dimensions except batch dimension
        x = F.relu(self.fc1(x)) # 400 -> 120
        x = F.relu(self.fc2(x)) # 120 -> 84
        x = self.fc3(x) # 84 -> 10
        x = x.view(-1, 16, 16) # reshape to 4D tensor
        return x


In [40]:
class torchAgent:
    def __init__(self,model, loss_fn, data_path: str = None, valid_path: str = None, optimizer = None, device: str = None, epoch: int = 0, model_path = None, verbose: int = 2, track_amount: int = None, **kwargs):
        # device: cpu / gpu
        if device is None:
            self.device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu" # set device
            )
        else:
            self.device = device # set device
        self.model = model.to(self.device) # set model
        self.loss_fn = loss_fn # set loss function
        self.optimizer = optimizer # set optimizer
        self.scheduler = None # set scheduler
        self.epoch = epoch # set epoch
        self.verbose = verbose # set verbose
        if model_path is None:
            self.model_path = f'model_{datetime.now().strftime("%y_%m_%d_%H%M")}' # set model path
        else:
            self.model_path = model_path

        if data_path is not None:
            self.data_path  = 'data' #data path
        else:
            self.data_path = data_path

        if valid_path is None:
            self.valid_path = 'data' #validataion path
        else:
            self.valid_path = valid_path

        if track_amount is None:
            self.track_amount = len(os.listdir(self.data_path))
        else:
            self.track_amount = track_amount

    
    def add_loss_fn(self, loss_fn):
        self.loss_fn = loss_fn

    def add_optimizer(self, optimizer, **kwargs):
        self.optimizer = optimizer(self.model.parameters(), **kwargs)

    def add_scheduler(self, scheduler, **kwargs):
        self.scheduler = scheduler(self.optimizer, **kwargs)

    def load_data(self, path: str):
        data = torch.Tensor(np.random.rand(100, 16, 16)).to(self.device)
        labels = torch.Tensor(np.random.rand(100, 16, 16)).to(self.device)

        return data, labels

    def tracks(self, validate: bool = False):
        if validate:
            self.data_path = self.valid_path
        #find all tracks in data path folder
        for track in os.listdir(self.data_path):
            yield self.load_data(track)

    def train_one_epoch(self, **kwargs):
        self.model.train(True)
        running_loss = 0.

        for i, (data, labels) in enumerate(self.tracks()):
            # Zero your gradients for every batch!
            self.optimizer.zero_grad()

            # calculate loss
            loss = self.loss_fn(self.model(data), labels)

            # backpropagation
            loss.backward()

            # update parameters
            self.optimizer.step()

            # print statistics
            running_loss += loss.item()
            print(f'Batch: [{i+1}] loss: {loss.item():.3f}, loss: {running_loss:.3f}',end='\r')

            # free memory
            del data, labels, loss
            torch.cuda.empty_cache()

        self.model.train(False)
        return running_loss/self.track_amount
    
    def validate(self, **kwargs):
        self.model.train(False)
        running_loss = 0.

        for i, (data, labels) in enumerate(self.tracks(validate=True)):
            # calculate loss
            loss = self.loss_fn(self.model(data), labels)

            # print statistics
            running_loss += loss.item()
            print(f'\nBatch: [{i+1}] Loss: {loss.item():.3f}, Total loss: {running_loss:.3f}',end='\n')

            # free memory
            del data, labels, loss
            torch.cuda.empty_cache()

        return running_loss/self.track_amount

    def train(self, **kwargs):
        best_loss = np.inf
        for epoch in range(self.epoch):
            print(f'Epoch: [{epoch+1}/{self.epoch}]')
            epoch_loss = self.train_one_epoch(**kwargs)
            print(f'Epoch: [{epoch+1}/{self.epoch}] loss: {epoch_loss:.3f}')
            valid_loss = self.validate(**kwargs)
            if best_loss > valid_loss:
                print('Saving model...')
                self.save_model()
                best_loss = valid_loss
            if self.scheduler is not None:
                self.scheduler.step()
        print('Finished Training')

    def save_model(self):
        torch.save(self.model.state_dict(), self.model_path)
        print(f'Model saved at {self.model_path}')

    def load_model(self, model_path: str):
        self.model.load_state_dict(torch.load(model_path))
        print(f'Model loaded from {model_path}')        
    
    

In [41]:
torchAgent = torchAgent(Model(), nn.MSELoss(), epoch=10, verbose=2)
torchAgent.add_optimizer(optim.SGD, lr=0.001, momentum=0.9)
torchAgent.add_scheduler(optim.lr_scheduler.StepLR, step_size=5, gamma=0.1)
torchAgent.train()

Epoch: [1/10]
Epoch: [1/10] loss: 0.339loss: 4.410

Batch: [1] Loss: 0.335, Total loss: 0.335
Saving model...
Model saved at model_23_06_05_1718
Epoch: [2/10]
Epoch: [2/10] loss: 0.026oss: 0.338

Batch: [1] Loss: 0.338, Total loss: 0.338
Epoch: [3/10]
Epoch: [3/10] loss: 0.026oss: 0.336

Batch: [1] Loss: 0.335, Total loss: 0.335
Epoch: [4/10]
Epoch: [4/10] loss: 0.026oss: 0.340

Batch: [1] Loss: 0.339, Total loss: 0.339
Epoch: [5/10]
Epoch: [5/10] loss: 0.026oss: 0.336

Batch: [1] Loss: 0.337, Total loss: 0.337
Epoch: [6/10]
Epoch: [6/10] loss: 0.026oss: 0.340

Batch: [1] Loss: 0.339, Total loss: 0.339
Epoch: [7/10]
Epoch: [7/10] loss: 0.026oss: 0.340

Batch: [1] Loss: 0.337, Total loss: 0.337
Epoch: [8/10]
Epoch: [8/10] loss: 0.026oss: 0.338

Batch: [1] Loss: 0.341, Total loss: 0.341
Epoch: [9/10]
Epoch: [9/10] loss: 0.026oss: 0.342

Batch: [1] Loss: 0.340, Total loss: 0.340
Epoch: [10/10]
Epoch: [10/10] loss: 0.026ss: 0.337

Batch: [1] Loss: 0.337, Total loss: 0.337
Finished Training