In [1]:
import models
import torch
import os
from torch.utils.data import DataLoader
import pandas as pd
from dataclasses import dataclass
from dataset import Data
from tqdm import tqdm
from torcheval.metrics.functional import peak_signal_noise_ratio


@dataclass
class TrainConfigs:
    learning_rate: float = 1e-4
    epochs: int = 10
    batch_size: int = 1
    save_every: int = 2
    crop_size: int = (300,300)
    
    save_path: str = 'train_files'
    images_path: str = 'images'
    train_csv_path: str = 'train.csv'
    test_csv_path: str = 'test.csv'
    
    optimizer: torch.optim = torch.optim.Adam
    
    def __post_init__(self):
        self.train_df = pd.read_csv(self.train_csv_path)
        self.test_df = pd.read_csv(self.test_csv_path)
        
        
class Trainer(object):

    def __init__(self,
                 configs: TrainConfigs,
                 model: models.BaseSRCNN,
                 optim: torch.optim):
        """Trainer object.

        Args:
            configs (TrainConfigs): configurations of training.
            model (models.BaseSRCNN): model to trained.
            optim (torch.optim): uninitiated optimizer object.
        """
        self.configs = configs
        self.model = model
        
        # Create train and test dataset objects.
        self.train_dataset = Data(self.configs.images_path,
                                  self.configs.train_df,
                                  1/self.model.args.upscale_factor,
                                  self.configs.crop_size)
        self.test_dataset = Data(self.configs.images_path,
                                 self.configs.test_df,
                                 1/self.model.args.upscale_factor,
                                 self.configs.crop_size)
        
        # Dataloaders
        self.train_loader = DataLoader(self.train_dataset,
                                       self.configs.batch_size,
                                       shuffle=True)
        self.test_loader = DataLoader(self.test_dataset,
                                      self.configs.batch_size,
                                      shuffle=False)
        
        # Create optimizer object.
        self.optim = optim(self.model.parameters(),lr = self.configs.learning_rate)
        
        # Create history for plotting later.
        self.train_history = []
        self.test_history = []
        self.psnr_train = []
        self.psnr_test = []
        
        # Epoch tracker
        self.epoch = 0
        
    def train(self):
        """Train function for our model.
        """
        for epoch in range(self.configs.epochs):
            loop = tqdm(enumerate(self.train_loader))
            
            # Train
            train_loss = 0.0
            train_psnr = 0.0
            for i, (image,image_downscaled) in loop:
                loss,upscaled_image = self.model.train_step(image_downscaled,image)
                self.model.zero_grad()
                loss.backward()
                self.optim.step()
                # Calculate psnr
                psnr = peak_signal_noise_ratio(image,upscaled_image,1.0)
                # Postfix for tqdm object.
                loop.set_postfix(loss=loss.cpu().item(),psnr=psnr.cpu().item())
                
                # Scaling loss and psnr with batch size to calculate accurate loss later.
                train_loss += loss.cpu().item()*image.shape[0]
                train_psnr += psnr.cpu().item()*image.shape[0]
            # Normalize by total number of images.
            train_loss = train_loss/len(self.train_dataset)
            train_psnr = train_psnr/len(self.train_dataset)

            # Append loss and psnr
            self.train_history.append(train_loss)
            self.psnr_train.append(train_psnr)
            
            # Test
            test_loss = 0.0
            test_psnr = 0.0
            with torch.no_grad():
                for i,(image,image_downscaled) in enumerate(self.test_loader):
                    loss,upscaled_image = self.model.train_step(image_downscaled,image)
                    # Update loss and psnr
                    test_loss += loss.cpu().item()*image.shape[0]
                    test_psnr += peak_signal_noise_ratio(image,upscaled_image,1.0).cpu().item()*image.shape[0]
            # Normalize
            test_loss = test_loss/len(self.test_dataset)
            test_psnr = test_psnr/len(self.test_dataset)
            
            # Append loss and psnr
            self.test_history.append(test_loss)
            self.psnr_test.append(test_psnr)
            
            # Checkpointing
            if self.epoch%self.configs.save_every == 0:
                self.save()
                
            # Update epoch number
            self.epoch += 1
            
                
    def save(self):
        # Create directory.
        if not os.path.exists(self.configs.save_path):
            os.mkdir(self.configs.save_path)
        
        # Separate directories for different epoch path
        dir = os.path.join(self.configs.save_path,f'epoch_{self.epoch}')
        if not os.path.exists(dir):
            os.mkdir(dir)
            # Save the model first.
            self.model.save(os.path.join(dir,f'model.pt'),
                            self.optim)
            torch.save({
                'configs': self.configs,
                'history': [self.train_history,
                            self.test_history,
                            self.psnr_train,
                            self.psnr_test]
                },os.path.join(dir,'trainer.pkl'))
        else:
            print("This version already exists.")
    
    @classmethod
    def load(cls, path: str, epoch: int,type):
        """Loads trainer with the model.

        Args:
            path (str): path to directory.
            type (any): type of srcnn used. (uninitiated object)
        """
        checkpoint_path = os.path.join(path,f'epoch_{epoch}')
        model,optim = type.load(os.path.join(checkpoint_path,f'model.pt'))
        load_dict = torch.load(os.path.join(checkpoint_path,f'trainer.pkl'))
        configs = load_dict['configs']
        
        # Nested history list.
        his = load_dict['history']
        
        # Create new class instance.
        self = cls(configs,model,torch.optim.Adam)  # Dummy optimizer.
        
        # Put in real optimizer
        self.optim = optim
        
        # Extract loss and metric histories.
        self.train_history = his[0]
        self.test_history = his[1]
        self.psnr_train = his[2]
        self.psnr_test = his[3]
        
        return self
            
        

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [5]:
args = models.ModelArgs(loss='perceptual')

In [6]:
model = models.SRCNN(args)

In [7]:
model.loss.model

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, kernel_si

In [None]:
(3,8,15,29)

In [25]:
configs = TrainConfigs()
trainer = Trainer(configs,model,torch.optim.Adam)

In [26]:
trainer.train()

15it [00:25,  1.71s/it, loss=46.2, psnr=5.95]


KeyboardInterrupt: 

In [5]:
trainer.save()

In [6]:
trainer_2 = Trainer.load('train_files',0,models.SRCNN)

In [7]:
trainer_2.train()

19it [00:17,  1.07it/s, loss=2.74, psnr=14.1]


KeyboardInterrupt: 

In [55]:
trainer_2.train_history

[]

In [3]:
trainer.model.loss

NameError: name 'trainer' is not defined