# Load pickle, presizing

In [None]:
class CenterCropTfm(Transform):
  def __init__(self, shape): self.shape = shape
  def encodes(self, t:TensorMultiCategory): return t
  def encodes(self, t:Tensor): return T.center_crop(t, shape=self.shape)

In [None]:
def open_pickle(*args):
  objs = []
  for fname in args:
    with open(ds_path/fname, 'rb') as f:
      objs.append(pickle.load(f))
  return objs

In [None]:
# return unique rows
def get_shapes_info(shape_list):
  (unique, idxs, counts) = np.unique(shape_list, axis=0, return_counts=True, return_index=True)
  d = zip(unique, counts, idxs)
  print("Unique Sizes & Counts: ", *d, sep="\n")
  return d

In [None]:
# resize to common shape (for dataloaders)
def get_common_shape(shape_list, f0 = np.min, f1 = np.min, do_print = False):
  shapes_arr   = np.asarray(shape_list)
  common_shape = f0(shapes_arr[:,0]), f1(shapes_arr[:,1])
  if do_print: print(f"Common shape: {common_shape}")
  return common_shape

In [None]:
def get_log2_list(l, cutoff=32):
  szs = []
  curr_sz = sz
  while curr_sz > cutoff: 
    szs.append(curr_sz)
    curr_sz = curr_sz//2
  return szs

# Rescaling px

Histogram that has about same number of pixels in each bin

In [None]:
def hist_scaled_pt(self:Tensor, bins=None):
  ys = torch.linspace(0., 1., len(bins))
  return self.flatten().interp_1d(bins, ys).reshape(self.shape).clamp(0.,1.)

def inv_hist_scaled_pt(self:Tensor, bins=None):
  ys = torch.linspace(0., 1., len(bins))
  min_bins, max_bins = min(bins), max(bins)
  return self.flatten().interp_1d(ys, bins).reshape(self.shape).clamp(min_bins, max_bins)

def plot_scale_fn(bins):
  fig, (ax0,ax1) = plt.subplots(ncols=2)

  ys = torch.linspace(0,1,len(bins))

  ax0.plot(bins, ys)
  ax0.set_title("Rescale Function");
  ax0.ticklabel_format(axis="x", style="sci", scilimits=(0,0))

  ax1.plot(ys, bins)
  ax1.set_title("Inv Function")
  ax1.ticklabel_format(axis="y", style="sci", scilimits=(1,0))

# Permutation Tfm

Permutation class, converts permutation representation between arr/str format


In [None]:
from itertools import permutations, combinations 

# arr to/from str
def arr2str(p, sep=""): return sep.join(map(str, p))
def str2arr(s, sep=""): return np.array([int(x) for x in s.split(sep)])

class P():

  def __init__(self, n_seq = 5, rands_len=10_000):
    # generate all possible permutations (shuffle orders)
    self.perms = list(permutations(range(n_seq)))
    self.c     = len(self.perms) # equals math.factorial(n_seq)

    # associate each perm with its label
    self.pairs  = list(combinations(range(n_seq), r=2))
    self.labels = [self.get_label(perm) for perm in self.perms]
    self.o2i    = {lbl:idx for idx,lbl in enumerate(self.labels)}

    # generate list of random nums (shuffle orders/labels)
    self.rands_len = rands_len
    self.set_rands()

  # set permutation list
  def set_rands(self):
    self.rands = np.random.randint(self.c, size=self.rands_len)

  # get label from permutation
  def get_label(self, perm):
    return tuple(float(perm[i]>perm[j]) for i,j in self.pairs)

  # return permutation/label at given index
  def get_perm(self, i):
    perm_idx     = self.rands[i]
    return self.perms[perm_idx], self.labels[perm_idx]

  # inverse permutation
  @staticmethod
  def get_inv(p): 
    d = dict(zip(p, range(len(p))))
    return np.array([d[i] for i in range(len(p))])

Test permutation

In [None]:
orig  = np.array(["a","b","c","d", "e"])
input = np.array(["d","a","b","c", "e"])

