In [17]:
import sys, os
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
import nibabel as nib
from omegaconf import OmegaConf
import nrrd
import random

sys.path.append('..')
from user_model import UserModel
from utils import *




In [34]:
### Load basic config
cfg = OmegaConf.load('../configs/ae_training.yaml')

In [35]:
class PretrainingDataset(Dataset):
    
    def __init__(
        self, 
        subjects: List[str],
        data_dir: str 
    ):
        # data_dir = cfg["data_dir"]
        # subject_dirs = [
        #     d for d in os.listdir(data_dir) 
        #     if os.path.isdir(os.path.join(data_dir, d))
        #     and 'corrupted' not in d
        # ]
        # if isinstance(cfg['subjects'], list):
        #     subject_dirs = [d for d in subject_dirs if d in cfg['subjects']]

        self.data_index = []
        for subject in subjects:
            slice_dir = f'{data_dir}/{subject}/Diffusion/data'
            slice_files = [
                f'{subject}/Diffusion/data/{f}' for f in os.listdir(slice_dir) 
                if os.path.isfile(os.path.join(slice_dir, f))
            ]
            self.data_index += slice_files

    def __len__(self):
        return len(self.data_index)
    
    def __getitem__(self, idx):
        file_name = f'{cfg["data_dir"]}/{self.data_index[idx]}'
        data = torch.tensor(nib.load(file_name).get_fdata()).float()

        return data
    
    
        # file_name = f'{data_dir}/{subject_dirs[0]}/Diffusion/data.nii.gz'
        # self.tmp = torch.tensor(nib.load(file_name).get_fdata()).float()

        
    #     # shape [145, 288, 145, 145]  [B, C, H, W]
    #     if cfg['normalize']:
    #         self.input = (data_in / data_in.amax(dim=(0,1,2))).permute(1,3,0,2)[14:159]
    #     else:
    #         self.input = data_in.permute(1,3,0,2)#[14:159]
        
    #     # shape [145, 145, 145]       [B, H, W]
    #     self.brain_mask = torch.tensor(
    #         nib.load(cfg["active_mask_path"]).get_fdata(), dtype=torch.bool
    #     ).permute(1,0,2)#[14:159]
            
    #     #if cfg['log']:
    #     #    wandb.config.update({'labels': cfg['labels']})
        
    #     if cfg['to_gpu']:
    #         self.input      = self.input.to(self.cfg["rank"])
    #         self.brain_mask = self.brain_mask.to(self.cfg["rank"]) 

        
    # def __getitem__(self, index) -> dict:

    #     input_ = self.input[index]
    #     mask = self.brain_mask[index]
        
    #     return {
    #         'input': input_,
    #         'mask':  mask
    #     } 
    

    # def __len__(self) -> int:
    #     if self.balance:
    #         return self.index_tensor.shape[1]
    #     else:
    #         return self.input.shape[0]

In [36]:
def get_train_loader(
    cfg: OmegaConf
):
    data_dir = cfg["data_dir"]
    if cfg.subjects == 'all':
        subject_dirs = [
            d for d in os.listdir(data_dir) 
            if os.path.isdir(os.path.join(data_dir, d))
            and 'corrupted' not in d
        ]
        n_subjects     = len(subject_dirs)
        train_subjects = subject_dirs[n_subjects//10:]
        val_subjects   = subject_dirs[:n_subjects//10]
        trainloader = DataLoader(
            PretrainingDataset(train_subjects, data_dir),
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=cfg.num_workers,
            pin_memory=True,
            drop_last=False,
        )
        valloader = DataLoader(
            PretrainingDataset(val_subjects, data_dir),
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=cfg.num_workers,
            pin_memory=True,
            drop_last=False,
        )

    else:
        trainloader = DataLoader(
            PretrainingDataset(cfg.subjects, data_dir),
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=cfg.num_workers,
            pin_memory=True,
            drop_last=False,
        )
        valloader = None

    return trainloader, valloader

In [37]:
trainloader, valloader = get_train_loader(cfg)

In [29]:
test = PretrainingDataset(cfg)
loader = DataLoader(test, batch_size=7, shuffle=True)

In [38]:
for batch in trainloader:
    print(batch.shape)
    break

torch.Size([16, 288, 145, 145])
