In [1]:
from fastai.vision.all import *
from tqdm import tqdm
import albumentations as A
from albumentations import functional as F

In [2]:
DATA_PATH = Path('data')
TRAIN_PATH = DATA_PATH / 'train_features'
LABEL_PATH = DATA_PATH / 'train_labels'

In [3]:
df = pd.read_csv(DATA_PATH / 'train_ready.csv')

In [4]:
df.head(1)

Unnamed: 0,chip_id,location,datetime,cloudpath,is_valid,B02_path,B03_path,B04_path,B08_path,label_path
0,adwp,Chifunfu,2020-04-29T08:20:47Z,az://./train_features/adwp,False,data/train_features/adwp/B02.tif,data/train_features/adwp/B03.tif,data/train_features/adwp/B04.tif,data/train_features/adwp/B08.tif,data/train_labels/adwp.tif


In [5]:
IMG_SIZE = 512
BS = 2
DEBUG = BS * 200 ## Size of development set or None

In [6]:
def get_multiband_img(chip_path):
    blue  = get_array(chip_path, 'B02')
    green = get_array(chip_path, 'B03')
    red   = get_array(chip_path, 'B04')
    infra = get_array(chip_path, 'B08')   
    stack = np.stack([blue, green, red, infra], axis = -1)##resulting size:(width,height,bands)
    return stack

def get_array(chip_path, band):
    fn = (chip_path / band).with_suffix('.tif')
    if fn.exists():
        return (np.array(Image.open(fn)) / 27000)
    else:
        return np.zeros((512, 512))

def get_mask(chip_path):
    fn = (LABEL_PATH / chip_path.stem).with_suffix('.tif')
    return np.array(Image.open(fn))


In [7]:
class Chips:
    def __init__(self, src_path, df_src, debug = None):
        self.src_path = src_path
        self.df = self._get_df(df_src)
        self.paths = self._get_paths(debug)
       
    def _is_valid(self, chip):
        return self.df.loc[self.df['chip_id'] == chip.stem, 'is_valid'].item()
    
    def _get_paths(self, debug):
        if debug:
            self.df = self.df.sample(n=debug)
        chips = self.df['chip_id'].tolist()
        return L([self.src_path / chip for chip in chips])
    
    def _get_df(self, src):
        if isinstance(src, pd.DataFrame):
            return src
        elif isinstance(src, (Path, str)):
            return pd.read_csv(src)
        else:
            print('Can not load dataframe, should be pd.DataFrame or path to .csv')
        
    def get_paths(self):
        return self.names.map(lambda x: (self.src_path / x))
    
    def get_train_chips(self):
        self.train_idx = self.paths.argwhere(self._is_valid, negate = True)
        return self.paths[self.train_idx]
    
    def get_valid_chips(self):
        self.valid_idx = self.paths.argwhere(self._is_valid)
        return self.paths[self.valid_idx]
    
    def get_splits(self):
        if not hasattr(self, 'train_idx'):
            self.train_idx = self.paths.argwhere(self._is_valid, negate = True)
        if not hasattr(self, 'valid_idx'):
            self.valid_idx = self.paths.argwhere(self._is_valid)
        return [self.train_idx, self.valid_idx]

In [8]:
chips = Chips(TRAIN_PATH, df, debug = DEBUG)
dsets = Datasets(chips.paths, ([get_multiband_img], [get_mask]), splits = chips.get_splits())

In [9]:
class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): 
        self.aug = aug
    def encodes(self, x):
        augs = []
        for img,mask in x:
            augs.append(tuple(self.aug(image=img, mask=mask).values()))
        return augs
    
class TransposeTransform(ItemTransform):
    def encodes(self, x):
        transposed = []
        for img, mask in x:
            transposed.append((TensorImage(img.transpose(2,0,1)).float(), TensorMask(mask).long()))
        return transposed
    
class FormatTransform(ItemTransform):
    def __init__(self, return_type):
        self.return_type = return_type
    
    def encodes(self, x):
        #if isinstance(x, torch.Tensor):
            #return TensorImage(x.permute(0,3,1,2)).float()
        return self.return_type([TensorImage(x[0].permute(0,3,1,2)).float(), TensorMask(x[1]).long()])
        #return self.return_type([TensorImage(x[0]).float(), TensorMask(x[1]).long()])

augs_list = A.Compose([
     A.Flip(),
     #A.RandomCrop(440, 440),
     A.CoarseDropout()
    ])

aug_tfms = SegmentationAlbumentationsTransform(augs_list)
transpose_tfm = TransposeTransform()
format_tfm = FormatTransform(tuple)

In [10]:
dls = dsets.dataloaders(
    bs = BS, 
    num_workers = 6, 
    pin_memory = True,
    device = 'cuda',
    #after_item = [],
    before_batch = [aug_tfms, transpose_tfm],
    #after_batch = [format_tfm],
)

In [11]:
n_channels, *img_size = dls.one_batch()[0].shape[-3:]

model = create_unet_model(
    arch = resnet34,
    n_out = 2,
    img_size = img_size,
    n_in = n_channels
)

_default_meta    = {'cut':None, 'split':default_split}
meta = model_meta.get(resnet34, _default_meta)

learn = Learner(
    dls,
    model,
    loss_func = CrossEntropyLossFlat(axis = 1),##model returns pred w/ shape (2,w,h) where 1st axis holds (p,1-p)
    metrics = [Dice, JaccardCoeff],
    splitter = meta['split']
)

learn.freeze()

---

In [12]:
#learn.lr_find()

In [13]:
learn.fine_tune(1, 1e-4)

epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.469546,0.409175,0.843298,0.729054,01:14


epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.349062,0.304411,0.87775,0.782133,01:14


### Tests for tfms speed:

|tfms|bs|No. imgs|time frozen|time unfrozen|
|----|--|--------|-----------|-------------|
|standard|2|600|01:05|01:07|
|w/o augs|2|600|01:09|01:12|
|w/ crop|2|600|00:51|00:52|
|w/ dropout|2|600|01:04|01:06|
|perm @np|2|600|00:49|00:49|
|@np dropout|2|600|00:47|00:47|
|@np dropout|3|600|01:14|01:13|

|tfms|bs|No. imgs|time frozen|time unfrozen|
|:----|--|--------|-----------|-------------|
|@ batch|||01:07|01:09|
|@ list|||00:49|00:49|
|@ loading|2||00:48|00:48|
|@ loading|3||01:14|01:14|

In [14]:
#learn.export(f'res34_{IMG_SIZE}_j{881}_v5')
#learn.export('test_debug_v6')