In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15, 8)


In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"  # makes some CUDA calls deterministic
import torch
torch.use_deterministic_algorithms(True)
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_grad_enabled(False)
torch.inference_mode(True)


In [None]:
from typing import NamedTuple
import random
import time
import PIL
import torch.linalg as linalg
from torch.nn.functional import mse_loss
import torchvision.transforms.functional as vF
from tqdm.notebook import tqdm, trange


In [None]:

def coords_tensor(D, n):
  coord_vals = torch.arange(n) / (n-1)
  coord_components = D*[coord_vals]
  X = torch.stack(torch.meshgrid(coord_components, indexing='ij'))
  s = list(range(1, D+1)) + [0]  # result is like [1,2,3,0]
  X = X.permute(s).flatten(0, D-1)
  return X
# print(coords_tensor(2, 3))

def pil_to_samples(img):
  D = 1 if (img.height == 1) else 2
  n = img.width
  t = vF.to_tensor(img).squeeze()
  Y = t.flatten().unsqueeze(1)
  X = coords_tensor(D, n)
  return X, Y

def samples_to_pil(X, Y):
  D = X.shape[1]
  n = round(len(Y)**(1/D))
  t = Y.reshape([1] + D*[n])
  return vF.to_pil_image(t)


In [None]:

class Config(NamedTuple):
  # signal
  D: int  # number of dimensions
  sig: str  # signal/image file name, minus '.png'
  # encoding
  type: str  # see complex_ef() for options
  gen: str  # see gen_fn_for(name) for options
  N: int

  # architecture
  H: int = 0
  G: int = 1
  # init and training
  seed: int = 0
  opt: str = 'lstsq'
  mlr: int = 1  # milli learning rate  # DEPRECATED, unused
  epoch_count: int = 1

  @classmethod
  def from_path(cls, path):
    d = torch.load(path)
    return Config(**d)

  def replace(self, **kwargs):
    return self._replace(**kwargs)

  @property
  def P(self) -> int:
    if self.type == 'simple':
      return self.N * self.D
    else:
      return self.N**self.D

  @property
  def sig_path(self):
    return f'../../exps/{self.D}d/{self.sig}.png'

  @property
  def dir_path(self):
    return '/'.join([
      f'../../exps/results',
      f'D_{self.D}',
      f'sig_{self.sig}',
      f'type_{self.type} gen_{self.gen} N_{self.N}',
      f'H_{self.H} G_{self.G}',
      f'opt_{self.opt} mlr_{self.mlr} seed_{self.seed}'
    ])

  @property
  def short_dir_path(self):
    return '/'.join([
      f'../../exps/solutions',
      f'D_{self.D} sig_{self.sig}',
      f'type_{self.type} gen_{self.gen}',
    ])


In [None]:

def gen_fn_for(name):
  return {
    'ident': lambda x: x,
    'rect' : lambda x: 0.5 * (torch.sign(x+0.5) - torch.sign(x-0.5)),
    'tri'  : lambda x: torch.max(1 - torch.abs(x), torch.zeros_like(x)),
    'gauss': lambda x: torch.exp(-0.5 * x*x),
    'sinc' : lambda x: torch.sinc(x),
  }[name]

def onedim_ef(x, gen_fn, N=1):
  offsets = torch.arange(0, N).unsqueeze(0)
  encoded = gen_fn(x[:, 0].unsqueeze(1) * (N - 1) - offsets)
  return encoded

def simple_ef(x, gen_fn, N=1):
  encoded = torch.zeros((x.shape[0], x.shape[1] * N))
  offsets = torch.arange(0, N).unsqueeze(0)
  for i in range(x.shape[1]):
    encoded[:, i * N:(i + 1) * N] = gen_fn(x[:, i].unsqueeze(1) * (N - 1) - offsets)
  return encoded

def complex_ef(x, gen_fn, N=1):
  offsets = torch.arange(0, N).unsqueeze(0)
  enc0 = gen_fn(x[:, 0].unsqueeze(1) * (N - 1) - offsets).unsqueeze(1)
  enc1 = gen_fn(x[:, 1].unsqueeze(1) * (N - 1) - offsets).unsqueeze(2)
  encoded = torch.mul(enc0, enc1).flatten(1)
  return encoded

