In [None]:
# default_exp abs_model

In [None]:
# hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Abstract Model

> Contains the general structure for all models in this library

In [None]:
# export
import torch
from torch import nn, optim
from deeptool.utils import Tracker

In [None]:
# exports


class AbsModel(nn.Module):
    """
    This class contains the general architecture and functionality to deal with all Models in this library
    contains:
    Tracker -> to visualize the progress
    Prep-input -> to handle the input depending on the dataset smootly
    """

    def __init__(self, args):
        """init the abtract model"""
        super(AbsModel, self).__init__()

        # Setup the input loader
        self.prep = self.select_prep(args.dataset_type)

        # Setup the tracker to visualize the progress
        if args.track:
            self.tracker = Tracker(args)

    def select_prep(self, mode):
        switcher = {
            "MRNet": self.prep_mrnet_input,
            "KneeXray": self.prep_kneexray_input,
        }
        # Get the model_creator
        prep = switcher.get(mode, lambda: "Invalid Dataset Type")
        # create model
        return prep

    def prep_mrnet_input(self, data):
        """
        This function deals with the MRNET input
        data = {"img: x", ...}
        """
        return data["img"]

    def prep_kneexray_input(self, data):
        """
        This function deals with the KneeXray input
        data = [x, y]
        """
        return data[0]

    @torch.no_grad()
    def watch_progress(self, test_data, iteration):
        """Outsourced to Tracker"""
        self.tracker.track_progress(self, test_data, iteration)

In [None]:
# hide
from nbdev.export import *

notebook2script()

Converted 00_dataloader.ipynb.
Converted 01_architecture.ipynb.
Converted 02_utils.ipynb.
Converted 03_parameters.ipynb.
Converted 04_train_loop.ipynb.
Converted 05_abstract_model.ipynb.
Converted 10_diagnosis.ipynb.
Converted 20_dcgan.ipynb.
Converted 21_introvae.ipynb.
Converted 22_vqvae.ipynb.
Converted 23_bigan.ipynb.
Converted 24_mocoae.ipynb.
Converted 33_rnn_vae.ipynb.
Converted 99_index.ipynb.
