In [1]:
from fastai.vision.all import *
from tqdm import tqdm
import albumentations as A
from albumentations import functional as F

In [2]:
DATA_PATH = Path('data')
TRAIN_PATH = DATA_PATH / 'train_features'
LABEL_PATH = DATA_PATH / 'train_labels'

In [3]:
df = pd.read_csv(DATA_PATH / 'train_ready.csv')

In [4]:
df.head(1)

Unnamed: 0,chip_id,location,datetime,cloudpath,is_valid,B02_path,B03_path,B04_path,B08_path,label_path
0,adwp,Chifunfu,2020-04-29T08:20:47Z,az://./train_features/adwp,False,data/train_features/adwp/B02.tif,data/train_features/adwp/B03.tif,data/train_features/adwp/B04.tif,data/train_features/adwp/B08.tif,data/train_labels/adwp.tif


In [5]:
IMG_SIZE = 512
BS = 2
DEBUG = None ## Size of development set or None

In [6]:
def get_array(chip, band):
    fn = (TRAIN_PATH / chip /band).with_suffix('.tif')
    if fn.exists():
        return (np.array(Image.open(fn)) / 27000)
    else:
        return np.zeros((512, 512))

def get_4chan_plain(chip_folder):
    blue  = get_array(chip_folder, 'B02')
    green = get_array(chip_folder, 'B03')
    red   = get_array(chip_folder, 'B04')
    infra = get_array(chip_folder, 'B08')
    stack = np.stack([blue, green, red, infra], axis = 0)#.astype(np.uint8)
    return TensorImage(stack)#tensor(stack).float()

def get_4chan(chip_folder):
    blue  = get_array(chip_folder, 'B02')
    green = get_array(chip_folder, 'B03')
    red   = get_array(chip_folder, 'B04')
    infra = get_array(chip_folder, 'B08')   
    stack = np.stack([blue, green, red, infra], axis = 0) 
    return stack

def get_mask_plain(chip):
    fn = (LABEL_PATH / chip).with_suffix('.tif')
    return TensorMask(Image.open(fn))#tensor(np.array(Image.open(fn))).long()

def get_mask(chip):
    fn = (LABEL_PATH / chip).with_suffix('.tif')
    return np.array(Image.open(fn))

def get_chips(path):
    potential_chips = list(path.iterdir())
    chips_paths = L(chip for chip in potential_chips if chip.is_dir())
    return chips_paths.attrgot('name')

def is_valid(chip):
    return df.loc[df['chip_id'] == chip, 'is_valid'].item()

---
## <Datasets
---

In [7]:
class Chips:
    def __init__(self, src_path, df_path, debug = None):
        self.src_path = src_path
        self.debug = debug
        self.names = self._get_chips(src_path)
        self.df = pd.read_csv(df_path)
        self.train_idx = self.names.argwhere(self._is_valid, negate = True)
        self.valid_idx = self.names.argwhere(self._is_valid)
        assert(len(self.train_idx) + len(self.valid_idx) == len(self.names))
        assert(len(set(self.train_idx).intersection(set(self.valid_idx))) == 0)
    
    def _is_valid(self, chip):
        return self.df.loc[self.df['chip_id'] == chip, 'is_valid'].item()
    
    def _get_chips(self, path):
        potential_chips = list(path.iterdir())
        chips_paths = [chip for chip in potential_chips if chip.is_dir()]
        if self.debug:
            chips_paths = random.choices(chips_paths, k = self.debug)
        return L(chips_paths).attrgot('name')
    
    def describe(self):
        print(f'Number of validation items: {len(self.valid_idx)}, number of training items: {len(self.train_idx)}\nTotal number of items: {len(self.names)}')
        
    def get_train_chips(self):
        return self.names[self.train_idx]
    
    def get_valid_chips(self):
        return self.names[self.valid_idx]
    
    def get_splits(self):
        return [self.train_idx, self.valid_idx]

In [8]:
chips = Chips(TRAIN_PATH, DATA_PATH / 'train_ready.csv', debug = DEBUG)
chips.describe()

Number of validation items: 2442, number of training items: 9306
Total number of items: 11748


In [9]:
plain_dsets = Datasets(chips.names, ([get_4chan_plain], [get_mask_plain]), splits = chips.get_splits())
tfms_dsets = Datasets(chips.names, [[get_4chan], [get_mask]], splits = chips.get_splits())

In [None]:
it = iter(plain_dsets)
print((next(it)[0] == plain_dsets[0][0]).all())
print((next(it)[0] == plain_dsets[1][0]).all())

---
## Datasets>
---

In [None]:
print((plain_dsets[0][0] == get_4chan_plain(chips.names[0])).all())
print((plain_dsets[0][1] == get_mask_plain(chips.names[0])) .all())

print((tfms_dsets[0][0] == get_4chan(chips.names[0])).all())
print((tfms_dsets[0][1] == get_mask(chips.names[0])) .all())

---
## < Transforms
---

### Plain/no tfms:

In [11]:
plain_dls = plain_dsets.dataloaders(
    bs = BS, 
    num_workers = 6, 
    pin_memory = True,
    #after_item = [ToTensor],
    #after_batch = [IntToFloatTensor, RandomResizedCropGPU(440)]
)

