# AutoMap

## DataLoader

In [None]:
import torch
from torch.utils.data import Dataset
import nibabel as nib
import numpy as np


class AutoMap_train(Dataset):
    def __init__(self, path, modulation):
        self.nii = nib.load(path).get_data()  # nii 3D影像数据
        self.xlen = len(self.nii)
        self.ylen = len(self.nii[0])
        self.zlen = len(self.nii[0][0])
        self.t = 5
        self.modulation = modulation
    
    def __getitem__(self, index):
        # 每张图片有 self.t 种变换版本, index 的位置决定切片的方式
        if index < self.t * self.xlen:
            img = self.nii[index, :, :]
        elif index < self.t * (self.xlen + self.ylen):
            img = self.nii[:, index, :]
        else img = self.nii[:, :, index]

        # 做图像增强
        # {0: origin, 1: 90, 2: 180, 3: 270, 4: randomcrop}
        mode = index % self.t
        augment = {0: self.origin, 1: self.rotate, 2: self.rotate, 3: self.rotate, 4: self.randomcrop}
        if mode == 0 or mode == 4:
            img = augment[mode](img)
        else:
            img = augment[mode](img, mode * 90)  # 图像旋转

        # raw 是经过 img FFT 变换的 k-space
        # 我们需要把 raw 转化为 n^2 个参数的 magnitude 和 phase
        raw = np.fft.fft2(img, "ortho")
        
        if self.modulation=="Magnitude":
            return np.real(raw).reshape(1, -1), img
        else:
            return np.angle(raw).reshape(1, -1), img

    def __len__(self):
        return self.t * (self.xlen + self.ylen + self.zlen)
    
    def origin(self, img):
        return img
    
    def resize(self, img):
        pass

    def randomcrop(self, img):
        pass

    def rotate(self, img, degree):
        pass


## 模型

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class AutoMap(nn.Module):
    def __init__(self):
        super(AutoMap, self).__init__()