def fixargs(fn, **fixedkwargs):
  def newfn(*args, **kwargs):
    newkwargs = dict(fixedkwargs)
    newkwargs.update(kwargs)
    return fn(*args, **newkwargs)
  return newfn

def make_encoder(config: Config):
  gen_fn = gen_fn_for(config.gen)
  enc = fixargs(
    simple_ef,
    gen_fn=gen_fn,
    N=config.N)
  return enc


In [None]:

class SimpleSolver(object):

  def __init__(self, config: Config):
    self.cfg = config
    self.img = PIL.Image.open(self.cfg.sig_path)
    self.X, self.Y = pil_to_samples(self.img)
    self.Phi = make_encoder(self.cfg)
    self.PhiX = self.Phi(self.X)

    self.W = torch.zeros((self.cfg.D * self.cfg.N, 1))

    self.ns_taken = 0
    self.loss0 = self.loss()
    self.loss1 = 1.0

  def Yhat(self):
    return self.PhiX @ self.W

  def loss(self):
    return mse_loss(self.Yhat(), self.Y).item()

  def solve(self) -> None:
    if self.ns_taken != 0:
      print(f'Warning: already solved {self.cfg.dir_path}')
      return
    start = time.time_ns()

    res = linalg.lstsq(self.PhiX, self.Y, driver='gelsy')
    self.W.copy_(res[0])

    self.loss1 = self.loss()
    self.ns_taken += time.time_ns() - start


In [None]:

class ComplexSolver(object):

  def __init__(self, config: Config):
    self.cfg = config
    self.img = PIL.Image.open(self.cfg.sig_path)
    self.n = self.img.width
    self.X, self.Y = pil_to_samples(self.img)

    self.x1 = torch.arange(0, self.n) / (self.n - 1)
    self.x1 = self.x1.reshape(-1,1)
    self.Phi = make_encoder(self.cfg)
    self.PhiX = self.Phi(self.x1)
    self.W = torch.zeros((1, self.cfg.N, self.cfg.N))

    self.ns_taken = 0
    self.loss0 = self.loss()
    self.loss1 = 1.0

  def Yhat(self):
      yhat = (self.PhiX @ self.W @ self.PhiX.T).transpose(0,2)
      yhat = yhat.flatten().unsqueeze(1)
      return yhat

  def loss(self):
    return mse_loss(self.Yhat(), self.Y).item()

  def solve(self) -> None:
    if self.ns_taken != 0:
      print(f'Warning: already solved {self.cfg.dir_path}')
      return
    start = time.time_ns()

    ix = linalg.pinv(self.PhiX)
    Y_shaped = self.Y.reshape(self.n, self.n, 1).transpose(0,2)
    self.W = ix @ Y_shaped @ ix.T

    self.loss1 = self.loss()
    self.ns_taken += time.time_ns() - start


In [None]:

def vary_N_and_save(cfg: Config, Nlim=None):
  img = PIL.Image.open(cfg.sig_path)
  n = img.width
  if Nlim is None:
    Nlim = n+10
  Ns = torch.zeros((Nlim,))
  Ps = torch.zeros_like(Ns)
  mses = torch.zeros_like(Ns)
  times = torch.zeros_like(Ns)
  for i in list(reversed(range(len(mses)))):
    cfg = cfg.replace(N=i+2)
    if cfg.type == 'simple':
      solver = SimpleSolver(cfg)
    elif cfg.type == 'complex':
      solver = ComplexSolver(cfg)
    else:
      raise NotImplementedError
    solver.solve()
    Ns[i] = cfg.N
    Ps[i] = cfg.P
    mses[i] = solver.loss1
    times[i] = solver.ns_taken
    del solver
  dst: str = cfg.short_dir_path
  os.makedirs(dst, exist_ok=True)
  torch.save(Ns, os.path.join(dst, 'N.pt'))
  torch.save(Ps, os.path.join(dst, 'P.pt'))
  torch.save(mses, os.path.join(dst, 'mse.pt'))
  torch.save(times, os.path.join(dst, 'time.pt'))

def filenames_1d():
  return ('0801_penguin', '0809_lion', '0823_greece', '0872_walnuts', '0887_castle')

