In [486]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt
import os
import random
import time
import copy
import glob 
from PIL import Image
import cv2 

import torchvision
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

print(os.listdir("../input"))

['train', 'test', 'train.csv', 'sample_submission.csv']


In [487]:
use_gpu = torch.cuda.device_count() > 0
print("{} GPU's available:".format(torch.cuda.device_count()) )

0 GPU's available:


In [489]:
name_label_dict = {
    0:  'Nucleoplasm',
    1:  'Nuclear membrane',
    2:  'Nucleoli',   
    3:  'Nucleoli fibrillar center',
    4:  'Nuclear speckles',
    5:  'Nuclear bodies',
    6:  'Endoplasmic reticulum',   
    7:  'Golgi apparatus',
    8:  'Peroxisomes',
    9:  'Endosomes',
    10:  'Lysosomes',
    11:  'Intermediate filaments',
    12:  'Actin filaments',
    13:  'Focal adhesion sites',   
    14:  'Microtubules',
    15:  'Microtubule ends',  
    16:  'Cytokinetic bridge',   
    17:  'Mitotic spindle',
    18:  'Microtubule organizing center',  
    19:  'Centrosome',
    20:  'Lipid droplets',
    21:  'Plasma membrane',   
    22:  'Cell junctions', 
    23:  'Mitochondria',
    24:  'Aggresome',
    25:  'Cytosol',
    26:  'Cytoplasmic bodies',   
    27:  'Rods & rings' }

In [490]:
class CellsDataset(Dataset):

    def __init__(self, path2data, path2labels, isTest=False, transforms=None):
        
        self.transform = transforms 
        self.path2data = path2data
        self.X = glob.glob(self.path2data + '/*.png')
        self.labels = None
        self.isTest = isTest
        
        if not self.isTest:
            self.labels = pd.read_csv(path2labels).set_index('Id')
            self.labels['Target'] = [[int(i) for i in s.split()] for s in self.labels['Target']]

            
    def open_rgby(self,id): #a function that reads RGBY image
        colors = ['red','green','blue','yellow']
        flags = cv2.IMREAD_GRAYSCALE
        img = [cv2.imread(os.path.join(self.path2data, id+'_'+color+'.png'), flags).astype(np.float32)/255
               for color in colors]
        return np.stack(img, axis=-1)

    
    def __getitem__(self, index):
        
        path2img = self.X[index]
        
        image = self.open_rgby(path2img.split('_')[0])
                
        if not self.isTest:
            labels =np.zeros(len(name_label_dict),dtype=np.int)
        else:
            labels = self.labels.loc[self.fnames[i]]['Target']
            label = np.eye(len(name_label_dict),dtype=np.float)[labels].sum(axis=0)
        
        if self.transform:
            image = self.transform(image)
        return image, labels

    def __len__(self):
        return len(self.X)

In [491]:
class AdjustGamma(object):
    def __call__(self, img):
        return transforms.functional.adjust_gamma(img, 0.8, gain=1)

In [492]:
class AdjustContrast(object):
    def __call__(self, img):
        return transforms.functional.adjust_contrast(img, 2)

In [493]:
class AdjustBrightness(object):
    def __call__(self, img):
        return transforms.functional.adjust_brightness(img, 2)

In [494]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        AdjustGamma(),
        AdjustContrast(),
        AdjustBrightness(),
        transforms.ToTensor()
    ]),
    'valid': transforms.Compose([
        transforms.RandomHorizontalFlip(), 
        transforms.RandomVerticalFlip(),
        AdjustGamma(),
        AdjustContrast(),
        transforms.ToTensor(),
    ]),
}

In [495]:
dsets = {
    'train': CellsDataset('../input/train', '../input/train.csv', transforms=data_transforms['train']),
    'valid': CellsDataset('../input/train', '../input/train.csv', transforms=data_transforms['valid']),
    'test':  CellsDataset('../input/test', None, isTest=True, transforms=data_transforms['valid']),
}

In [496]:
batch_size = 32
random_seed = 3
valid_size = 0.2
shuffle = True

In [497]:
def create_dataLoader(dsets, batch_size, pin_memory=False):
    dset_loaders = {} 
    for key in dsets.keys():
        dset_loaders[key] = DataLoader(dsets[key], batch_size=batch_size, pin_memory=pin_memory) #sampler=sampler[key],
    return dset_loaders

In [498]:
dset_loaders = create_dataLoader(dsets, batch_size, pin_memory=False)

In [499]:
dset_loaders.keys()

dict_keys(['train', 'valid', 'test'])

In [500]:
def plot_volcanos(dset_loaders, is_train = True, preds_test = [], preds_train = []):
    
    X, y = next(iter(dset_loaders))
    X, y = X.numpy(), y.numpy()
    
    plt.figure(figsize=(20,10))
    for i in range(0, 4):
        plt.subplot(1,4,i+1)
        
        rand_img = random.randrange(0, X.shape[0])
        img = X[rand_img,:,:,:]
        plt.imshow((img[0,:,:]*255).astype(np.int))
        #plt.title('Volcano: {}'.format(y[rand_img]))
        plt.axis('off')

In [501]:
image, label = next(iter(dset_loaders['train']))
print(image.size(), label.size())

AttributeError: 'NoneType' object has no attribute 'astype'

In [502]:
plot_volcanos(dset_loaders['train'])

AttributeError: 'NoneType' object has no attribute 'astype'