In [1]:
import argparse, os
import wandb
from pathlib import Path
import torchvision.models as tvmodels
import pandas as pd
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

import params
from utils import t_or_f

In [2]:
default_config = SimpleNamespace(
    framework="fastai",
    img_size=(224, 224),
    batch_size=64,
    augment=True, # use data augmentation
    epochs=1, 
    arch="resnet18",
    lr=2e-3,
    pretrained=True,  # whether to use pretrained encoder
    mixed_precision=True,
    seed=42,
    log_preds=True
)

In [3]:
def download_data():
    processed_data_at = wandb.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')
    processed_dataset_dir = Path(processed_data_at.download())
    return processed_dataset_dir

In [4]:
def get_df(processed_dataset_dir, is_test=False):
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')
    if not is_test:
        df = df[df.Stage != 'test'].reset_index(drop=True)
        df['is_valid'] = df.Stage == 'valid'
    else:
        df = df[df.Stage == 'test'].reset_index(drop=True)
    
    ## Assign paths
    df["image_fname"] = [processed_dataset_dir/f'{f}' for f in df.File_Name.values]
    return df

In [5]:
def get_data(df, bs=64, img_size=(224, 224), augment=True):
    block = DataBlock(blocks=(ImageBlock, CategoryBlock),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("Label"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

In [6]:
def log_metrics(learn):
    scores = learn.validate()
    metric_names = ['final_loss', 'Accuracy', 'Error_rate']
    final_results = {metric_names[i] : scores[i] for i in range(len(scores))}
    for k,v in final_results.items(): 
        wandb.summary[k] = v  

In [7]:
def train(config):
    set_seed(config.seed, reproducible=True)
    run = wandb.init(project=params.WANDB_PROJECT, job_type="training", config=config)

    config = wandb.config    
    
    processed_dataset_dir = download_data()
    df = get_df(processed_dataset_dir)

    dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

    metrics=[accuracy, error_rate]
    learn = vision_learner(dls, arch=getattr(tvmodels, config.arch), pretrained=config.pretrained, metrics=metrics)

    cbs = [WandbCallback(log_preds=True, log_model=True), 
           SaveModelCallback(fname=f'run-{wandb.run.id}-model', monitor='valid_loss')]
    cbs += ([MixedPrecision()] if config.mixed_precision else [])

    learn.fine_tune(config.epochs, config.lr, cbs=cbs)
    
    log_metrics(learn)

In [8]:
train(default_config)

[34m[1mwandb[0m: Currently logged in as: [33msolab5[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact data_split:latest, 2266.04MB. 36310 files... 
[34m[1mwandb[0m:   36310 of 36310 files downloaded.  
Done. 0:0:6.9


epoch,train_loss,valid_loss,accuracy,error_rate,time
0,0.561696,0.289912,0.898678,0.101322,00:45


Better model found at epoch 0 with valid_loss value: 0.2899123430252075.


epoch,train_loss,valid_loss,accuracy,error_rate,time
0,0.173657,0.102228,0.962555,0.037445,00:56


Better model found at epoch 0 with valid_loss value: 0.1022278442978859.


In [9]:
wandb.finish()

0,1
accuracy,▁█
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
error_rate,█▁
lr_0,▁▁▁▂▂▂▃▃▄▄▅▅▆▆▇▇▇███▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_1,▁▁▁▂▂▂▃▃▄▄▅▅▆▆▇▇▇███▂▂▃▃▄▄▄▄▄▄▄▃▃▃▂▂▁▁▁▁
lr_2,▁▁▁▂▂▂▃▃▄▄▅▅▆▆▇▇▇███▂▂▃▃▄▄▄▄▄▄▄▃▃▃▂▂▁▁▁▁
mom_0,████▇▇▆▆▆▅▄▄▃▃▂▂▂▁▁▁█▇▅▄▂▁▁▁▂▂▃▃▄▅▆▆▇███

0,1
Accuracy,0.96256
Error_rate,0.03744
accuracy,0.96256
epoch,2.0
eps_0,1e-05
eps_1,1e-05
eps_2,1e-05
error_rate,0.03744
final_loss,0.1022
lr_0,0.0
