# Inference notebook (only for submissions)

In [1]:
from fastai.vision.all import *

In [2]:
path_str = '../data'
PATH = Path(path_str)

images_path = Path(path_str + '/train_images')
csv_path = Path(path_str + '/train.csv')

In [11]:
class AlbumentationsTransform(RandTransform):
    split_idx,order = None, 2
    
    def __init__(self, train_aug, valid_aug): 
        store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)


def get_train_aug(size): 
    return albumentations.Compose([
            albumentations.RandomResizedCrop(size,size),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5
            ),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5)
])

def get_valid_aug(size): 
    return albumentations.Compose([
        albumentations.CenterCrop(size, size, p=1.),
        albumentations.Resize(size, size),
], p=1.)

def get_x(row): return images_path/row['image_id']
def get_y(row): return row['label']

In [12]:
learn = load_learner(Path('resnet18-fold0.pkl'))
learn1 = load_learner(Path('resnet18-fold1.pkl'))

In [14]:
submission_df = pd.read_csv(path_str + '/sample_submission.csv')
test_data_path = submission_df['image_id'].apply(lambda x: path_str+'/test_images/'+x)

learn_tst_dl = learn.dls.test_dl(test_data_path)
learn1_tst_dl = learn1.dls.test_dl(test_data_path)


## TODO: add item_tfms to .tta
learn_predictions = learn.tta(dl=learn_tst_dl, n=3)
learn1_predictions = learn1.tta(dl=learn1_tst_dl, n=3)

predictions = (learn_predictions[0] + learn1_predictions[0]) / 2

In [15]:
submission_df['label'] = np.argmax(predictions, axis=1)

submission_df.to_csv('submission.csv', index=False)