In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
import numpy as np

In [2]:
def _segments_to_indices(self, segments, N, pad=0):
        idxs = []
        for a, b in segments:
            # map fraction [0,1] to pixel indices [0, N-1]
            i0 = max(pad, int(round(a * (N - 1))))
            i1 = min((N - 1) - pad, int(round(b * (N - 1))))
            if i1 >= i0:
                idxs.append(torch.arange(i0, i1 + 1, dtype=torch.long))
        if not idxs:
            # fallback to full range if nothing provided
            return torch.arange(pad, N - pad, dtype=torch.long)
        return torch.unique(torch.cat(idxs)).to(torch.long)
    

def __getitem__(self, index):
    if not type(self.steps) == np.ndarray:
        step = np.random.randint(1,200)
    else:
        step = self.steps[index]

    # Create tensor for the target
    t = torch.tensor(get_all(self.sims[index], step), dtype=torch.float32)

    # Create 0 matrix
    z = torch.zeros_like(t)

    # build a boolean mask of revealed pixels, shape (H,W)
    mask = torch.zeros((self.H, self.W), dtype=torch.bool)

    chans = self._chan_idx()

    if self.reveal_strategy == "block":
        # choose top-left for the block
        if not type(self.points) == np.ndarray:
            i0 = np.random.randint(0, self.H - self.block_size + 1)
            j0 = np.random.randint(0, self.W - self.block_size + 1)
        else:
            i0, j0 = self.points[index]
            i0 = max(0, min(i0, self.H - self.block_size))
            j0 = max(0, min(j0, self.W - self.block_size))
        mask[i0:i0+self.block_size, j0:j0+self.block_size] = True

    elif self.reveal_strategy == "disks":
        row_fracs = self.reveal_dim[0] # e.g, [(0, 1)]
        col_fracs = self.reveal_dim[1]

        row_allowed = self._segments_to_indices(row_fracs, self.H, pad=self.radius)
        col_allowed = self._segments_to_indices(col_fracs, self.W, pad=self.radius)

        # choose grid shape close to aspect ratio works with non-squares
        Hspan = (row_allowed[-1] - row_allowed[0] + 1) if len(row_allowed) > 0 else self.H
        Wspan = (col_allowed[-1] - col_allowed[0] + 1) if len(col_allowed) > 0 else self.W
        ratio = Wspan / max(1, Hspan)

        ny = int(max(1, round(np.sqrt(self.n_points / max(1e-8, ratio)))))
        nx = int(max(1, round(self.n_points / ny)))

        while nx * ny < self.n_points:
            nx += 1

        # pick evenly spaced indices from rows/cols allowed
        def pick_lin_indices(allowed, k):
            if k <= 1:
                return allowed[len(allowed)//2]
            # linespace over positions
            pos = torch.linspace(0, len(allowed)-1, steps=k)
            idx = torch.round(pos).long()
            return allowed[idx]
        
        row_picks = pick_lin_indices(row_allowed, ny)
        col_picks = pick_lin_indices(col_allowed, nx)

        yy, xx = torch.meshgrid(row_picks, col_picks, indexing="ij")
        points = torch.stack([yy.reshape(-1), xx.reshape(-1)], dim=1) # (ny*nx, 2)
        
        # if more than n_points, subselect
        if points.shape[0] > self.n_points:
            sel_pos = torch.linspace(0, points.shape[0]-1, steps=self.n_points)
            sel_idx = torch.round(sel_pos).long()
            points = points[sel_idx]

        ii = points[:][0]
        jj = points[:][1]