def filenames_s1d():
  for fname in os.listdir('../../exps/1d/'):
    if not fname[0] == 's' or not fname.endswith('.png'):
      continue
    yield fname.replace('.png', '')

def filenames_2d(smallonly):
  ret = []
  for name in filenames_1d():
    ret.append('s' + name)
    if smallonly:
      continue
    ret.append(name)
  return ret

def save_all_1d():  # takes ~10min
  for gen in tqdm(['sinc', 'gauss', 'tri', 'rect'], desc='gen'):
    for sig in tqdm(filenames_1d(), desc='sig', leave=False):
      vary_N_and_save(Config(D=1, sig=sig, type='simple', gen=gen, N=0))

def save_all_s1d():
  for sig in tqdm(list(filenames_s1d()), desc='file'):
    for gen in ['sinc', 'gauss', 'tri', 'rect']:
      vary_N_and_save(Config(D=1, sig=sig, type='simple', gen=gen, N=0))

def save_all_simple_2d(smallonly=False):
  for gen in tqdm(['sinc', 'gauss', 'tri', 'rect'], desc='gen'):
    for sig in tqdm(filenames_2d(smallonly), desc='sig', leave=False):
      vary_N_and_save(Config(D=2, sig=sig, type='simple', gen=gen, N=0))

def save_all_complex_2d(smallonly=False):
  for gen in tqdm(['sinc', 'gauss', 'tri', 'rect'], desc='gen'):
    for sig in tqdm(filenames_2d(smallonly), desc='sig', leave=False):
      vary_N_and_save(Config(D=2, sig=sig, type='complex', gen=gen, N=0))

def save_greece():
  for gen in tqdm(['gauss', 'tri', 'rect'], desc='gen'):
    vary_N_and_save(Config(D=2, sig='0823_greece', type='complex', gen=gen, N=0))


# vary_N_and_save(Config(D=2, sig='0823_greece', type='simple', gen='sinc', N=0), Nlim=50)
# save_all_1d()
save_all_s1d()
# save_all_simple_2d(True)
# save_all_complex_2d(True)
# save_greece()


In [None]:

def chart():
  cfgs = [
    Config(D=2, sig='0823_greece', type='simple', gen='sinc', N=0),
    Config(D=2, sig='0823_greece', type='complex', gen='sinc', N=0),
  ]
  fig, ax = plt.subplots()
  for cfg in cfgs:
    dst: str = cfg.short_dir_path
    Ns = torch.load(os.path.join(dst, 'N.pt'))
    Ps = torch.load(os.path.join(dst, 'P.pt'))
    mses = torch.load(os.path.join(dst, 'mse.pt'))
    times = torch.load(os.path.join(dst, 'time.pt'))
    ax.plot(Ns[:50].cpu(), mses[:50].cpu())

  ax.set_yscale('log')
  # ax.set_xlim(0, 50)

chart()


In [None]:

def check_lstsq_path(dirpath: str):
  cfg = Config.from_path(os.path.join(dirpath, 'config.pt'))
  old_loss = torch.load(os.path.join(dirpath, 'loss_series.pt'))[1].item()

  if cfg.type == 'simple':
    solver = SimpleSolver(cfg)
  elif cfg.type == 'complex':
    solver = ComplexSolver(cfg)
  else:
    raise NotImplementedError
  solver.solve()
  new_loss = solver.loss1
  del solver

  diff = old_loss - new_loss
  if abs(diff) > 1e-9:
    print(dirpath)
    print(diff, old_loss, new_loss)


def check_lstsq_all(etype=None):
  dirpaths = []
  for dirpath, dirnames, filenames in os.walk('/home/alex/hp/exps/results/D_2'):
    if etype and not dirpath.count(etype):
      continue
    if not dirpath.count('lstsq'):
      continue
    for filename in filenames:
      if filename == 'config.pt':
        dirpaths.append(dirpath)
  random.shuffle(dirpaths)  # so that there is no run that is always checked last
  for dpath in tqdm(dirpaths):
    check_lstsq_path(dpath)
  print(f'verified {len(dirpaths)} runs')


# check_lstsq_path('/home/alex/hp/exps/results/D_2/sig_0823_greece/type_complex gen_sinc N_4/H_0 G_1/opt_lstsq mlr_1 seed_0')
# check_lstsq_all()
