In [None]:
from fastai.vision.all import *
from fastai.vision.widgets import *
import wandb

# Params
Image.MAX_IMAGE_PIXELS = 1e11
CFG = {
    'base_model': 'resnet18',   # resnet18/34/50, efficientnet_v2_s/m/l
    'batch_size': 32,
    'whole_img_size': 700,
    'aug_img_size': 512,
    'aug_min_scale': 1.0,
    'freeze_epochs': 1,
    'epochs': 10,
    'seed': 42,
    'tissuecrop': False
}

# Wandb
wandb.login(key='1b0401db7513303bdea77fb070097f9d2850cf3b')
run = wandb.init(project='kaggle-ubc-ocean', config=CFG, tags=['fastai', 'baseline'])

# Paths
root = '/media/latlab/MR/projects/kaggle-ubc-ocean'
data_dir = os.path.join(root, 'data')
results_dir = os.path.join(root, 'results')
train_filename = 'train.csv'
train_img_dir = os.path.join(data_dir, 'train_images')
train_thumbnail_dir = os.path.join(data_dir, 'train_thumbnails')

# Functions
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_file_path(image_id):
    if os.path.exists(os.path.join(train_thumbnail_dir, f'{image_id}_thumbnail.png')):
        return os.path.join(train_thumbnail_dir, f'{image_id}_thumbnail.png')
    else:
        return os.path.join(train_img_dir, f'{image_id}.png')

# Seed
seed_everything(CFG['seed'])
torch.backends.cudnn.benchmark = True

# Load descriptive data
df = pd.read_csv(os.path.join(data_dir, train_filename))

# Add image path
df['image_path'] = df['image_id'].apply(get_file_path)
df

In [None]:
class SideCrop(Transform):
    def __init__(self): pass
    def encodes(self, image: PILImage):
        size = min(image.size)
        new_image = image.crop_pad((size, size), (0,0))
        return new_image
    
class TissueCrop(Transform):
    def __init__(self, size): self.size = size
    def encodes(self, image: PILImage):
        h_sum = np.array(image).mean(2).sum(0)
        v_sum = np.array(image).mean(2).sum(1)

        if h_sum[0] != 0:
            h_left = 0
        else:
            h_left = np.where(h_sum==0)[0][0]
        if sum(h_sum[h_left:]==0) == 0:
            h_right = len(h_sum)
        else:
            h_right = np.where(h_sum[h_left:]==0)[0][0] + h_left

        if v_sum[0] != 0:
            v_top = 0
        else:
            v_top = np.where(v_sum==0)[0][0]
        if sum(v_sum[v_top:]==0) == 0:
            v_bottom = len(v_sum)
        else:
            v_bottom = np.where(v_sum[v_top:]==0)[0][-1] + v_top
        
        h_center = round((h_left + h_right)/2)
        v_center = round((v_top + v_bottom)/2)
        new_image = image.crop_pad((self.size, self.size), (h_center-round(self.size/2), v_center-round(self.size/2)))
        return new_image

if CFG['tissuecrop']:
    dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
                get_x=ColReader('image_path'),
                get_y=ColReader('label'),
                splitter=RandomSplitter(valid_pct=0.2, seed=CFG['seed']),
                item_tfms=[TissueCrop(CFG['whole_img_size'])],
                batch_tfms=[*aug_transforms(size=CFG['aug_img_size'], min_scale=CFG['aug_min_scale'], max_warp=0), Normalize.from_stats(*imagenet_stats)])
else:
    dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
            get_x=ColReader('image_path'),
            get_y=ColReader('label'),
            splitter=RandomSplitter(valid_pct=0.2, seed=CFG['seed']),
            item_tfms=[SideCrop(), Resize(CFG['whole_img_size'], method='crop')],
            batch_tfms=[*aug_transforms(size=CFG['aug_img_size'], min_scale=CFG['aug_min_scale'], max_warp=0), Normalize.from_stats(*imagenet_stats)])

In [None]:
%%time
dls = dblock.dataloaders(df, bs=CFG['batch_size'], num_workers=36)

In [None]:
dls.show_batch(max_n=25)

In [None]:
learn = vision_learner(dls, eval(CFG['base_model']), metrics=BalancedAccuracy())
learn.fine_tune(CFG['epochs'], freeze_epochs=CFG['freeze_epochs'])

In [None]:
for res in learn.recorder.values:
    wandb.log({'train_loss': res[0], 
               'valid_loss': res[1],
               'balanced_accuracy': res[2]})
wandb.finish()
learn.export(os.path.join(results_dir, 'models', f'ubc-ocean_{run.name}.pkl'))

In [None]:
# Show accuracy plot
plt.plot(np.array(learn.recorder.values)[:,2])
plt.figure()
learn.recorder.plot_loss()

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
# interp.most_confused(min_val=5)

In [None]:
interp.plot_top_losses(18, nrows=3)