In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15, 8)


In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"  # makes some CUDA calls deterministic
import torch
torch.use_deterministic_algorithms(True)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True


In [None]:
from typing import NamedTuple, List, Iterator
import random
import time
from collections import defaultdict
import gc
import numpy as np
import PIL
import torch.nn as nn
import torchvision.transforms.functional as vF
from tqdm.notebook import tqdm, trange


In [None]:

def coords_tensor(D, n):
  coord_vals = torch.arange(n) / (n-1)
  coord_components = D*[coord_vals]
  X = torch.stack(torch.meshgrid(coord_components, indexing='ij'))
  s = list(range(1, D+1)) + [0]  # result is like [1,2,3,0]
  X = X.permute(s).flatten(0, D-1)
  return X
# print(coords_tensor(2, 3))

def pil_to_samples(img):
  D = 1 if (img.height == 1) else 2
  n = img.width
  t = vF.to_tensor(img).squeeze().cuda()
  Y = t.flatten().unsqueeze(1)
  X = coords_tensor(D, n)
  return X, Y

def samples_to_pil(X, Y):
  D = X.shape[1]
  n = round(len(Y)**(1/D))
  t = Y.reshape([1] + D*[n])
  return vF.to_pil_image(t)


In [None]:

class Config(NamedTuple):
  # signal
  D: int  # number of dimensions
  sig: str  # signal/image file name, minus '.png'
  # encoding
  type: str  # see complex_ef() for options
  gen: str  # see gen_fn_for(name) for options
  N: int
  # architecture
  H: int = 0
  G: int = 1
  # init and training
  seed: int = 0
  opt: str = 'tpc'
  mlr: int = 1  # milli learning rate  # DEPRECATED, unused
  epoch_count: int = 10000

  @classmethod
  def from_path(cls, path):
    d = torch.load(path)
    return Config(**d)

  def replace(self, **kwargs):
    return self._replace(**kwargs)

  def compression_ratio(self):
    # X, Y = pil_to_samples(PIL.Image.open(self.sig_path))
    # uncompressed: int = len(Y)
    # TODO: don't hardcode these resolutions
    uncompressed = 1260**2
    if self.D == 1:
      uncompressed = 2040
    if self.sig[0] == 'B':
      uncompressed = 1023**2
    if self.sig[0] == 's':
      uncompressed = 512**2
    compressed: int = self.P
    return compressed / uncompressed

  @property
  def P(self) -> int:
    if self.H == 0:
      return self.NN_input_width
    else:
      return (self.NN_input_width * self.G) + (self.H - 1) * (self.G * self.G) + (self.G * 1)

  def count_params(self, trainer_to_check=None):
    P = self.P
    if trainer_to_check is not None:
      actual = sum(p.numel() for p in trainer_to_check.net.parameters() if p.requires_grad)
      assert P == actual
    widths: List[int] = self.NN_layer_widths
    total_by_widths = 0
    for i in range(1, len(widths)):
      total_by_widths += widths[i-1] * widths[i]
    assert P == total_by_widths
    return P

  @property
  def sig_path(self):
    return f'../../exps/{self.D}d/{self.sig}.png'

  @property
  def dir_path(self):
    return '/'.join([
      f'../../exps/results',
      f'D_{self.D}',
      f'sig_{self.sig}',
      f'type_{self.type} gen_{self.gen} N_{self.N}',
      f'H_{self.H} G_{self.G}',
      f'opt_{self.opt} mlr_{self.mlr} seed_{self.seed}'
    ])

  @property
  def NN_input_width(self) -> int:
    if self.type == 'simple':
      return self.N * self.D
    else:
      return self.N**self.D

  @property
  def NN_layer_widths(self) -> List[int]:
    return [self.NN_input_width] + self.H * [self.G] + [1]

  @property
  def lr(self):
    return self.mlr / 1000


In [None]:

def gen_fn_for(name):
  return {
    'ident': lambda x: x,
    'rect' : lambda x: 0.5 * (torch.sign(x+0.5) - torch.sign(x-0.5)),
    'tri'  : lambda x: torch.max(1 - torch.abs(x), torch.zeros_like(x)),
    'gauss': lambda x: torch.exp(-0.5 * x*x),
    'sinc' : lambda x: torch.sinc(x),
  }[name]

