In [39]:
%pip install wandb

Defaulting to user installation because normal site-packages is not writeable
Collecting wandb
  Downloading wandb-0.10.25-py2.py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 3.7 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting sentry-sdk>=0.4.0
  Downloading sentry_sdk-1.0.0-py2.py3-none-any.whl (131 kB)
[K     |████████████████████████████████| 131 kB 18.9 MB/s 
[?25hCollecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting subprocess32>=3.5.3
  Downloading subprocess32-3.5.4.tar.gz (97 kB)
[K     |████████████████████████████████| 97 kB 5.8 MB/s 
Building wheels for collected packages: subprocess32, pathtools
  Building wheel for subprocess32 (setup.py) ... [?25ldone
[?25h  Created wheel for subprocess32: filename=subprocess32-3.5.4-py3-none-any.whl size=6488 sha256=42f2826b8e6101c079c4520c929135262837b78c254702a3eec634a52d345b8e
  Stored in directory: /home/jupyter/

In [60]:
from collections import defaultdict
from pathlib import Path
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.utils as utils
from torch.utils.data import DataLoader
from torch import optim
from tqdm import tqdm
import wandb

from model_base import ModelBase

In [61]:
class Trainer():
    def __init__(self, 
                 model: nn.Module,
                 config: dict,
                 optimizers: list, # [(label, opt)]
                 train_loader: DataLoader,
                 val_loader: DataLoader=None,
                 scheduler=None):
        self.model = model
        self.config = config
        self.optimizers = self.model.configure_optimizers()
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.scheduler = scheduler
        self.history = {"train": defaultdict(list), 
                        "val": defaultdict(list)}

    def save_checkpoint(self,
                        epoch: int,
                        checkpoint_path: Path,
                        ) -> None:
        checkpoint = {
            "model": self.model,
            "model_state_dict": self.model.state_dict(),
            "epoch": epoch,
        }

        for opt in optimizers:
            label = opt["label"]
            optimizer = opt["value"]

            checkpoint[f"optimizer_{label}"] = optimizer
            checkpoint[f"optimizer_{label}_state_dict"] = optimizer.state_dict()

        torch.save(checkpoint, checkpoint_path)

    def load_checkpoint(self, 
                        checkpoint_path: Path,
                        ) -> None:
        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint["model_state_dict"])

        for opt in optimizers:
            label = opt["label"]
            optimizer = opt["value"]
            optimizer.load_state_dict(checkpoint[f"optimizer_{label}_state_dict"])

    @torch.enable_grad()
    def train_epoch(self,
                    pbar: tqdm
                    ) -> None:
        model.train()
        
        for batch_idx, batch in enumerate(tqdm.tqdm(self.train_loader)):
            for opts in optimizers:

                step = opts["label"]
                optimizer = opts["value"]

                info = self.model.training_step(batch=batch, 
                                                step=step)
                loss = info['loss']
                loss.backward()
                utils.clip_grad_norm_(parameters=model.parameters(),
                                      max_norm=10)
                
                self._update_history(info)
                self._update_logs(pbar)
                
                optimizer.step()
                optimizer.zero_grad()
        
    def _update_logs(self, pbar: tqdm):
        pbar.update(1)

        history_train = self.history["train"]
        postfix_train = {
            key + "_train": history_train[key][-1] for key in history_train
        }

        history_val = self.history["val"]
        postfix_val = {
            key + "_val": history_val[key][-1] for key in history_val
        }

        pbar.set_postfix({**postfix_train, **postfix_val})
        wandb.log({**postfix_train, **postfix_val})

    def _update_history(self, info):
        for key in info:
            if key not in info:
                print(f"Warning: not valid key in history - {key}")
                continue
            for inner_key in info[key]:
                value = info[key][inner_key]
                if isinstance(inner, torch.Tensor):
                    value = value.item()
                self.history[key][inner].append(value)

    def fit(self):
        n_epochs = self.config["n_epochs"]
        pbar = tqdm(total=epochs, position=0, leave=True)
        wandb.init(project="test-drive", config=self.config)
        wandb.watch(self.model)

        for epoch in range(n_epochs):
            self._train_epoch(pbar)
            
            if epoch % self.save_period == 0:
                loss = self.history["train"]["loss"]
                checkpoint_path = \
                    Path.cwd() / "models" / f'loss={loss},e={epoch}.pt'
                self.save_checkpoint(checkpoint_path, epoch)

                with torch.no_grad():
                    batch = next(iter(train_dataloader))
                    sample = model.sample(batch)
                images = (utils.make_grid(sample, nrow=4).detach().cpu().permute(1,2,0)
                          * Tensor([0.406, 0.456, 0.485])
                          + Tensor([0.225, 0.224, 0.229])).numpy()
                wandb.log({"generated images": [wandb.Image(images)]})

        pbar.close()
        wandb.finish()
