In [18]:
import os  
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pickle as pkl
import random 
import numpy as np

import time

class PatchDataset(Dataset):
    def __init__(self, files, pos="S19_", neg="S29_", transform=None):
        super().__init__()
        self.pos = pos
        self.neg = neg
        self.files = files
        self.transform = transform
 
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        
        sample_name = self.files[idx].split("/")[-1]
        true_class = self.get_true_class(sample_name, idx)

        with open(self.files[idx], "rb") as fin:
            patch = pkl.load(fin)
            masked = self.get_masked(patch)
            
            sample = {"image": masked,
                      "true_class": true_class}
            
            if self.transform:
                sample = transform(sample)
                
            return sample
            
            
    def get_masked(self, patch):
        tmp = patch.RGB
        tmp[~patch.mask]=0
        return tmp
    
    def get_true_class(self, sample_name, idx):
        
        if self.pos in sample_name:
            label = 1
        elif self.neg in sample_name:
            label = 0
            
        return label
        
class ToTensor():
    
    def __call__(self, sample):
        
        image, true_class = sample['image'], sample['true_class']        
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'true_class': torch.tensor(true_class)}

class RandomAffine():
    
    def __init__(self):
        self.transformer = transforms.RandomAffine(degrees=(0,360), translate=(0.1, 0.1), scale=(0.9, 1.1), fill=0)

    def __call__(self, sample):
        
        image, true_class = sample['image'], sample['true_class'] 
        
        return {'image': self.transformer(image),
                'true_class': true_class}
  
class RandomFlip():
    
    def __init__(self):
        self.vert = transforms.RandomVerticalFlip()
        self.hor = transforms.RandomHorizontalFlip()
    
    def __call__(self, sample):
        
        image, true_class = sample['image'], sample['true_class'] 
        
        image = self.vert(image)
        image = self.hor(image)
        
        return {'image': image,
                'true_class': true_class}
       
class Normalize():
    
    def __init__(self):
        self.means = torch.tensor([0.019045139, 0.033935968, 0.06551865])
        self.stds = torch.tensor([0.05771165, 0.11356566, 0.17648968])
        self.transformer = transforms.Normalize(self.means, self.stds)
        
    def __call__(self, sample):
        
        image, true_class = sample['image'], sample['true_class'] 
        
        return {'image': self.transformer(image),
                'true_class': true_class}
  
  
class MYCNTrainingSet(Dataset):
    
    def __init__(self, base, pos="S19_", neg="S29_", split=[0.8, 0.1, 0.1]):
        super().__init__()
        self.base = base
        self.pos = pos
        self.neg = neg
        self.split = split
        self.files = np.array(self._get_files(self.pos) + self._get_files(self.neg))
    
        assert sum(self.split) == 1, "Split must sum up to 1"
        
        train_files, val_files, test_files = self.get_split_indices()
        
        self.train_dataset = PatchDataset(train_files)
        self.val_dataset = PatchDataset(val_files)
        self.test_dataset = PatchDataset(test_files)
        
    def get_split_indices(self):
        
        total = len(self.files)
        train_num = int(len(self.files)*self.split[0])
        val_num = int(len(self.files)*self.split[1])
        test_num = len(self.files)-(train_num+val_num)
        idxs = list(range(len(self.files)))
        random.shuffle(idxs)
        
        train_idxs = idxs[:train_num]
        val_idxs = idxs[train_num:train_num+val_num]
        test_idxs = idxs[-test_num:]
        
        return self.files[train_idxs], self.files[val_idxs], self.files[test_idxs]
          
    def _get_files(self, pattern):
        return [os.path.join(self.base, x) for x in os.listdir(self.base) if pattern in x]

In [19]:
import matplotlib.pyplot as plt
base = "/data_isilon_main/isilon_images/10_MetaSystems/MetaSystemsData/MYCN_SpikeIn/results/single_patches"
transform=transforms.Compose([ToTensor(), RandomAffine(), RandomFlip(), Normalize()])
data = MYCNTrainingSet(base)
loader = DataLoader(data.train_dataset, batch_size=256, shuffle=True, num_workers=2)

for sample in loader:
    im = sample["image"]


torch.Size([256, 128, 128, 3])
torch.Size([256, 128, 128, 3])
torch.Size([256, 128, 128, 3])
torch.Size([256, 128, 128, 3])
torch.Size([256, 128, 128, 3])
torch.Size([256, 128, 128, 3])


KeyboardInterrupt: 

In [None]:
for x in loader:
    print(x)

In [None]:
base = "/data_isilon_main/isilon_images/10_MetaSystems/MetaSystemsData/MYCN_SpikeIn/results/patches"
out = "/data_isilon_main/isilon_images/10_MetaSystems/MetaSystemsData/MYCN_SpikeIn/results/single_patches"

import pickle as pkl
for pattern in ["S19_", "S29_"]:
    files = [os.path.join(base, x) for x in os.listdir(base) if pattern in x]
    for file in files[:200]:
        with open(file, "rb") as fin:
            n = file.split("/")[-1].split(".")[0]
            if os.path.isfile(os.path.join(out, f"{n}_0.ptch")):
                print(n + " already exists")
                continue
            print(n)
            dat = pkl.load(fin)
            for i, d in enumerate(dat):
                with open(os.path.join(out, f"{n}_{i}.ptch"), "wb+") as fout:
                    pkl.dump(d, fout)


In [None]:
import numpy as np
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

%matplotlib inline

out = "/data_isilon_main/isilon_images/10_MetaSystems/MetaSystemsData/MYCN_SpikeIn/results/single_patches"

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1) 

images = []
masks = []
for n, file in enumerate(tqdm(os.listdir(out)[:10000])):
    with open(os.path.join(out, file), "rb") as fin:
        patch = pkl.load(fin)
        tmp = patch.RGB
        tmp[~patch.mask] = 0
        images.append(tmp)

    images_arr = np.array(images)
    masks_arr = np.array(masks)

    R_means = np.mean(images_arr[..., 0])
    G_means = np.mean(images_arr[..., 1])
    B_means = np.mean(images_arr[..., 2])
    R_stds = np.std(images_arr[..., 0])
    G_stds = np.std(images_arr[..., 1])
    B_stds = np.std(images_arr[..., 2])
 
print(R_means[-1],G_means[-1], B_means[-1], R_stds[-1], G_stds[-1], B_stds[-1])

In [None]:
R_mean = np.mean(images[masks, 0])
G_mean = np.mean(images[masks, 1])
B_mean = np.mean(images[masks, 2])
R_std = np.std(images[masks, 0])
G_std = np.std(images[masks, 1])
B_std = np.std(images[masks, 2])

In [112]:
a = np.array([0,10,2,3,4,5,6])