def simple_ef(x, gen_fn, N=1):
  encoded = torch.zeros((x.shape[0], x.shape[1] * N))
  offsets = torch.arange(0, N).unsqueeze(0)
  for i in range(x.shape[1]):
    encoded[:, i * N:(i + 1) * N] = gen_fn(x[:, i].unsqueeze(1) * (N - 1) - offsets)
  return encoded

def complex_ef(x, gen_fn, N=1):
  offsets = torch.arange(0, N).unsqueeze(0)
  enc0 = gen_fn(x[:, 0].unsqueeze(1) * (N - 1) - offsets).unsqueeze(1)
  enc1 = gen_fn(x[:, 1].unsqueeze(1) * (N - 1) - offsets).unsqueeze(2)
  encoded = torch.mul(enc0, enc1).flatten(1)
  return encoded

def fixargs(fn, **fixedkwargs):
  def newfn(*args, **kwargs):
    newkwargs = dict(fixedkwargs)
    newkwargs.update(kwargs)
    return fn(*args, **newkwargs)
  return newfn

def make_encoder(config: Config):
  gen_fn = gen_fn_for(config.gen)
  type_fn = {
    'simple': simple_ef,
    'complex': complex_ef,
  }[config.type]
  enc = fixargs(
    type_fn,
    gen_fn=gen_fn,
    N=config.N)
  return enc


def plot_enc():
  cfg = Config(D=2, sig='', type='complex', gen='tri', N=3, H=0, G=0)
  enc = make_encoder(cfg)
  n = 41
  coords = coords_tensor(2, n)
  print(coords.shape)
  pic = enc(coords)
  print(pic.shape)
  fig, ax = plt.subplots(ncols=1)
  ax.imshow(pic[:, 0].reshape([n,n]).cpu(), cmap='gray', vmin=0, vmax=1)
# plot_enc()


In [None]:

class AdamOpt(object):
  def __init__(self, net, phix, y):
    self.net = net
    self.PhiX = phix
    self.Y = y
    self.criterion = nn.MSELoss()
    self.optimizer = torch.optim.Adam(self.net.parameters())  #, lr=<DEFAULT>=1e-3)

  def load_state_dict(self, t):
    self.optimizer.load_state_dict(t)

  def state_dict(self):
    return self.optimizer.state_dict()

  def one_epoch(self):
    self.optimizer.zero_grad()
    loss = self.criterion(self.net(self.PhiX), self.Y)
    loss.backward()
    self.optimizer.step()
    return loss.item()


class TpcAdamOpt(object):
  # Adam, with parameters copied from the TPC paper demo code:
  # https://colab.research.google.com/github/osiriszjq/complex_encoding/blob/main/complex_encoding.ipynb
  def __init__(self, net, phix, y):
    self.net = net
    self.PhiX = phix
    self.Y = y
    self.criterion = nn.MSELoss()
    self.optimizer = torch.optim.Adam(self.net.parameters(), betas=(0.9, 0.999), weight_decay=1e-8)

  def load_state_dict(self, t):
    self.optimizer.load_state_dict(t)

  def state_dict(self):
    return self.optimizer.state_dict()

  def one_epoch(self):
    self.optimizer.zero_grad()
    loss = self.criterion(self.net(self.PhiX), self.Y)
    loss.backward()
    self.optimizer.step()
    return loss.item()


def opt_method_by_name(name: str):
  return {
    'adam': AdamOpt,
    'tpc': TpcAdamOpt,
  }[name]


In [None]:
def layers_for(widths: List[int]) -> Iterator[nn.Module]:
  for i in range(len(widths)-1):
    if i:
      yield nn.ReLU()
    yield nn.Linear(widths[i], widths[i+1], bias=False)


In [None]:

