In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
plt.rcParams['figure.figsize'] = (15, 8)


In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"  # makes some CUDA calls deterministic
import torch
torch.use_deterministic_algorithms(True)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_grad_enabled(False)
torch.inference_mode(True)


In [None]:
from typing import NamedTuple, List, Iterator, Optional
import PIL
import torch.nn as nn


In [None]:
class Config(NamedTuple):
  # signal
  D: int  # number of dimensions
  sig: str  # signal/image file name, minus '.png'
  # encoding
  type: str  # see complex_ef() for options
  gen: str  # see gen_fn_for(name) for options
  N: int
  # architecture
  H: int
  G: int
  # init and training
  seed: int
  opt: str
  mlr: int  # milli learning rate
  epoch_count: int = 10000

  @classmethod
  def from_path(cls, path):
    d = torch.load(path)
    return Config(**d)


In [None]:
class RunData(object):
  def __init__(self, dirpath: str):
    self.path = dirpath
    self.config = Config.from_path(os.path.join(self.path, 'config.pt'))
    self.loss_series = torch.load(os.path.join(self.path, 'loss_series.pt'))
    self.ratio_series = self.loss_series[:-1] / self.loss_series[1:]
    self.time_series = torch.load(os.path.join(self.path, 'time_series.pt'))
    self.net_state = torch.load(os.path.join(self.path, 'net_state.pt'))  # TODO: visualise or don't bother loading
    self.opt_state = torch.load(os.path.join(self.path, 'opt_state.pt'))  # TODO: visualise or don't bother loading

def print_config(run):
  print(f'Path: {run.path}')
  print(f'Configuration:')
  # for k, v in run.config.items():
  #   print(f'  {k}:\t{v}')
  print(f'  Signal: {run.config.D}D, {run.config.sig}.png')
  print(f'  Encoding: {run.config.type}, {run.config.gen}, N={run.config.N}')
  print(f'  Architecture: H={run.config.H} hidden layers of width G={run.config.G}')
  print(f'  Training: {run.config.opt}, lr={run.config.mlr/1000}, seed={run.config.seed}, {run.config.epoch_count} epochs')
  print(f'  Run so far: {len(run.time_series)} epochs, total time {int(run.time_series.sum() * 1e-9)} seconds.')
  print()

def chart_loss(run):
  fig, ax0 = plt.subplots(nrows=1)
  ax0.plot(run.loss_series[:].cpu())
  ax0.set_xscale('log')
  ax0.set_yscale('log')
  ax0.set_xlim(1e1)
  ax0.grid(which='both')

data = RunData('/home/alex/hp/exps/results/D_2/sig_s0823_greece/type_simple gen_sinc N_256/H_4 G_85/opt_tpc mlr_1 seed_0/')
# print_config(data)
chart_loss(data)


In [None]:
n = len(data.loss_series)
half1 = data.loss_series[:n//2]
half2 = data.loss_series[n//2:]
best1 = float(torch.min(half1))
best2 = float(torch.min(half2))
r = best2/best1
print(best1, best2, r)
if r > 0.9:
  print('stop')


In [None]:
def percentile_chart(series):
  ordered = series.msort()
  fig, ax = plt.subplots(ncols=1)
  ax.set_xlabel("Percent rank")
  ax.set_xlabel("Percentile value")
  ax.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=len(ordered)))
  ax.grid(which='both')
  ax.plot(ordered.cpu())
  return fig, ax


In [None]:

def chart_ratios(run):
  fig, ax = percentile_chart(run.ratio_series)
  ax.set_yscale('log')
  ax.axhline(y=1, color='black', ls='--')
  ax.set_ylabel('MSE change ratio')
  ax.set_title('MSE change ratio by percentile')

chart_ratios(data)


In [None]:

def chart_time(run):
  fig, ax = percentile_chart(run.time_series / 1e9)
  ax.set_ylim(0)
  ax.set_ylabel('Time to train one epoch (seconds)')
  ax.set_title('Training time by percentile')

chart_time(data)
