In [7]:
import numpy as np
import mrcfile
import matplotlib.pyplot as plt
# %matplotlib notebook

In [94]:
outdir = '../data/full/'
fname = 'stack_0003_2x_dfpt.mrc'
with mrcfile.open(f'{outdir}{fname}', permissive=True) as mrc:
    data = mrc.data
    
data.shape

(3710, 3838)

In [225]:
class SubgraphLoader():
    def __init__(self, 
                 inputShape, 
                 sampleShape, 
                 keys, 
                 keyToPath,
                 labelFile):
        self.inputShape = inputShape
        self.sampleShape = sampleShape
        self.keys = keys
        self.keyToPath = keyToPath
        self.particles = self._parseParticles(labelFile)
    
    def getMicrograph(self, key):
        """ Load micrograph. """
        with mrcfile.open(self.keyToPath(key), permissive=True) as mrc:
            data = mrc.data
        return data
    
    def boxContains(point, sampleId):
        dimH, dimW = self.sampleShape 
        x, y = point
        cx, cy = sampleId
        return (cx*dimW < x < cx*dimW + dimW) and (cy*dimH < y < cy*dimH + dimH)
    
    def _generateSubgraph(self, key):
        """ Generate the subimages for a given micrograph. """
        retDict = {}
        h, w = self.inputShape
        dimH, dimW = self.sampleShape
        data = self.getMicrograph(key)
        for idxh in range(int(h/dimH)):
            for idxw in range(int(w/dimW)):
                retDict[(idxh,idxw)] = data[idxh*dimH:idxh*dimH+dimH, 
                                                  idxw*dimW:idxw*dimW+dimW]
        return retDict

    def _parseParticles(self, file):
        """ 
        Read in the particles for all micrographs. 
        This will need to be edited if you change your key types. 
        """
        with open(file, "r") as f:
            particles = f.readlines()

        particleData = [particle.split()[0:3] for particle in particles[17:-1]]
        particleDict = {}
        for x in particleData:
            key = int(x[0][18:22])
            value = tuple(map(float, x[1:]))
            particleDict.setdefault(key, []).append(value)
        return particleDict

    def getSubgraphAnnotation(self, shift = True):
        """ 
        Searches through the particle list a dictionary which maps
                
        (micrographKey, subgraphKey) -> [particles in subgraph]

        subgraphKey - is the x,y position of a subgraph within the grid formed by the subgraphs over the micrograph.
        shift specifies if you want the absolute or relative posiotions.
        """
        subDict = {}
        h, w = self.inputShape
        dimH, dimW = self.sampleShape
        for micrograph in self.keys:
            for idxh in range(int(h/dimH)):
                for idxw in range(int(w/dimW)):
                    subgraph_particles = np.array(
                        list(
                            filter(
                                lambda x : inRect(x, (idxh, idxw), dimH, dimW), 
                                self.particles[micrograph]
                            )
                        )
                    )
                    try:
                        subDict[(micrograph, idxh, idxw)] = subgraph_particles - np.array([idxh*dimH, idxw*dimW])
                    except:
                        continue
        return subDict
    
    def getSubgraphs(self):
        subDict = {}
        for micrograph in self.keys:
            subgraphs = self._generateSubgraph(micrograph)
            for k,v in subgraphs.items():
                subDict[(micrograph, *k)] = v
        return subDict

In [220]:
pdict = parse_particle_centers('../data/full/particles.star')
sdict = generate_subgraph(data, 512, 512)

In [226]:
loader = SubgraphLoader (
    (3710, 3838), 
    (512, 512), 
    [1, 2, 3], 
    lambda x: f'../data/full/stack_{str(x).zfill(4)}_2x_dfpt.mrc',
    '../data/full/particles.star'
)

In [227]:
loader.getSubgraphs()



{(1,
  0,
  0): array([[11.063994 , 11.445369 , 11.335733 , ..., 11.296578 , 11.358143 ,
         12.807101 ],
        [11.481459 , 10.988326 , 11.1371975, ..., 10.948187 , 11.896396 ,
         11.789806 ],
        [11.38492  , 10.875472 , 12.121873 , ..., 11.317797 , 11.845706 ,
         11.468467 ],
        ...,
        [11.968819 , 11.628999 , 11.62763  , ..., 11.131636 , 10.930626 ,
         11.460574 ],
        [11.018502 , 11.428985 , 11.313138 , ..., 10.8780365, 10.881426 ,
         11.092923 ],
        [11.055786 , 11.600703 , 11.459327 , ..., 10.966662 , 11.703842 ,
         10.973822 ]], dtype=float32),
 (1,
  0,
  1): array([[11.829643 , 11.229397 , 11.305838 , ..., 11.95431  , 10.874928 ,
         11.680517 ],
        [11.78418  , 11.760414 , 12.252817 , ..., 11.336928 , 10.7646675,
         11.647305 ],
        [11.313099 , 11.412874 , 11.614745 , ..., 10.979542 , 11.0251465,
         11.221465 ],
        ...,
        [11.342705 , 10.735772 , 11.632245 , ..., 11.242066 , 1

In [228]:
loader.getSubgraphAnnotation()

{(1,
  0,
  0): array([[299., 440.],
        [266., 203.]]),
 (1,
  0,
  1): array([[326., 369.],
        [449., 279.],
        [356., 501.]]),
 (1, 0, 2): array([[344., 343.]]),
 (1,
  0,
  3): array([[221., 104.],
        [407., 140.],
        [278., 269.],
        [356.,  41.]]),
 (1,
  0,
  4): array([[203.,  63.],
        [359.,  48.],
        [206., 381.]]),
 (1,
  0,
  5): array([[290., 226.],
        [395., 283.],
        [419., 502.]]),
 (1,
  0,
  6): array([[239., 341.],
        [278.,  41.]]),
 (1,
  1,
  0): array([[381., 377.],
        [138., 467.],
        [ 66., 359.]]),
 (1,
  1,
  1): array([[300., 261.],
        [147., 129.],
        [ 18., 180.]]),
 (1,
  1,
  2): array([[486.,  13.],
        [159., 445.],
        [366., 433.]]),
 (1,
  1,
  3): array([[177., 323.],
        [ 63., 185.]]),
 (1,
  1,
  4): array([[141.,  78.],
        [318.,  72.],
        [198., 444.]]),
 (1,
  1,
  5): array([[111.,  34.],
        [153., 310.],
        [510., 280.],
        [ 36., 