In [None]:
plain_b = plain_dls.one_batch()
plain_xb = plain_b[0]
plain_yb = plain_b[1]
print(plain_xb.shape, plain_yb.shape)
print(plain_xb.type(), plain_yb.type())

In [None]:
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)

In [None]:
n_workers = plain_dls.fake_l.num_workers
print(n_workers == 0)

In [None]:
faky = _MultiProcessingDataLoaderIter(plain_dls.fake_l)
for b in faky:
    print(type(b))

In [None]:
len(plain_dls.loaders[0])

### Returns:
* `before_batch` returns list of single tuples (img, mask)

In [None]:
@Transform
def print_after_item(x):
    print(f'At after_item: {type(x)}')
    if type(x) in [np.ndarray, torch.Tensor]:
        print(x.shape)
    elif type(x) == list:
        print(len(x))
    return x

@Transform
def print_before_batch(x):
    print(f'At before_batch: {type(x)}')
    if type(x) in [np.ndarray, torch.Tensor]:
        print(x.shape)
    elif type(x) in [list, tuple]:
        print(len(x))
    return x

@Transform
def print_after_batch(x):
    x
    print(f'At after_batch: {type(x)}')
    if type(x) in [np.ndarray, torch.Tensor]:
        print(x.shape)
    elif type(x) == list:
        print(len(x))
    return x

In [None]:
state = True

In [None]:
if state:
    print('Use tfms:')
    tfms_dls = tfms_dsets.dataloaders(
        bs = BS, 
        num_workers = 6, 
        pin_memory = True, 
        # x = (x_0,…,x_bs) | y = (y_0,…,y_bs) | b = ((x_0,y_0),…,(x_bs,y_bs))
        after_item = [print_after_item], ## Applied to all x_i, y_i indiv. for all i's
        before_batch = [print_before_batch], ## Applied to [(x_0, y_0),…,(x_bs,y_bs)]
        ## Is turned from np.array to tensor
        after_batch = [print_after_batch] ## Applied to [x, y]
    )
else:
    print("Don't use tfms")
    tfms_dls = tfms_dsets.dataloaders(bs = BS, num_workers = 6, pin_memory = True)

### With tfms:

In [None]:
tfms_b = tfms_dls.one_batch()
tfms_xb = tfms_b[0]
tfms_yb = tfms_b[1]
print('\nPrint shapes and types:')
print(tfms_xb.shape, tfms_yb.shape)
print(tfms_xb.type(), tfms_yb.type())

### -------------------------------------------------------

### Compare:

In [None]:
plain_x = plain_xb[0]
tfms_x = tfms_xb[0]

In [None]:
plain_x[0,:5,:5]

In [None]:
tfms_x[0,-5:,-5:]

In [None]:
def show_visual(img, ax = None):
    img = img.cpu()[:3,...].numpy().transpose(1,2,0)
    if ax == None:
        _, ax = plt.subplots(1)
    ax.imshow(img)

In [None]:
show_visual(plain_x)

In [None]:
#def show_with_mask()

In [None]:
def show_seg_batch(b):
    imgb, maskb = b
    bs = imgb.shape[0]
    print(bs)
    
    fig, axs = plt.subplots(bs, 3, figsize = (20, 15))
    for i in range(bs):
        axs[i][0].imshow(imgs[i])
        axs[i][1].imshow(masks[i])

In [None]:
b[0].shape

---
## Transforms >
---

In [12]:
dls = plain_dls

class BananaLoss(CrossEntropyLossFlat):
    def __call__(self, inp, targ, **kwargs):
        inp,targ  = map(self._contiguous, (inp,targ))
        targ = targ.long()
        if self.floatify and targ.dtype!=torch.float16: targ = targ.float()
        if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long()
        if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)
        return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)

In [13]:
n_channels, *img_size = dls.one_batch()[0].shape[-3:]

model = create_unet_model(
    arch = resnet34,
    n_out = 2,
    img_size = img_size,
    n_in = n_channels
)

_default_meta    = {'cut':None, 'split':default_split}
meta = model_meta.get(resnet34, _default_meta)

learn = Learner(
    dls,
    model,
    loss_func = BananaLoss(axis = 1),
    metrics = [Dice, JaccardCoeff],
    splitter = meta['split']
)

learn.freeze()

---

In [14]:
#learn.lr_find()

In [18]:
%timeit -r 10 -n 10 dls.one_batch()

36.6 ms ± 1.17 ms per loop (mean ± std. dev. of 10 runs, 10 loops each)


In [15]:
learn.fit_one_cycle(1, 1e-4)

epoch,train_loss,valid_loss,dice,jaccard_coeff,time
0,0.248852,0.195848,0.936604,0.880767,22:29


# Time overview:
* without tfms: 22:30 min | __baseline__
* with flip transforms: 8:15min | but something is wrong

In [None]:
#learn.fine_tune(5, 1e-4)

In [None]:
#result_metric = None
#learn.export(f'res34_{IMG_SIZE}_j{905}')