class Trainer(object):
  epoch: int
  loss_series: torch.Tensor
  time_series: torch.Tensor
  net: torch.nn.Module

  def __init__(self, config: Config, saveload: bool = True):
    if config.compression_ratio() > 1:
      print(f'Warning: Overlarge model {config.dir_path}')
    self.cfg = config
    self.saveload = saveload
    # common init
    self.img = PIL.Image.open(self.cfg.sig_path)
    self.X, self.Y = pil_to_samples(self.img)
    self.Phi = make_encoder(self.cfg)
    self.PhiX = self.Phi(self.X)
    # epoch 0 init
    torch.manual_seed(self.cfg.seed)
    np.random.seed(self.cfg.seed)
    self.epoch = 0
    self.summary = {}
    self.loss_series = torch.zeros((500 + 1,))
    self.time_series = torch.zeros((500,))
    self.net = nn.Sequential(*layers_for(self.cfg.NN_layer_widths))
    self.opt = opt_method_by_name(self.cfg.opt)(self.net, self.PhiX, self.Y)
    # load latest epoch if possible
    self.load()

  def load(self):
    if not self.saveload:
      return
    try:
      summa = torch.load(f'{self.cfg.dir_path}/summary.pt')
      losss = torch.load(f'{self.cfg.dir_path}/loss_series.pt')
      times = torch.load(f'{self.cfg.dir_path}/time_series.pt')
      netsd = torch.load(f'{self.cfg.dir_path}/net_state.pt')
      optsd = torch.load(f'{self.cfg.dir_path}/opt_state.pt')
    except OSError:
      return
    self.summary = summa
    self.loss_series = losss
    self.time_series = times
    self.epoch = len(self.time_series)
    self.net.load_state_dict(netsd)
    self.opt.load_state_dict(optsd)
    self.net.train()

  def update_summary(self):
    self.summary['P'] = self.cfg.count_params(self)

  def save(self):
    if not self.saveload:
      return
    os.makedirs(self.cfg.dir_path, exist_ok=True)
    self.update_summary()
    torch.save(self.cfg._asdict(),           f'{self.cfg.dir_path}/config.pt')
    torch.save(self.summary,                 f'{self.cfg.dir_path}/summary.pt')
    torch.save(self.loss_series,             f'{self.cfg.dir_path}/loss_series.pt')
    torch.save(self.time_series,             f'{self.cfg.dir_path}/time_series.pt')
    torch.save(self.net.state_dict(),        f'{self.cfg.dir_path}/net_state.pt')
    torch.save(self.opt.state_dict(),        f'{self.cfg.dir_path}/opt_state.pt')

  def Yhat(self):
    return self.net(self.PhiX)

  def loss(self):
    return nn.functional.mse_loss(self.Yhat(), self.Y).item()

  def ensure_space(self):
    if self.epoch >= len(self.time_series):
      extra = len(self.loss_series)
      self.loss_series = torch.cat([self.loss_series, torch.zeros((extra,))])
      self.time_series = torch.cat([self.time_series, torch.zeros((extra,))])

  def train(self, stopper) -> bool:  # returns whether progress was made
    if stopper.stopnow():
      return False
    self.net.train()
    while not stopper.stopnow():
      start = time.time_ns()
      loss = self.opt.one_epoch()
      self.ensure_space()
      self.loss_series[self.epoch] = loss
      self.time_series[self.epoch] = time.time_ns() - start
      self.epoch += 1
    self.ensure_space()
    self.loss_series[self.epoch] = self.loss()
    self.loss_series = self.loss_series[:self.epoch + 1]
    self.time_series = self.time_series[:self.epoch]
    self.save()  # TODO: save periodically
    return True


