### Import libraries

In [1]:
import sys
sys.path.append("../")

In [2]:
import os
from os import environ
import numpy as np
from random import choices
import pandas as pd
from tqdm.notebook import tqdm

In [3]:
import skimage.io as io

In [4]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.transforms as mtrans
%matplotlib inline

In [5]:
from preprocess.common import load_nii

In [6]:
from torch.utils.data import DataLoader
from torchcontrib.optim import SWA

# ---- My utils ----
from utils.data_augmentation import data_augmentation_selector
from utils.dataload import *
from utils.training import *

In [7]:
train_aug, train_aug_img, val_aug = data_augmentation_selector("none", 224, 224)

Using None Data Augmentation


In [11]:
data_partition = "train"
general_aug, img_aug = train_aug, train_aug_img
normalization = "standardize"
fold_system = "all"
label_type = "mask"


mnms_dataset = MMsDataset(
    mode="train", transform=train_aug, img_transform=train_aug_img,
    folding_system=fold_system, normalization=normalization, label_type=label_type
)


-------------------------
USING ALL DATA FOR TRAINING
-------------------------



In [12]:
len(mnms_dataset.df)

3284

In [15]:
weakly_dataset = MMsWeaklyDataset(
    mode="train", transform=train_aug, img_transform=train_aug_img,
    folding_system=fold_system, normalization=normalization, label_type=label_type
)


-------------------------
USING ALL DATA FOR TRAINING
-------------------------



In [16]:
len(weakly_dataset.df)

47059

In [79]:
from itertools import chain

In [80]:
def unpackTuple(tup): 
    res = [] 
    for i in chain(*tup): 
        res.append(i) 
          
    return res

In [235]:
class BalancedConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets
        self.max_len = max(len(d) for d in self.datasets)
        self.min_len = min(len(d) for d in self.datasets)

    def __getitem__(self, i):
        return tuple(d[i % len(d)] for d in self.datasets)

    def masks_collate(self, batch):
        # Only image - mask
        images, masks = [], []
        for item in range(len(batch)):
            for c_dataset in range(len(batch[item])):
                images.append(batch[item][c_dataset][0])
                masks.append(batch[item][c_dataset][1])
        images = torch.stack(images)
        masks = torch.stack(masks)
        return images, masks
    
    def __len__(self):
        return self.max_len

In [236]:
train_dataset = BalancedConcatDataset(mnms_dataset, weakly_dataset, weakly_dataset)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=True, collate_fn=train_dataset.masks_collate)

In [237]:
for a,b in train_loader:
    break