perm = np.array([3,0,1,2,4])
iperm = P.get_inv(perm)

print("Perm : ", arr2str(perm), "\n", "iPerm: ", arr2str(iperm), sep="")

np.testing.assert_array_equal(orig[perm], input)
np.testing.assert_array_equal(input[iperm], orig)

Perm : 30124
iPerm: 12304


Testing permutation label

In [None]:
p_tfm = P(n_seq=5, rands_len=10)

orig  = np.array(["a","b","c","d", "e"])
input = np.array(["d","a","b","c", "e"])

perm = np.array([3,0,1,2,4])

hand_label = (1,1,1,0,0,0,0,0,0,0)
auto_label = p_tfm.get_label(perm)

print("Label (by hand): ", arr2str(hand_label, sep = "   "), "\n", 
      "Label (autogen): ", arr2str(auto_label, sep=" "), sep="")

np.testing.assert_array_equal(hand_label, auto_label)

Label (by hand): 1   1   1   0   0   0   0   0   0   0
Label (autogen): 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0


 # Sandwich Tfm

Given an index, return a sequence of slices (a "sandwich") corresponding to that index.

Setup/init: 
- pass a list of filenames + attributes, (fname, attribute_dict).
- pass the desired sequence length, nseq.

Encode: given index $i$, 
- look up the correct file to open, 
- finds the corresponding slice sequence ($x$-2, $x$-1, $x$, $x$+1, $x$+2),
- returns a shuffled version + the target label of which slice pairs are shuffled.

In [None]:
# indexes into array of kspace slices
class ImSandwP(Transform, P):

  def __init__(self, fn2attr_items, objs = None, n_seq = 5):
    self.n_seq = n_seq
    self.fn2attr_items = fn2attr_items

    # set objs (either those given as input or load into RAM h5 files)
    self.objs = objs if objs else [self.get_obj(fn) for fn,_ in fn2attr_items]

    # how many slices/sandwiches per file
    self.n_slices = [attr["n_slices"] for (_, attr) in self.fn2attr_items]
    self.n_sandws = [n-(self.n_seq-1) for n in self.n_slices]

    # given dataset index, cumsum is used to index into appropriate file
    self.cumsum_n_slices = np.cumsum(self.n_slices)
    self.cumsum_n_sandws = np.cumsum(self.n_sandws)

    # total number of slices/sandwiches in dataset (for reference)
    self.total_n_slices   = self.cumsum_n_slices[-1]
    self.total_n_sandws   = self.cumsum_n_sandws[-1]

    #  choose permutation to use for each sandw (avoid call to rand repeated)
    P.__init__(self, n_seq=self.n_seq, rands_len=self.total_n_sandws)

  # load h5 items into RAM for faster lookup
  def get_obj(self, fname):
    with h5py.File(fname, 'r') as f:
      return C.apply(f['kspace'][()], CenteredTfms.k2im(), pre=T.to_tensor, post=C.complex2mgn)
  
  # returns index in fn2attr corresponding to given sandwich
  def sandw2fn_idx(self, i): return np.searchsorted(self.cumsum_n_sandws, i)
  
  def encodes(self, sandw_idx):
    # get file index from sandw number (idx 0 = first sandwich = sandwich #1)
    fn_idx  = self.sandw2fn_idx(sandw_idx + 1)

    # sandwich index within file
    n_sandw_prior_vol = 0 if fn_idx == 0 else self.cumsum_n_sandws[fn_idx - 1]
    sw_idx  = sandw_idx - n_sandw_prior_vol # 15sandw in file = 14th index
    
    # get imspace sandwich, target slice + neighbor slices on either side
    imsandw  = self.objs[fn_idx][sw_idx:sw_idx+self.n_seq]     

    # shuffle ksandw according to permutation
    perm, label = self.get_perm(sandw_idx)
    return imsandw[np.array(perm)], TensorMultiCategory(label)

  def decodes(self, o):
    imsandw, label = o
    label = tuple(label.numpy())
    return imsandw, self.perms[self.o2i[label]]