In [1]:
import sys
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import traceback
import torchvision
import os
import gc

import numpy as np # linear algebra
from sklearn.model_selection import train_test_split
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from skimage.io import imread
import pickle
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt

from skimage.transform import rescale, resize, downscale_local_mean


torch.backends.cudnn.benchmark=True

# Unet

In [2]:
class double_conv(nn.Module):
    ''' conv -> BN -> relu -> conv -> BN -> relu'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x

class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)
    
    def forward(self, x):
        return self.conv(x)

class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )
    
    def forward(self, x):
        return self.mpconv(x)

class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, out_ch//2, stride=2)
        
        self.conv = double_conv(in_ch, out_ch)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, (diffX//2, diffX - diffX//2,
                        diffY//2, diffY - diffY//2)
                  )
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x
    

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64,  128) #x2
        self.down2 = down(128, 256) #x3
        self.down3 = down(256, 512) #x4
        self.down4 = down(512, 512) #x5
        self.up1   = up(1024,256)
        self.up2   = up(512,128)
        self.up3   = up(256,64)
        self.up4   = up(128,64)
        self.outc  = outconv(64, n_classes)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(x5,x4) # (x5-512d + x4-512d  = 1024d--> 256d)
        x = self.up2(x,x3)  # (x-256d + x3 - 256d = 512d --> 128d)
        x = self.up3(x, x2) # (x-128d + x2 - 128d = 256d --> 64d)
        x = self.up4(x,x1)  # (x-64d  + x1 - 64d  = 128d --> 64d)
        x = self.outc(x)    # 64d --> n_classes_D
        
        return x


class KaggleDataset(Dataset):
    
    def __init__(self, ship_dir):
        self.ship_dir = ship_dir
        self.train_image_dir = os.path.join(self.ship_dir, 'train_v2')
        self.test_image_dir = os.path.join(self.ship_dir, 'test_v2')
        print("Starting preprocess")
        self.preprocess_pickle()
        
    def preprocess_pickle(self):
        with open('all_batches_balancedTrain.pickle', 'rb') as f:
            self.all_batches_balancedTrain = pickle.load(f)
        with open('all_batches_balancedValid.pickle', 'rb') as f:
            self.all_batches_balancedValid = pickle.load(f)

    def preprocess(self):
        
        def sample_ships(in_df, base_rep_val=1500):
            if in_df['ships'].values[0]==0:
                return in_df.sample(base_rep_val//3) # even more strongly undersample no ships
            else:
                return in_df.sample(base_rep_val, replace=(in_df.shape[0]<base_rep_val))
        masks = pd.read_csv(os.path.join(self.ship_dir, 'train_ship_segmentations_v2.csv'))
        
        masks['ships'] = masks['EncodedPixels'].map(lambda c_row: 1 if isinstance(c_row, str) else 0)
        unique_img_ids = masks.groupby('ImageId').agg({'ships': 'sum'}).reset_index()
        print("Reach 1")
        unique_img_ids['has_ship'] = unique_img_ids['ships'].map(lambda x: 1.0 if x>0 else 0.0)
        unique_img_ids['has_ship_vec'] = unique_img_ids['has_ship'].map(lambda x: [x])
        # some files are too small/corrupt
        print("Reach 1.2")
        unique_img_ids['file_size_kb'] = unique_img_ids['ImageId'].map(lambda c_img_id: 
                                                                       os.stat(os.path.join(self.train_image_dir, 
                                                                                            c_img_id)).st_size/1024)
        print("Reach 2")
        unique_img_ids = unique_img_ids[unique_img_ids['file_size_kb']>50] # keep only 50kb files
        masks.drop(['ships'], axis=1, inplace=True)
        train_ids, valid_ids = train_test_split(unique_img_ids, 
                         test_size = 0.3, 
                         stratify = unique_img_ids['ships'])
        
        
        print("Reach 3")
        train_df = pd.merge(masks, train_ids)
        valid_df = pd.merge(masks, valid_ids)
        train_df['grouped_ship_count'] = train_df['ships'].map(lambda x: (x+1)//2).clip(0, 7)

        
        print("Reach 4")
        balanced_train_df = train_df.groupby('grouped_ship_count').apply(sample_ships)
        print("Creating list")
        self.all_batches_balancedTrain = list(balanced_train_df.groupby('ImageId'))
        self.all_batches_balancedValid = list(valid_df.groupby('ImageId'))
        
        with open('all_batches_balancedTrain.pickle', 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(self.all_batches_balancedTrain, f, pickle.HIGHEST_PROTOCOL)

        with open('all_batches_balancedValid.pickle', 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(self.all_batches_balancedValid, f, pickle.HIGHEST_PROTOCOL)
            
    def __len__(self):
        return len(self.all_batches_balancedTrain)
    
    def multi_rle_encode(self, img):
        labels = label(img[:, :, 0])
        return [rle_encode(labels==k) for k in np.unique(labels[labels>0])]

    # ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
    def rle_encode(self, img):
        '''
        img: numpy array, 1 - mask, 0 - background
        Returns run length as string formated
        '''
        pixels = img.T.flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        return ' '.join(str(x) for x in runs)

    def rle_decode(self, mask_rle, shape=(768, 768)):
        '''
        mask_rle: run-length as string formated (start length)
        shape: (height,width) of array to return 
        Returns numpy array, 1 - mask, 0 - background
        '''
        s = mask_rle.split()
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
        starts -= 1
        ends = starts + lengths
        img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
        for lo, hi in zip(starts, ends):
            img[lo:hi] = 1
        return img.reshape(shape).T  # Needed to align to RLE direction

    def masks_as_image(self, in_mask_list):
        # Take the individual ship masks and create a single mask array for all ships
        all_masks = np.zeros((768, 768), dtype = np.int16)
        #if isinstance(in_mask_list, list):
        for mask in in_mask_list:
            if isinstance(mask, str):
                all_masks += self.rle_decode(mask)
        return np.expand_dims(all_masks, -1)

    def __getitem__(self, idx):
        factor = 5
        rgb_path = os.path.join(self.train_image_dir, self.all_batches_balancedTrain[idx][0])
        c_img = imread(rgb_path)
        c_mask = self.masks_as_image( self.all_batches_balancedTrain[idx][1]['EncodedPixels'].values)
        
        c_img = np.stack(c_img, 0)/255.0
        c_mask = np.stack(c_mask, 0)
        
#         c_img = resize(c_img, (c_img.shape[0] / 2, c_img.shape[1] / 2),
#                        anti_aliasing=True)
        
#         c_mask = resize(c_mask, (c_mask.shape[0] / 2, c_mask.shape[1] / 2),
#                        anti_aliasing=True)
        
        c_img = resize(c_img, (c_img.shape[0] // factor, c_img.shape[1] // factor),
                       anti_aliasing=True)
        
        c_mask = resize(c_mask, (c_mask.shape[0] // factor, c_mask.shape[1] // factor),
                       anti_aliasing=True)
        c_img = c_img.transpose(-1, 0, 1)
        c_mask = c_mask.transpose(-1, 0, 1)
        
        
        return c_img.astype('f'), c_mask.astype('f')

    def show(self, x, y):
        f, axarr = plt.subplots(1,2, figsize=(15, 15))

        axarr[0].imshow(x.transpose(-1, 1, 0))
        axarr[1].imshow(y.transpose(-1, 1, 0)[:, :, 0])
            
def train(net, criterion, optimizer, epochs, trainLoader):
    print ('Training has begun ...')
    running_loss = 0
    for epoch in range(epochs):
        for i, data in enumerate(trainLoader):
            try:
                print ('Epoch : {0} || BatchID : '.format(epoch,i)) 
                X,Y = data
                print (X.size(), Y.size())
                optimizer.zero_grad()

                Y_   = net(X.cuda())
                print (Y_.size())
                loss = criterion(Y_, Y.cuda())
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                print("Loss is {}".format(running_loss))
                if (1):    # print every 2000 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss))
                    running_loss = 0.0 
            except Exception:
                traceback.print_exc()
                sys.exit(1)
            
def dice_coeff(pred, target):
    smooth = 1.
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()

    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

In [None]:
if __name__ == "__main__":
    gc.collect()
    net       = UNet(3, 1).cuda()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
    criterion = dice_coeff
    
    ship_dir = '/media/shivam/DATA/airbus-tracking/'
    trainDataset = KaggleDataset(ship_dir)
    
    trainDataLoader   = torch.utils.data.DataLoader(
            trainDataset
            , batch_size=4,shuffle=True
            , num_workers=1, pin_memory=True)
    
    train(net, criterion, optimizer, 2, trainDataLoader)
    
    
    verbose = 0
    if verbose:
        y = net(torch.Tensor(np.random.random((1,3,256,256))))
        print (y.size())
        
        print (net)
        
        for param in net.parameters():
            print (param.size())

Starting preprocess
Training has begun ...


  warn("The default mode, 'constant', will be changed to 'reflect' in "


Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])




torch.Size([4, 1, 153, 153])
Loss is 7.569840818177909e-05
[1,     1] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.846106746001169e-05
[1,     2] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.73114079493098e-05
[1,     3] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.957056368468329e-05
[1,     4] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.486938557121903e-05
[1,     5] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.592514157295227e-05
[1,     6] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1,

Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.720198482275009e-05
[1,    52] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.85633092164062e-05
[1,    53] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.921427459223196e-05
[1,    54] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.644472498213872e-05
[1,    55] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.94798179413192e-05
[1,    56] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.726593321422115e-05
[1,    57] loss: 0.000
Epoch : 0 || Batch

Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.751909288344905e-05
[1,   102] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.367628131760284e-05
[1,   103] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.700241985730827e-05
[1,   104] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.637112867087126e-05
[1,   105] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.888479740358889e-05
[1,   106] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 8.086044545052573e-05
[1,   107] loss: 0.000
Epoch : 0 || Bat

Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 8.009323937585577e-05
[1,   152] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.765371992718428e-05
[1,   153] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.837373414076865e-05
[1,   154] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.448058750014752e-05
[1,   155] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.93573708506301e-05
[1,   156] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.78066532802768e-05
[1,   157] loss: 0.000
Epoch : 0 || Batch

Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.899132469901815e-05
[1,   203] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.81523558543995e-05
[1,   204] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.607301813550293e-05
[1,   205] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 8.413241448579356e-05
[1,   206] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.818326412234455e-05
[1,   207] loss: 0.000
Epoch : 0 || BatchID : 
torch.Size([4, 3, 153, 153]) torch.Size([4, 1, 153, 153])
torch.Size([4, 1, 153, 153])
Loss is 7.517730409745127e-05
[1,   208] loss: 0.000
Epoch : 0 || Batc

In [None]:
# del net
# torch.cuda.empty_cache()

# Rough

In [None]:
class KaggleDataset(Dataset):
    
    def __init__(self, ship_dir):
        self.ship_dir = ship_dir
        self.train_image_dir = os.path.join(self.ship_dir, 'train_v2')
        self.test_image_dir = os.path.join(self.ship_dir, 'test_v2')
        print("Starting preprocess")
        self.preprocess_pickle()
        
    def preprocess_pickle(self):
        with open('all_batches_balancedTrain.pickle', 'rb') as f:
            self.all_batches_balancedTrain = pickle.load(f)
        with open('all_batches_balancedValid.pickle', 'rb') as f:
            self.all_batches_balancedValid = pickle.load(f)

    def preprocess(self):
        
        def sample_ships(in_df, base_rep_val=1500):
            if in_df['ships'].values[0]==0:
                return in_df.sample(base_rep_val//3) # even more strongly undersample no ships
            else:
                return in_df.sample(base_rep_val, replace=(in_df.shape[0]<base_rep_val))
        masks = pd.read_csv(os.path.join(self.ship_dir, 'train_ship_segmentations_v2.csv'))
        
        masks['ships'] = masks['EncodedPixels'].map(lambda c_row: 1 if isinstance(c_row, str) else 0)
        unique_img_ids = masks.groupby('ImageId').agg({'ships': 'sum'}).reset_index()
        print("Reach 1")
        unique_img_ids['has_ship'] = unique_img_ids['ships'].map(lambda x: 1.0 if x>0 else 0.0)
        unique_img_ids['has_ship_vec'] = unique_img_ids['has_ship'].map(lambda x: [x])
        # some files are too small/corrupt
        print("Reach 1.2")
        unique_img_ids['file_size_kb'] = unique_img_ids['ImageId'].map(lambda c_img_id: 
                                                                       os.stat(os.path.join(self.train_image_dir, 
                                                                                            c_img_id)).st_size/1024)
        print("Reach 2")
        unique_img_ids = unique_img_ids[unique_img_ids['file_size_kb']>50] # keep only 50kb files
        masks.drop(['ships'], axis=1, inplace=True)
        train_ids, valid_ids = train_test_split(unique_img_ids, 
                         test_size = 0.3, 
                         stratify = unique_img_ids['ships'])
        
        
        print("Reach 3")
        train_df = pd.merge(masks, train_ids)
        valid_df = pd.merge(masks, valid_ids)
        train_df['grouped_ship_count'] = train_df['ships'].map(lambda x: (x+1)//2).clip(0, 7)

        
        print("Reach 4")
        balanced_train_df = train_df.groupby('grouped_ship_count').apply(sample_ships)
        print("Creating list")
        self.all_batches_balancedTrain = list(balanced_train_df.groupby('ImageId'))
        self.all_batches_balancedValid = list(valid_df.groupby('ImageId'))
        
        with open('all_batches_balancedTrain.pickle', 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(self.all_batches_balancedTrain, f, pickle.HIGHEST_PROTOCOL)

        with open('all_batches_balancedValid.pickle', 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(self.all_batches_balancedValid, f, pickle.HIGHEST_PROTOCOL)
            
    def __len__(self):
        return len(self.all_batches_balancedTrain)
    
    def multi_rle_encode(self, img):
        labels = label(img[:, :, 0])
        return [rle_encode(labels==k) for k in np.unique(labels[labels>0])]

    # ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
    def rle_encode(self, img):
        '''
        img: numpy array, 1 - mask, 0 - background
        Returns run length as string formated
        '''
        pixels = img.T.flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        return ' '.join(str(x) for x in runs)

    def rle_decode(self, mask_rle, shape=(768, 768)):
        '''
        mask_rle: run-length as string formated (start length)
        shape: (height,width) of array to return 
        Returns numpy array, 1 - mask, 0 - background
        '''
        s = mask_rle.split()
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
        starts -= 1
        ends = starts + lengths
        img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
        for lo, hi in zip(starts, ends):
            img[lo:hi] = 1
        return img.reshape(shape).T  # Needed to align to RLE direction

    def masks_as_image(self, in_mask_list):
        # Take the individual ship masks and create a single mask array for all ships
        all_masks = np.zeros((768, 768), dtype = np.int16)
        #if isinstance(in_mask_list, list):
        for mask in in_mask_list:
            if isinstance(mask, str):
                all_masks += self.rle_decode(mask)
        return np.expand_dims(all_masks, -1)

    def __getitem__(self, idx):
        rgb_path = os.path.join(self.train_image_dir, self.all_batches_balancedTrain[idx][0])
        c_img = imread(rgb_path)
        c_mask = self.masks_as_image( self.all_batches_balancedTrain[idx][1]['EncodedPixels'].values)
        
        c_img = np.stack(c_img, 0)/255.0
        c_mask = np.stack(c_mask, 0)
        
        c_img = resize(c_img, (c_img.shape[0] / 2, c_img.shape[1] / 2),
                       anti_aliasing=True)
        
        c_mask = resize(c_mask, (c_mask.shape[0] / 2, c_mask.shape[1] / 2),
                       anti_aliasing=True)
        
        c_img = c_img.transpose(-1, 0, 1)
        c_mask = c_mask.transpose(-1, 0, 1)
        
        
        return c_img, c_mask

    def show(self, x, y):
        f, axarr = plt.subplots(1,2, figsize=(15, 15))

        axarr[0].imshow(x.transpose(-1, 1, 0))
        axarr[1].imshow(y.transpose(-1, 1, 0)[:, :, 0])
            

if(10):
    ship_dir = '/media/shivam/DATA/airbus-tracking/'
    trainDataset = KaggleDataset(ship_dir)

In [None]:

# # x, y = trainDataset[1]

# # yp = np.ones_like(y)
# y = torch.from_numpy(np.random.random((4, 1, 153, 153)))
# yp = torch.from_numpy(np.random.random((4, 1, 153, 153)))

# print(dice_coeff(y, yp))

In [None]:
# a,b = trainDataset[199]
# dataLoader = torch.utils.data.DataLoader(trainDataset
#             , batch_size=4,
#             shuffle=True, num_workers=1, pin_memory=True)