In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15, 8)


In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"  # makes some CUDA calls deterministic
import torch
torch.use_deterministic_algorithms(True)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_grad_enabled(False)
torch.inference_mode(True)


In [None]:
from typing import Dict
from tqdm.notebook import tqdm
from collections import defaultdict
import math
import random
import csv


In [None]:

def tryload(filepath, missing_ok=False):
  try:
    return torch.load(filepath)
  except OSError:
    if not missing_ok:
      raise Exception(f'no file found at {filepath}')
    return None

class RunData(object):
  def __init__(self, dirpath: str):
    self.path = dirpath
    self.config = tryload(os.path.join(self.path, 'config.pt'))
    self.summary = tryload(os.path.join(self.path, 'summary.pt'), missing_ok=True)
    self.loss_series = tryload(os.path.join(self.path, 'loss_series.pt'))
    self.time_series = tryload(os.path.join(self.path, 'time_series.pt'))
    # self.net_state = tryload(os.path.join(self.path, 'net_state.pt'))
    # self.opt_state = tryload(os.path.join(self.path, 'opt_state.pt'))

def runpaths():
  for dirpath, dirnames, filenames in os.walk('../../exps/results'):
    for filename in filenames:
      if filename == 'config.pt':
        yield dirpath

def update_summary(r: RunData, recalc: bool = False) -> bool:
  changed = False
  if r.summary is None:
    r.summary = {}
    changed = True
  recalc = recalc or ('time_total' not in r.summary) or (r.summary['time_total'] != torch.sum(r.time_series).item())
  if recalc or 'last_mse' not in r.summary or 'last_epoch' not in r.summary:
    r.summary['last_mse'] = r.loss_series[-1].item()
    r.summary['last_epoch'] = len(r.loss_series) - 1
    changed = True
  if recalc or 'best_mse' not in r.summary or 'best_epoch' not in r.summary:
    minres = torch.min(r.loss_series, dim=0)
    r.summary['best_mse'] = minres.values.item()
    r.summary['best_epoch'] = minres.indices.item()
    changed = True
  if recalc or 'worst_mse' not in r.summary or 'worst_epoch' not in r.summary:
    maxres = torch.max(r.loss_series, dim=0)
    r.summary['worst_mse'] = maxres.values.item()
    r.summary['worst_epoch'] = maxres.indices.item()
    changed = True
  if recalc or 'time_min' not in r.summary:
    r.summary['time_min'] = torch.min(r.time_series).item()
    changed = True
  if recalc or 'time_max' not in r.summary:
    r.summary['time_max'] = torch.max(r.time_series).item()
    changed = True
  if recalc or 'time_total' not in r.summary:
    r.summary['time_total'] = torch.sum(r.time_series).item()
    changed = True
  if recalc or 'I' not in r.summary:
    if r.config['type'] == 'simple':
      r.summary['I'] = r.config['N'] * r.config['D']
    else:
      r.summary['I'] = r.config['N'] ** r.config['D']
    changed = True
  return changed

def update_summaries():
  paths = list(runpaths())
  # random.shuffle(paths)
  changecount = 0
  for p in tqdm(paths):
    data = RunData(p)
    change = update_summary(data)
    if change:
      changecount += 1
      spath = os.path.join(data.path, 'summary.pt')
      torch.save(data.summary, spath)
      print(f'Updated {spath}')
    del data
  print(f'Updated {changecount}/{len(paths)} runs.')

update_summaries()


In [None]:

def stats_from(dpath: str) -> Dict:
  d = {}
  d.update(torch.load(os.path.join(dpath, 'config.pt')))
  d.update(torch.load(os.path.join(dpath, 'summary.pt')))
  return d

def collate():
  dst = '../stats/results_raw.csv'
  paths = list(runpaths())
  with open(dst, 'w', newline='') as f:
    first_stats = stats_from(paths[0])
    writer = csv.DictWriter(f, fieldnames=first_stats.keys())
    writer.writeheader()
    for path in tqdm(paths):
      writer.writerow(stats_from(path))

collate()


In [None]:

def bestpaths():
  shortpaths = set()
  for p in runpaths():
    shortpaths.add(p[:p.index('/opt_')])
  for p in sorted(shortpaths):
    if p.count('H_0'):
      continue
    best_path = None
    best_mse = 10
    for dirpath, dirnames, filenames in os.walk(p):
      for filename in filenames:
        if filename == 'summary.pt':
          d = torch.load(os.path.join(dirpath, 'summary.pt'))
          mse = d['best_mse']
          if mse < best_mse:
            best_path = dirpath
            best_mse = mse
    yield best_path

def psnr(mse):
  if mse == 0:
    return 1000000.0
  return 10 * math.log10(1 / mse)

def clean_stats_from(dpath: str) -> Dict:
  d = {}
  d.update(torch.load(os.path.join(dpath, 'config.pt')))
  d.update(torch.load(os.path.join(dpath, 'summary.pt')))
  clean = {
    'N': d['N'],
    'I': d['I'],
    'H': d['H'],
    'G': d['G'],
    'P': d['P'],
    'K': d['G'] / d['I'],
    'psnr': psnr(d['best_mse']),
    'total_time': d['time_total'] * 1e-9,
    'epoch_time': d['time_total'] * 1e-9 / d['last_epoch'],
  }
  return clean

def clean_collate():
  paths_by_fname = defaultdict(list)
  for p in bestpaths():
    fname = p[p.index('D_'):p.index(' N_')].replace('/', ' ')
    paths_by_fname[fname].append(p)
  for fname,paths in tqdm(paths_by_fname.items(), desc='fnames'):
    # print(repr(k), len(paths), paths[0])
    with open(f'../stats/H_H {fname}.csv', 'w', newline='') as f:
      first_stats = clean_stats_from(paths[0])
      writer = csv.DictWriter(f, fieldnames=first_stats.keys())
      writer.writeheader()
      for path in paths:
        writer.writerow(clean_stats_from(path))

clean_collate()


In [None]:

def solution_paths():
    for dirpath, dirnames, filenames in os.walk('../../exps/solutions'):
      for filename in filenames:
        if filename == 'mse.pt':
          yield dirpath

def collate_solutions_from(dpath: str):
  mses = torch.load(os.path.join(dpath, 'mse.pt'))
  Ns = torch.load(os.path.join(dpath, 'N.pt'))
  Ps = torch.load(os.path.join(dpath, 'P.pt'))
  times = torch.load(os.path.join(dpath, 'time.pt'))
  ds = []
  for i in range(len(mses)):
    mse = mses[i].item()
    N = Ns[i].item()
    P = Ps[i].item()
    t = times[i].item()
    d = {
      'N': N,
      'I': P,
      'H': 0,
      'G': 0,
      'P': P,
      'K': 0,
      'psnr': psnr(mse),
      'total_time': t * 1e-9,
      'epoch_time': t * 1e-9,
    }
    ds.append(d)
  return ds

def collate_solutions():
  for p in solution_paths():
    ds = collate_solutions_from(p)
    fname = p[p.index('D_'):].replace('/', ' ')
    with open(f'../stats/H_0 {fname}.csv', 'w', newline='') as f:
      first_stats = ds[0]
      writer = csv.DictWriter(f, fieldnames=first_stats.keys())
      writer.writeheader()
      for d in ds:
        writer.writerow(d)

# collate_solutions()
