In [7]:
import sys
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import traceback
import torchvision
import os
import gc

import numpy as np # linear algebra
from sklearn.model_selection import train_test_split
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from skimage.io import imread
import pickle
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt

from skimage.transform import rescale, resize, downscale_local_mean
from skimage import img_as_bool


torch.backends.cudnn.benchmark=True


TRAINING_STATS = "UNETv1_checkpoint/progress.csv"
TRAINED_UNET_MODEL = "UNETv1_checkpoint/model.pt"
SHIP_DIR = "/media/shivam/DATA/airbus-tracking/"
TEST_IMAGE_DIR = os.path.join(SHIP_DIR, "test_v2")


VALIDATION_SIZE = 10
VALIDATION_BATCH = 2

TRAIN_BATCH = 10

# Unet

In [4]:
class double_conv(nn.Module):
    ''' conv -> BN -> relu -> conv -> BN -> relu'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x

class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)
    
    def forward(self, x):
        return self.conv(x)

class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )
    
    def forward(self, x):
        return self.mpconv(x)

class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_ch//2, out_ch//2, stride=2)
        
        self.conv = double_conv(in_ch, out_ch)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, (diffX//2, diffX - diffX//2,
                        diffY//2, diffY - diffY//2)
                  )
        
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x
    

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64,  128) #x2
        self.down2 = down(128, 256) #x3
        self.down3 = down(256, 512) #x4
        self.down4 = down(512, 512) #x5
        self.up1   = up(1024,256)
        self.up2   = up(512,128)
        self.up3   = up(256,64)
        self.up4   = up(128,64)
        self.outc  = outconv(64, n_classes)
    
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(x5,x4) # (x5-512d + x4-512d  = 1024d--> 256d)
        x = self.up2(x,x3)  # (x-256d + x3 - 256d = 512d --> 128d)
        x = self.up3(x, x2) # (x-128d + x2 - 128d = 256d --> 64d)
        x = self.up4(x,x1)  # (x-64d  + x1 - 64d  = 128d --> 64d)
        x = self.outc(x)    # 64d --> n_classes_D
        
        return x


class KaggleDataset(Dataset):
    
    def __init__(self, ship_dir, use_csv=False, ship_count=5):
        self.ship_dir = ship_dir
        self.train_image_dir = os.path.join(self.ship_dir, 'train_v2')
        self.test_image_dir = os.path.join(self.ship_dir, 'test_v2')
        print("Starting preprocess")
        self.preprocess(use_csv, ship_count)
#         self.preprocess_pickle()
        
    def preprocess_pickle(self):
        with open('all_batches_balancedTrain.pickle', 'rb') as f:
            self.all_batches_balancedTrain = pickle.load(f)
        with open('all_batches_balancedValid.pickle', 'rb') as f:
            self.all_batches_balancedValid = pickle.load(f)

    def preprocess(self, use_csv, min_ship_count):
        
        def sample_ships(in_df, base_rep_val=1500):
            if in_df['ships'].values[0]==0:
                return in_df.sample(base_rep_val//3) # even more strongly undersample no ships
            else:
                return in_df.sample(base_rep_val, replaice=(in_df.shape[0]<base_rep_val))
            
        if not use_csv:
            masks = pd.read_csv(os.path.join(self.ship_dir, 'train_ship_segmentations_v2.csv'))

    #         masks = masks.sample(len(masks)/2)
            masks['ships'] = masks['EncodedPixels'].map(lambda c_row: 1 if isinstance(c_row, str) else 0)
            unique_img_ids = masks.groupby('ImageId').agg({'ships': 'sum'}).reset_index()
            print("Reach 1")
            unique_img_ids['has_ship'] = unique_img_ids['ships'].map(lambda x: 1.0 if x>0 else 0.0)
            unique_img_ids['has_ship_vec'] = unique_img_ids['has_ship'].map(lambda x: [x])
            print(len(unique_img_ids))
            # some files are too small/corrupt
            print("Reach 1.2")
            unique_img_ids['file_size_kb'] = unique_img_ids['ImageId'].map(lambda c_img_id: 
                                                                           os.stat(os.path.join(self.train_image_dir, 
                                                                                                c_img_id)).st_size/1024)
            print("Reach 2")
            unique_img_ids = unique_img_ids[unique_img_ids['file_size_kb']>50] # keep only 50kb files
            masks.drop(['ships'], axis=1, inplace=True)
            train_ids, valid_ids = train_test_split(unique_img_ids, 
                             test_size = 0.3, 
                             stratify = unique_img_ids['ships'])


            print("Reach 3")
            train_df = pd.merge(masks, train_ids)
            valid_df = pd.merge(masks, valid_ids)
            
            # train_df['grouped_ship_count'] = train_df['ships'].map(lambda x: (x+1)//2).clip(0, 7)
            train_df['grouped_ship_count'] = train_df['ships'].map(lambda x: (x+1)//2)
            self.train_df = train_df
            self.valid_df = valid_df
            print("Reach 4")
            balanced_train_df = train_df.groupby('grouped_ship_count').apply(sample_ships)
            # TODO; save function 
            balanced_train_df.to_csv("balanced_train_df.csv", index=False)
            
            self.all_batches_balancedTrain = list(balanced_train_df.groupby('ImageId'))
            with open('all_batches_balancedTrain.pickle', 'wb') as f:
                # Pickle the 'data' dictionary using the highest protocol available.
                pickle.dump(self.all_batches_balancedTrain, f, pickle.HIGHEST_PROTOCOL)
            
            self.all_batches_balancedValid = list(valid_df.groupby('ImageId'))
            with open('all_batches_balancedValid.pickle', 'wb') as f:
                # Pickle the 'data' dictionary using the highest protocol available.
                pickle.dump(self.all_batches_balancedValid, f, pickle.HIGHEST_PROTOCOL)
            
        else:
            filename_train = 'all_batches_balancedTrain_me_{0}.pickle'.format(min_ship_count)
            if os.path.exists(filename_train):
                print("Using existing files : ", filename_train)
                with open(filename_train, 'rb') as f:
                    self.all_batches_balancedTrain = pickle.load(f)
                with open('all_batches_balancedValid.pickle', 'rb') as f:
                    self.all_batches_balancedValid = pickle.load(f)
            else:
                print("Creating new files")
                balanced_train_df = pd.read_csv('balanced_train_df.csv')
                ourBalanced_train_df = balanced_train_df[balanced_train_df['grouped_ship_count'] >= min_ship_count]
            
                self.all_batches_balancedTrain = list(ourBalanced_train_df.groupby('ImageId'))
                with open(filename_train, 'wb') as f:
                    # Pickle the 'data' dictionary using the highest protocol available.
                    pickle.dump(self.all_batches_balancedTrain, f, pickle.HIGHEST_PROTOCOL)
                    
                with open('all_batches_balancedValid.pickle', 'rb') as f:
                    self.all_batches_balancedValid = pickle.load(f)
                
    def __len__(self):
        return len(self.all_batches_balancedTrain)
    
    def multi_rle_encode(self, img):
        labels = label(img[:, :, 0])
        return [rle_encode(labels==k) for k in np.unique(labels[labels>0])]

    # ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
    def rle_encode(self, img):
        '''
        img: numpy array, 1 - mask, 0 - background
        Returns run length as string formated
        '''
        pixels = img.T.flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        return ' '.join(str(x) for x in runs)

    def rle_decode(self, mask_rle, shape=(768, 768)):
        '''
        mask_rle: run-length as string formated (start length)
        shape: (height,width) of array to return 
        Returns numpy array, 1 - mask, 0 - background
        '''
        s = mask_rle.split()
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
        starts -= 1
        ends = starts + lengths
        img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
        for lo, hi in zip(starts, ends):
            img[lo:hi] = 1
        return img.reshape(shape).T  # Needed to align to RLE direction

    def masks_as_image(self, in_mask_list):
        # Take the individual ship masks and create a single mask array for all ships
        all_masks = np.zeros((768, 768), dtype = np.int16)
        #if isinstance(in_mask_list, list):
        for mask in in_mask_list:
            if isinstance(mask, str):
                all_masks += self.rle_decode(mask)
        return np.expand_dims(all_masks, -1)

    def __getitem__(self, idx):
        crop_delta = 192
        factor = 5
        rgb_path = os.path.join(self.train_image_dir, self.all_batches_balancedTrain[idx][0])
        c_img = imread(rgb_path)
        c_mask = self.masks_as_image( self.all_batches_balancedTrain[idx][1]['EncodedPixels'].values)
        
        c_img = np.stack(c_img, 0)/255.0
        c_mask = np.stack(c_mask, 0)
        
        h, w, _ = c_mask.shape

        # Random crop selection trick
        c = 0
        while (c < 5): 
            x1 = np.random.randint(0, h-crop_delta)
            x2 = np.random.randint(0, w-crop_delta)
            c_mask_s = c_mask[x1:x1+crop_delta, x2:x2+crop_delta, :]
            c += 1;
            if (np.sum(c_mask_s) > 200):
                break
            
        c_img_s = c_img[x1:x1+crop_delta, x2:x2+crop_delta, :]
        
        # Resizing image trick (not attempted yet)
#         c_img = resize(c_img, (c_img.shape[0] // factor, c_img.shape[1] // factor),
#                        anti_aliasing=False)
        
#         c_mask = resize(c_mask, (c_mask.shape[0] // factor, c_mask.shape[1] // factor),
#                        anti_aliasing=False)
        c_img = c_img_s.transpose(-1, 0, 1)
        c_mask = c_mask_s.transpose(-1, 0, 1)
#         print(c_img_s.shape, c_mask_s.shape)
#         print(c_img.shape, c_mask.shape)
        
        return c_img.astype('f'), c_mask.astype('f')
    
    def extract_image(self, idx, datapath, data):
        rgb_path = os.path.join(datapath, data[idx][0])
        c_img = imread(rgb_path)
        c_mask = self.masks_as_image(data[idx][1]['EncodedPixels'].values)
        
        c_img = c_img.transpose(-1, 0, 1)
        c_mask = c_mask.transpose(-1, 0, 1)
        return c_img.astype('f'), c_mask.astype('f')
        
    def validationset(self, size=10, batch_size=2):
        random_batches = [np.random.randint(0, len(self.all_batches_balancedValid)-batch_size) for _ in range(10)]
        for i in random_batches:
            X = []
            Y = []
            for j in range(batch_size):
                X_temp, y_temp = self.extract_image(i+j, self.train_image_dir, self.all_batches_balancedValid)
                X.append(X_temp)
                Y.append(y_temp)
            X = np.array(X)
            Y = np.array(Y)
            yield X, Y
            
        
    def show(self, x, y):
        f, axarr = plt.subplots(1,2, figsize=(15, 15))

        axarr[0].imshow(x.transpose(-1, 1, 0))
        axarr[1].imshow(y.transpose(-1, 1, 0)[:, :, 0])


    
def train(net, criterion, optimizer, epochs, trainLoader):
    print ('Training has begun ...')
    training_stats = []
    for epoch in range(epochs):
        running_loss = 0
        val_loss = 0;
        # Train with all available data.
        print("Training in epoch: {}".format(epoch+1))
        tcount = 0
        for i, data in enumerate(trainLoader):
            tcount += 1
            X,Y = data
            optimizer.zero_grad()

            Y_   = net(X.cuda())
            loss = criterion(Y_, Y.cuda())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        # Validate after each epoch.
        print("Validating in epoch: {}".format(epoch+1))
        with torch.no_grad():
            vcount = 0 
            for X_im, y_im in trainDataLoader.dataset.validationset(VALIDATION_SIZE, VALIDATION_BATCH):
                y_pred = net(torch.from_numpy(X_im).cuda())
                val_loss += criterion(y_pred, torch.from_numpy(y_im).cuda())
            
        # Normalize and save
        running_loss /= len(trainLoader.dataset)
        val_loss /= VALIDATION_SIZE*VALIDATION_BATCH
        training_stats.append([running_loss, val_loss])
        pd.DataFrame(training_stats).to_csv(TRAINING_STATS, header = ['running_loss', 'val_loss'], index = False)
        
        # Save model
        if (epoch%5==0):
            torch.save(net.state_dict(), TRAINED_UNET_MODEL)
        
        # Empty gpu cache
        torch.cuda.empty_cache()
        print("Epoch: {}, running loss: {:.4f}, validation loss: {:.4f}".format(epoch+1, running_loss, val_loss))
        

def dice_coeff(pred, target):
    smooth = 1.
    num = pred.size(0)
    m1 = pred.view(num, -1)  # Flatten
    m2 = target.view(num, -1)  # Flatten
    intersection = (m1 * m2).sum()

    return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

In [5]:
# Load in Dataset
ship_dir = '/media/shivam/DATA/airbus-tracking/'
trainDataset = KaggleDataset(ship_dir,use_csv=True, ship_count=2)

Starting preprocess
Creating new files


## Initializing and training the U-Net

In [8]:
# Construct UNet
gc.collect()
reuse = False

net = UNet(3, 1).cuda()
if reuse:
    print("Reusing model from: {}".format(TRAINED_UNET_MODEL))
    net.load_state_dict(torch.load(TRAINED_UNET_MODEL))
    net.eval()
    
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005)
criterion = dice_coeff

# Training the model
trainDataLoader   = torch.utils.data.DataLoader(
        trainDataset
        , batch_size=TRAIN_BATCH,shuffle=True
        , num_workers=1, pin_memory=True)

## Training

In [None]:
train(net, criterion, optimizer, 100, trainDataLoader)

Training has begun ...
Training in epoch: 1




Validating in epoch: 1
Epoch: 1, running loss: 0.0022, validation loss: 0.0034
Training in epoch: 2
Validating in epoch: 2
Epoch: 2, running loss: 0.0022, validation loss: 0.0072
Training in epoch: 3
Validating in epoch: 3
Epoch: 3, running loss: 0.0018, validation loss: 0.0013
Training in epoch: 4
Validating in epoch: 4
Epoch: 4, running loss: 0.0017, validation loss: 0.0039
Training in epoch: 5
Validating in epoch: 5
Epoch: 5, running loss: 0.0013, validation loss: 0.0005
Training in epoch: 6
Validating in epoch: 6
Epoch: 6, running loss: 0.0013, validation loss: 0.0032
Training in epoch: 7
Validating in epoch: 7
Epoch: 7, running loss: 0.0012, validation loss: 0.0012
Training in epoch: 8


## Inference

In [None]:
from tqdm import tqdm_notebook
from skimage.morphology import binary_opening, disk

# Load testing images
test_paths = os.listdir(TEST_IMAGE_DIR)
print(len(test_paths), 'test images found')

# Load inference model
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load(TRAINED_UNET_MODEL))
net.eval()

In [None]:
# TEST RUN
fig, m_axs = plt.subplots(8, 2, figsize = (10, 40))
for (ax1, ax2), c_img_name in zip(m_axs, test_paths):
    c_path = os.path.join(TEST_IMAGE_DIR, c_img_name)
    c_img = imread(c_path)
    first_img = np.expand_dims(c_img, 0)/255.0
    first_seg = fullres_model.predict(first_img)
    ax1.imshow(first_img[0])
    ax1.set_title('Image')
    ax2.imshow(first_seg[0, :, :, 0], vmin = 0, vmax = 1)
    ax2.set_title('Prediction')
fig.savefig('test_predictions.png')

In [None]:
out_pred_rows = []
for c_img_name in tqdm_notebook(test_paths):
    c_path = os.path.join(test_image_dir, c_img_name)
    c_img = imread(c_path)
    c_img = np.expand_dims(c_img, 0)/255.0
    
    with torch.no_grad():
        cur_seg = net(torch.from_numpy(c_img).cuda())
            
    cur_seg = binary_opening(cur_seg>0.5, np.expand_dims(disk(2), -1))
    cur_rles = multi_rle_encode(cur_seg)
    if len(cur_rles)>0:
        for c_rle in cur_rles:
            out_pred_rows += [{'ImageId': c_img_name, 'EncodedPixels': c_rle}]
    else:
        out_pred_rows += [{'ImageId': c_img_name, 'EncodedPixels': None}]
    gc.collect()
    

# Rough

In [None]:
class KaggleDataset(Dataset):
    
    def __init__(self, ship_dir):
        self.ship_dir = ship_dir
        self.train_image_dir = os.path.join(self.ship_dir, 'train_v2')
        self.test_image_dir = os.path.join(self.ship_dir, 'test_v2')
        print("Starting preprocess")
        self.preprocess_pickle()
        
    def preprocess_pickle(self):
        with open('all_batches_balancedTrain.pickle', 'rb') as f:
            self.all_batches_balancedTrain = pickle.load(f)
        with open('all_batches_balancedValid.pickle', 'rb') as f:
            self.all_batches_balancedValid = pickle.load(f)

    def preprocess(self):
        
        def sample_ships(in_df, base_rep_val=1500):
            if in_df['ships'].values[0]==0:
                return in_df.sample(base_rep_val//3) # even more strongly undersample no ships
            else:
                return in_df.sample(base_rep_val, replace=(in_df.shape[0]<base_rep_val))
        masks = pd.read_csv(os.path.join(self.ship_dir, 'train_ship_segmentations_v2.csv'))
        
        masks['ships'] = masks['EncodedPixels'].map(lambda c_row: 1 if isinstance(c_row, str) else 0)
        unique_img_ids = masks.groupby('ImageId').agg({'ships': 'sum'}).reset_index()
        print("Reach 1")
        unique_img_ids['has_ship'] = unique_img_ids['ships'].map(lambda x: 1.0 if x>0 else 0.0)
        unique_img_ids['has_ship_vec'] = unique_img_ids['has_ship'].map(lambda x: [x])
        # some files are too small/corrupt
        print("Reach 1.2")
        unique_img_ids['file_size_kb'] = unique_img_ids['ImageId'].map(lambda c_img_id: 
                                                                       os.stat(os.path.join(self.train_image_dir, 
                                                                                            c_img_id)).st_size/1024)
        print("Reach 2")
        unique_img_ids = unique_img_ids[unique_img_ids['file_size_kb']>50] # keep only 50kb files
        masks.drop(['ships'], axis=1, inplace=True)
        train_ids, valid_ids = train_test_split(unique_img_ids, 
                         test_size = 0.3, 
                         stratify = unique_img_ids['ships'])
        
        
        print("Reach 3")
        train_df = pd.merge(masks, train_ids)
        valid_df = pd.merge(masks, valid_ids)
        train_df['grouped_ship_count'] = train_df['ships'].map(lambda x: (x+1)//2).clip(0, 7)

        
        print("Reach 4")
        balanced_train_df = train_df.groupby('grouped_ship_count').apply(sample_ships)
        print("Creating list")
        self.all_batches_balancedTrain = list(balanced_train_df.groupby('ImageId'))
        self.all_batches_balancedValid = list(valid_df.groupby('ImageId'))
        
        with open('all_batches_balancedTrain.pickle', 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(self.all_batches_balancedTrain, f, pickle.HIGHEST_PROTOCOL)

        with open('all_batches_balancedValid.pickle', 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(self.all_batches_balancedValid, f, pickle.HIGHEST_PROTOCOL)
            
    def __len__(self):
        return len(self.all_batches_balancedTrain)
    
    def multi_rle_encode(self, img):
        labels = label(img[:, :, 0])
        return [rle_encode(labels==k) for k in np.unique(labels[labels>0])]

    # ref: https://www.kaggle.com/paulorzp/run-length-encode-and-decode
    def rle_encode(self, img):
        '''
        img: numpy array, 1 - mask, 0 - background
        Returns run length as string formated
        '''
        pixels = img.T.flatten()
        pixels = np.concatenate([[0], pixels, [0]])
        runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
        runs[1::2] -= runs[::2]
        return ' '.join(str(x) for x in runs)

    def rle_decode(self, mask_rle, shape=(768, 768)):
        '''
        mask_rle: run-length as string formated (start length)
        shape: (height,width) of array to return 
        Returns numpy array, 1 - mask, 0 - background
        '''
        s = mask_rle.split()
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
        starts -= 1
        ends = starts + lengths
        img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
        for lo, hi in zip(starts, ends):
            img[lo:hi] = 1
        return img.reshape(shape).T  # Needed to align to RLE direction

    def masks_as_image(self, in_mask_list):
        # Take the individual ship masks and create a single mask array for all ships
        all_masks = np.zeros((768, 768), dtype = np.int16)
        #if isinstance(in_mask_list, list):
        for mask in in_mask_list:
            if isinstance(mask, str):
                all_masks += self.rle_decode(mask)
        return np.expand_dims(all_masks, -1)

    def __getitem__(self, idx):
        rgb_path = os.path.join(self.train_image_dir, self.all_batches_balancedTrain[idx][0])
        c_img = imread(rgb_path)
        c_mask = self.masks_as_image( self.all_batches_balancedTrain[idx][1]['EncodedPixels'].values)
        
        c_img = np.stack(c_img, 0)/255.0
        c_mask = np.stack(c_mask, 0)
        
        c_img = resize(c_img, (c_img.shape[0] / 2, c_img.shape[1] / 2),
                       anti_aliasing=True)
        
        c_mask = resize(c_mask, (c_mask.shape[0] / 2, c_mask.shape[1] / 2),
                       anti_aliasing=True)
        
        c_img = c_img.transpose(-1, 0, 1)
        c_mask = c_mask.transpose(-1, 0, 1)
        
        
        return c_img, c_mask

    def show(self, x, y):
        f, axarr = plt.subplots(1,2, figsize=(15, 15))

        axarr[0].imshow(x.transpose(-1, 1, 0))
        axarr[1].imshow(y.transpose(-1, 1, 0)[:, :, 0])
            

if(10):
    ship_dir = '/media/shivam/DATA/airbus-tracking/'
    trainDataset = KaggleDataset(ship_dir)

In [None]:

# # x, y = trainDataset[1]

# # yp = np.ones_like(y)
y = torch.from_numpy(np.random.random((4, 1, 153, 153)))
y = torch.from_numpy(np.random.random((4, 1, 153, 153)))
yp = torch.from_numpy(np.random.random((4, 1, 153, 153)))

# print(dice_coeff(y, yp))

In [None]:
idx = np.random.randint(0,len(trainDataset.all_batches_balancedTrain))
# idx = 57
factor = 5
print(idx)

rgb_path = os.path.join(trainDataset.train_image_dir, trainDataset.all_batches_balancedTrain[idx][0])
c_img = imread(rgb_path)
c_mask = trainDataset.masks_as_image( trainDataset.all_batches_balancedTrain[idx][1]['EncodedPixels'].values)



h, w, _ = c_mask.shape
c = 192

x1 = np.random.randint(0, h-c)
x2 = np.random.randint(0, w-c)

c_img_s = c_img[x1:x1+c, x2:x2+c, :]
c_mask_s = c_mask[x1:x1+c, x2:x2+c, :]

# c_img_s = resize(c_img, (c_img.shape[0] // factor, c_img.shape[1] // factor), anti_aliasing=False)
# c_mask_s = img_as_bool(resize(c_mask, (c_mask.shape[0] // factor, c_mask.shape[1] // factor), anti_aliasing=False))
# c_mask_s = resize(c_mask, (c_mask.shape[0] // factor, c_mask.shape[1] // factor), anti_aliasing=False)

# c_mask_s = resize(c_mask, (c_mask.shape[0] // factor, c_mask.shape[1] // factor), anti_aliasing=False)
# c_mask_s[c_mask_s < 0.5] = 0
# c_mask_s[c_mask_s >= 0.5] = 1

# from torchvision.transforms import RandomCrop

# func = RandomCrop(192)

# c_mask_s = func(torch.from_numpy(c_mask))

print(np.sum(c_mask_s))
# print(c_img.shape, c_mask.shape)
# print(c_img_s.shape, c_mask_s.shape)
f, axarr = plt.subplots(2,2, figsize=(15,15))
axarr[0][0].imshow(c_img)
axarr[0][1].imshow(c_mask[:,:,0])
axarr[1][0].imshow(c_img_s)
axarr[1][1].imshow(c_mask_s[:,:,0])

# print (trainDataset.all_batches_balancedTrain[idx][1])