load data
根据原论文内容，要：
transform the .mrc file into np array
chunked into pairs of overlapping boxes of size 60*60*60 with strides of 30 voxels
augmentation:
random 90 degree rotation
randomly cropping 48*48*48 box from 60*60*60box

In [None]:
import os
import random
import mrcfile
import numpy as np
import interp_back
import torch
import torchvision
from torch.utils import data
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
from pytorch_msssim import ssim
from scunet import SCUNet
from utils import pad_map, chunk_generator, parse_map, get_batch_from_generator
import matplotlib.pyplot as plt


depoFolder = "/home/tyche/training_and_validation_sets/depoFiles"
simuFolder = "/home/tyche/training_and_validation_sets/simuFiles"



In [None]:
def get_all_files(directory):
    file_list = list()
    for file in os.listdir(directory):
        file_list.append(f"{directory}/{file}")
    return file_list


depoList = get_all_files(depoFolder)
simuList = get_all_files(simuFolder)
depoList.sort()
simuList.sort()

In [None]:
def mrc2padded(mrcfile, apix):
    map, origin, nxyz, voxel_size, nxyz_origin = parse_map(mrcfile, ignorestart=False, apix=apix)
    print(f"# Original map dimensions: {nxyz_origin}")
    nxyzstart = np.round(origin / voxel_size).astype(np.int64)
    print(f"# Map dimensions at {apix} Angstrom grid size: {nxyz}")
    padded_map = pad_map(map, 60, dtype=np.float32, padding=0.0)
    maximum = np.percentile(map[map > 0], 99.999)
    del map
    return padded_map, maximum


padded_map, maximum = mrc2padded(depoList[-1], 1.0)
generator = chunk_generator(padded_map, maximum, 60, 30)
positions, chunks = get_batch_from_generator(generator, 10, dtype=np.float32)
chunks.shape


In [None]:
def transform(tensor, outsize=48):
    N = tensor.shape[0]
    axes_options=[(0,1), (1, 2), (0, 2)]
    nx, ny, nz = tensor.shape[1:4]
    newx, newy, newz = outsize, outsize, outsize
    output = torch.zeros(N, 48, 48, 48, device=tensor.device)
    for i in range(N):
        k = random.choice([1, 2, 3]) 
        rotated = torch.rot90(tensor[i], k=k, dims=random.choice(axes_options))
        startX = random.randint(0, nx-newx)
        startY = random.randint(0, ny-newy)
        startZ = random.randint(0, nz-newz)
        cropped = rotated[startX:startX+outsize, startY:startY+outsize, startZ:startZ+outsize]
        output[i] = cropped
    del tensor
    torch.cuda.empty_cache()
    return output

# 输入为torch张量batch_size*60*60*60

In [None]:
net = SCUNet(
    in_nc=1,
    config=[2,2,2,2,2,2,2],
    dim=32,
    drop_path_rate=0.0,
    input_resolution=48,
    head_dim=16,
    window_size=3,
)
torch.cuda.empty_cache()
net = net.cuda()


In [None]:
def loss(X, Y):
    smooth_L1 = nn.SmoothL1Loss()
    

In [None]:
trainer = torch.optim.Adam(net.parameters(), lr=0.0005)

In [None]:
num_epochs = 300
net.train()
loss_values = []

for epoch in range(num_epochs):
    for depoFile, simuFile in zip(depoList, simuList):
        if(os.path.getsize(depoFile) > 1024 * 1024 * 512 or os.path.getsize(simuFile) > 1024 * 1024 * 512):
            continue
        train_loss = 0
        depoPadded, depoMax = mrc2padded(depoFile, 1.0)
        simuPadded, simuMax = mrc2padded(simuFile, 1.0)
        depo_generator = chunk_generator(depoPadded, depoMax, 60, 30)
        simu_generator = chunk_generator(simuPadded, simuMax, 60, 30)
        while True:
            _, depo_chunks = get_batch_from_generator(depo_generator, 32, dtype=np.float32)
            _, simu_chunks = get_batch_from_generator(simu_generator, 32, dtype=np.float32)
            if depo_chunks.shape != simu_chunks.shape:
                continue
            
            if depo_chunks.shape[0] == 0:
                break
            depo_chunks = torch.from_numpy(depo_chunks)
            simu_chunks = torch.from_numpy(simu_chunks)
            depo_chunks = transform(depo_chunks)
            simu_chunks = transform(simu_chunks)
            l = loss(net(depo_chunks), simu_chunks)
            trainer.zero_grad()
            l.backward()
            trainer.step()
            train_loss += l
        plt.plot(epoch, train_loss, 'ro', label='Train')
        plt.pause(0.01)