class Stopper(object):
  def __init__(self, trainer: Trainer, epoch: int = None, dtime: float = None, max_progratio: float = None):
    self.tr = trainer
    self.stopepoch = epoch or 1000000000
    self.dtime = dtime
    self.max_progratio = max_progratio
    self.stoptime = None

  def start(self) -> bool:  # returns whether progress was made
    if self.dtime:
      self.stoptime = time.time_ns() + self.dtime * 1e9
    return self.tr.train(self)

  def stuck(self):
    return self.tr.epoch >= 1000 and float(self.tr.loss_series[1000]) > 0.3

  def progratio(self):
    if self.tr.epoch < 1000:
      return 0.0
    l_half = float(self.tr.loss_series[self.tr.epoch//2])
    l_now = float(self.tr.loss_series[self.tr.epoch])
    if l_half == 0:
      return 1.0
    return l_now / l_half

  def stopnow(self):
    if self.stuck():
      return True
    if self.max_progratio and self.progratio() >= self.max_progratio:
      return True
    if self.stopepoch and self.tr.epoch >= self.stopepoch:
      return True
    if self.stoptime and time.time_ns() > self.stoptime:
      return True
    return False


In [None]:

# def run_many():
#   configs = []
#   for N in range(2, 12+1):
#     sizes = []
#     comp = Config(D=2, sig='', type='complex', gen='sinc', N=N, H=0, G=1)
#     simp = Config(D=2, sig='', type='simple', gen='sinc', N=N, H=1, G=1)
#     for H in [1, 2, 4, 8]:
#       hsimp = simp.replace(H=H)
#       best = None
#       for G in range(2, 1000+1):
#         gsimp = hsimp.replace(G=G)
#         if gsimp.P > comp.P:
#           break
#         best = gsimp
#       if best is None:
#         break
#       configs.append(best)
#       sizes.append(best.G)
#     print(N, sizes)
#
#   configs = [c for c in configs if c.compression_ratio() <= 1]  # remove overlarge models
#   configs.sort(key=lambda x: x.P, reverse=True)
#   # for c in configs:
#   #   print(c.dir_path, c.P)
#
#   for newseed in range(1000):
#     for config in tqdm(configs):
#       # for fname in ['0801_penguin', '0809_lion', '0823_greece', '0872_walnuts', '0887_castle']:
#       config = config.replace(seed=newseed, sig='0823_greece')
#       trainer = Trainer(config)
#       fin = trainer.is_finished
#       if fin:
#         print(f'{config.dir_path} already complete')
#       else:
#         trainer.train()
#         print(f'{config.P}\t{trainer.loss_series[-1].item()}\t{config.dir_path}')


def run_comparable_N_P():
  configs = []
  for N in [8, 12, 16, 24, 32, 48, 64, 96, 128, 192, 256, 384]:
    comp = Config(D=2, sig='s0823_greece', type='complex', gen='sinc', N=N, H=0, G=1)
    for H in [1, 2, 3, 4, 5, 6, 7, 8]:
      most_P = None
      for G in range(1, 2*N):
        c = comp.replace(type='simple', H=H, G=G)
        if c.P > comp.P:
          break
        most_P = c
      if most_P:
        configs.append(most_P)

  # configs.reverse()
  # configs.sort(key=lambda x: x.P, reverse=True)
  # for c in configs:
  #   print(c.dir_path, c.P)

  seed_by_cfg = defaultdict(int)
  for trial in range(1, 3+1):
    print(f'Starting trial #{trial}')
    for config in configs:
      while True:
        c = config.replace(seed=seed_by_cfg[config])
        stopper = Stopper(Trainer(c), epoch=1000)
        stopper.start()
        stuck = stopper.stuck()
        del stopper
        if not stuck:
          break
        print(f'Skipping bad seed: {c}')
        seed_by_cfg[config] += 1
    for req_progratio in [0.5, 0.6, 0.7, 0.8, 0.86, 0.9, 0.93, 0.96]:
      print(f'Progratio: {req_progratio}')
      progcount = 1
      prevprogcount = -1
      while progcount:
        progcount = 0
        for config in configs:
          stopper = Stopper(
            Trainer(config.replace(seed=seed_by_cfg[config])),
            dtime=60.0,
            max_progratio=req_progratio
          )
          prog = stopper.start()
          progcount += int(prog)
          del stopper
        if progcount != prevprogcount:
          prevprogcount = progcount
          print(f'Number of models making progress: {progcount}/{len(configs)}')
    print(f'Trial #{trial} finished. Incrementing seeds.')
    for config in configs:
      seed_by_cfg[config] += 1


run_comparable_N_P()


In [None]:

def run_1D():
  configs = []
  cf = Config(D=1, sig='0823_greece', type='simple', gen='sinc', N=1000, H=0, G=1)
  periodN = 10
  for N in range(periodN, 100+1, periodN):
    for H in range(1, 10+1):
      most_P = None
      for G in range(2, 2040):
        c = cf.replace(type='simple', N=N, H=H, G=G)
        if c.P > cf.P:
          break
        most_P = c
      if most_P:
        configs.append(most_P)

  # configs.reverse()

  configs = [c for c in configs if c.compression_ratio() <= 1]  # remove overlarge models
  # configs = [c for c in configs if c.P <= 2000]  # remove overlarge models
  configs.sort(key=lambda x: x.P, reverse=True)
  # for c in configs:
  #   print(c.dir_path, c.P)

  for newseed in range(1000):#, desc='seeds'):
    for config in tqdm(configs, desc='configs', leave=True):
      trainer = Trainer(config.replace(seed=newseed))
      if trainer.is_finished:
        continue
      trainer.train()
      # print(f'{config.dir_path}\t{trainer.loss_series[-1].item()}\t{config.P}')


# run_1D()


In [None]:

def find_prop(rootdir: str = '../../exps/results'):
  cfgs = []
  for dirpath, dirnames, filenames in os.walk(rootdir):
    for filename in filenames:
      if filename == 'config.pt':
        cfgs.append(Config.from_path(os.path.join(dirpath, filename)))
  rcfgs = []
  for c in tqdm(cfgs):
    r = c.compression_ratio()
    if r >= 1:
      rcfgs.append((c.dir_path, r))
  rcfgs.sort(reverse=True)
  for p, r in rcfgs:
    print(p, r)

# find_prop()


In [None]:

def resave_path(path: str, verbose: bool = True):
  if verbose:
    print(f'resaving path: {path[:-9]} ...')
  cfg = Config.from_path(path)
  prevrun = Trainer(cfg)
  prevrun.save()

def resave_all(rootdir: str = '../../exps/results'):
  paths = []
  for dirpath, dirnames, filenames in os.walk(rootdir):
    for filename in filenames:
      if filename == 'config.pt':
        paths.append(os.path.join(dirpath, filename))
  for path in tqdm(paths):
    resave_path(path, verbose=False)
  print(f'resaved {len(paths)} runs')

# resave_path('/home/alex/hp/exps/results/D_1/sig_0801_penguin/type_simple gen_sinc N_2/H_0 G_1/opt_adam mlr_1 seed_0/config.pt')
# resave_all()


In [None]:

def verify_path(path: str, epochs: int = 1, verbose: bool = True):
  if verbose:
    print(f'checking path: {path}')
  cfg = Config.from_path(path)
  rerun = Trainer(cfg.replace(epoch_count=epochs), saveload=False)
  rerun.train()
  rerun_loss_series = rerun.loss_series[:epochs+1].detach()
  del rerun
  prevrun = Trainer(cfg)
  # simple checks
  assert len(prevrun.loss_series) == cfg.epoch_count + 1
  assert len(prevrun.time_series) == cfg.epoch_count
  # rerun some training
  assert torch.equal(prevrun.loss_series[:epochs+1], rerun_loss_series[:epochs+1])
  del prevrun
  if verbose:
    print(f'all checks out')

def verify_all(rootdir: str = '../../exps/results'):
  paths = []
  for dirpath, dirnames, filenames in os.walk(rootdir):
    if dirpath.count('lstsq'):
      continue
    for filename in filenames:
      if filename == 'config.pt':
        paths.append(os.path.join(dirpath, filename))
  random.shuffle(paths)  # so that there is no run that is always checked last
  for path in tqdm(paths):
    verify_path(path, verbose=False)
  print(f'verified {len(paths)} runs')

# verify_path('/home/alex/hp/exps/results/D_1/sig_0801_penguin/type_simple gen_sinc N_2/H_0 G_1/opt_adam mlr_1 seed_0/config.pt', epochs=10)
# verify_all()
