# Standart

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm_notebook

In [None]:
class TrainWrapper:
    
    def __init__(self, model, train_dl, test_dl, epoch, save_path, verbose=True, **karg)
        self.model = model
        self.train_dl = train_dl
        self.test_dl = test_dl
        self.verbose = verbose
        self.save_path = save_path
        self.epoch = epoch
        self.optimizer = karg["optimizer"]
        self.loss_func = karg["loss_func"]
        self.loss_val_list = []
        self.min_loss = 9999999
    
    def train(self)
        if self.verbose:
            train_dl = tqdm_notebook(self.train_dl)
            
        for epoch in range(self.epoch):
            if self.verbose:
                print(f"Epoch: {epoch}")
            self.model.train()
            epoch_loss = 0
            for data_list in train_dl:
                vectors, labels = self.prepare_data(data_list)
                output = self.model(data)
                
                loss_val = self.loss_func(output, labels)
                loss_val.backward()
                epoch_loss += loss_val.item()
                
                optimizer.step()
                
            if epoch_loss < self.min_loss:
                    torch.save(model.state_dict(), self.save_path)
                    
            self.evaluate()
    
    def prepare_data(self, data):
        pass
    
    def evaluate(self):
        pred_list = []
        true_label = []
        self.model.eval()
        
        if self.verbose:
            test_dl = tqdm_notebook(self.test_dl)
            
        with torch.no_grad():
            for data_list in test_dl:
                vectors, labels = self.prepare_data(data_list)
                true_label.extend(labels)
                
                output = self.model(data)
                prediction = torch.argmax(output).item()
                pred_list.append(prediction)
        
        score = f1_score(true_label, pred_list, average="macro")
        print("F1 macro: {score}")
        
    def load_model(self):
        data = torch.load(self.save_path)
        self.model = self.model.load_state_dict(data)