load data
根据原论文内容，要：
transform the .mrc file into np array
chunked into pairs of overlapping boxes of size 60*60*60 with strides of 30 voxels
augmentation:
random 90 degree rotation
randomly cropping 48*48*48 box from 60*60*60box

In [2]:
import os
import random
import mrcfile
import numpy as np
import interp_back
import torch
import torchvision
from torch.utils import data
from torch.utils.data import Dataset
from torchvision import transforms
from utils import pad_map, chunk_generator, parse_map

depoFolder = "/home/tyche/training_and_validation_sets/depoFiles"
simuFolder = "/home/tyche/training_and_validation_sets/simuFiles"



In [None]:
def indices_of_map(map, box_size, stride, dtype=np.float32, padding=0.0):
    assert stride <= box_size
    map_shape = np.shape(map)
    padded_map = np.full((map_shape[0] + 2 * box_size, map_shape[1] + 2 * box_size, map_shape[2] + 2 * box_size), padding, dtype=dtype)
    padded_map[box_size : box_size + map_shape[0], box_size : box_size + map_shape[1], box_size : box_size + map_shape[2]] = map
    indices = list()
    start_point = box_size - stride
    cur_x, cur_y, cur_z = start_point, start_point, start_point
    while (cur_z + stride < map_shape[2] + box_size):
        next_chunk = padded_map[cur_x:cur_x + box_size, cur_y:cur_y + box_size, cur_z:cur_z + box_size]
        cur_x += stride
        if (cur_x + stride >= map_shape[0] + box_size):
            cur_y += stride
            cur_x = start_point # Reset X
            if (cur_y + stride  >= map_shape[1] + box_size):
                cur_z += stride
                cur_y = start_point # Reset Y
                cur_x = start_point # Reset X
        indices.append([cur_x, cur_y, cur_z])
    n_chunks = len(indices)
    indices = np.asarray(indices, dtype=dtype)
    return indices

In [None]:
def mrc2Indices(mrcFile):
    map, _, _, _, _ = parse_map(mrcFile, ignorestart=False)
    maximum = np.percentile(map[map > 0], 99.999)
    map = np.where(map > 0, map / maximum, 0)
    padded_map = pad_map(map, 60, dtype=np.float32, padding=0.0)
    del map
    indices = indices_of_map(padded_map, 60, 30)
    return torch.from_numpy(indices)

        

In [None]:
def transform(tensor, outsize=48):
    N = tensor.shape[0]
    axes_options=[(0,1), (1, 2), (0, 2)]
    nx, ny, nz = tensor.shape[1:4]
    newx, newy, newz = outsize, outsize, outsize
    output = torch.zeros(N, 48, 48, 48, device=tensor.device)
    for i in range(N):
        k = random.choice([1, 2, 3]) 
        rotated = torch.rot90(tensor[i], k=k, dims=random.choice(axes_options))
        startX = random.randint(0, nx-newx)
        startY = random.randint(0, ny-newy)
        startZ = random.randint(0, nz-newz)
        cropped = rotated[startX:startX+outsize, startY:startY+outsize, startZ:startZ+outsize]
        output[i] = cropped
    del tensor
    torch.cuda.empty_cache()
    return output


In [None]:
def get_all_files(directory):
    file_list = list()
    for file in os.listdir(directory):
        file_list.append(f"{directory}/{file}")
    return file_list


In [None]:
depoList = get_all_files(depoFolder)
simuList = get_all_files(simuFolder)

In [None]:
def data_iter(batch_size,depoList, simuList):
    num_examples = len(depoList)
    for depofile, simufile in zip(depoList, simuList):
        depoIndices = mrc2Indices(depofile)
        simu


In [3]:
map, origin, nxyz, voxel_size, nxyz_origin = parse_map(f"{depoFolder}/7KHA_deposited.mrc", ignorestart=False, apix=1.0)
print(map, origin, nxyz, voxel_size, nxyz_origin)

# Rescale voxel size from [1.047 1.047 1.047] to [1. 1. 1.]
[[[ 0.03967627 -0.05195117 -0.21086168 ... -0.12607902 -0.11112486
   -0.03566635]
  [-0.01963407 -0.06368623 -0.17821681 ... -0.10512558 -0.04603753
    0.01077917]
  [-0.03778876  0.00466018 -0.01873241 ... -0.18926582 -0.10191675
   -0.04391088]
  ...
  [-0.12737338 -0.21426395 -0.23344398 ...  0.1642514   0.0864611
   -0.01059788]
  [-0.02199879 -0.13732263 -0.18568696 ...  0.16201054  0.1087267
    0.04579527]
  [ 0.08057188 -0.0203341  -0.13713813 ... -0.04864865 -0.0647044
   -0.02367403]]

 [[-0.1243434  -0.04380324 -0.10431768 ... -0.07594126 -0.12551868
   -0.19107004]
  [-0.1598873  -0.0527913  -0.05477556 ... -0.00420102 -0.05281496
   -0.1415216 ]
  [-0.14302589 -0.00871784  0.05809193 ... -0.04513825 -0.06213009
   -0.12889107]
  ...
  [-0.22429416 -0.22255464 -0.21681365 ... -0.02432631 -0.0419833
   -0.13384932]
  [-0.162963   -0.11976589 -0.11873236 ...  0.06514268  0.02082419
   -0.10892078]
  [-0.09283847 -0

In [4]:
map, origin, nxyz, voxel_size, nxyz_origin = parse_map(f"{simuFolder}/7KHA_simulated.mrc", ignorestart=False)
print(map, origin, nxyz, voxel_size, nxyz_origin)

[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 ...

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]] [-102.8255