In [2]:
import wandb
import pandas as pd
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
import params

Let's now create a train_config that we'll pass to W&B run to control training hyperparameters.

In [3]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=(224, 224),
    batch_size=64,
    augment=True, # use data augmentation
    epochs=5, 
    lr=2e-3,
    pretrained=True,  # whether to use pretrained encoder
    seed=42,
)

We set seed for reproducibility

In [4]:
set_seed(train_config.seed, reproducible=True)

In [28]:
run = wandb.init(project=params.WANDB_PROJECT, job_type="training", config=train_config)

As usual, we will use W&B Artifacts to track the lineage of our models.

In [6]:
processed_data_at = run.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')
processed_dataset_dir = Path(processed_data_at.download())

[34m[1mwandb[0m: Downloading large artifact data_split:latest, 2266.04MB. 36310 files... 
[34m[1mwandb[0m:   36310 of 36310 files downloaded.  
Done. 0:0:6.2


In [7]:
df = pd.read_csv(processed_dataset_dir / 'data_split.csv')

We will not use the hold out dataset stage at this moment. is_valid column will tell our trainer how we want to split data between training and validation.

In [8]:
df = df[df.Stage != 'test'].reset_index(drop=True)
df['is_valid'] = df.Stage == 'valid'

In [11]:
def label_func(fname):
    return (fname.parent.name)

We will use fastai's DataBlock API to feed data into model training and validation.

In [12]:
fnames = get_image_files(processed_dataset_dir)
fnames = list(filter(lambda x: "media" not in str(x), fnames))
len(fnames)

18160

In [13]:
df["image_fname"] = [processed_dataset_dir/f'{f}' for f in df.File_Name.values]

In [14]:
def get_data(df, bs=64, img_size=(224, 224), augment=True):
    block = DataBlock(blocks=(ImageBlock, CategoryBlock),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("Label"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

We are using wandb.config to track our training hyperparameters.

In [17]:
config = wandb.config    

In [18]:
dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

In [19]:
metrics=[accuracy, error_rate]
learn = vision_learner(dls, arch=resnet18, pretrained=config.pretrained, metrics=metrics)



In [20]:
callbacks = [
    SaveModelCallback(monitor='valid_loss'),
    WandbCallback(log_preds=True, log_model=True)
]

Let's train our model!

In [21]:
learn.fine_tune(config.epochs, config.lr, cbs=callbacks)

epoch,train_loss,valid_loss,accuracy,error_rate,time
0,0.562043,0.289595,0.898678,0.101322,00:24


Better model found at epoch 0 with valid_loss value: 0.28959470987319946.


epoch,train_loss,valid_loss,accuracy,error_rate,time
0,0.213966,0.095208,0.96641,0.03359,00:26
1,0.094866,0.087995,0.965859,0.034141,00:26
2,0.05751,0.026764,0.989537,0.010463,00:26
3,0.032075,0.035872,0.985132,0.014868,00:26
4,0.020896,0.020239,0.992841,0.007159,00:26


Better model found at epoch 0 with valid_loss value: 0.0952075719833374.
Better model found at epoch 1 with valid_loss value: 0.08799496293067932.
Better model found at epoch 2 with valid_loss value: 0.026764262467622757.
Better model found at epoch 4 with valid_loss value: 0.02023932710289955.


We are reloading the model from the best checkpoint at the end and saving it. To make sure we track the final metrics correctly, we will validate the model again and save the final loss and metrics to wandb.summary.

In [22]:
scores = learn.validate()
metric_names = ['final_loss', 'Accuracy', 'Error_rate']
final_results = {metric_names[i] : scores[i] for i in range(len(scores))}
final_results.items()
for k,v in final_results.items(): 
    wandb.summary[k] = v  

In [45]:
run = wandb.init(project=params.WANDB_PROJECT, job_type="metric")

In [46]:
table = wandb.Table(columns=["Metric", "Value"])
for k, v in final_results.items():
    table.add_data(k, v)

wandb.log({'Results': table})

In [47]:
wandb.finish()