In [None]:
import wandb
import pandas as pd
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
import params

Let's now create a train_config that we'll pass to W&B run to control training hyperparameters.

In [None]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=(224, 224),
    batch_size=64,
    augment=True, # use data augmentation
    epochs=5, 
    lr=2e-3,
    pretrained=True,  # whether to use pretrained encoder
    seed=42,
)

We set seed for reproducibility

In [None]:
set_seed(train_config.seed, reproducible=True)

In [None]:
run = wandb.init(project=params.WANDB_PROJECT, job_type="training", config=train_config)

As usual, we will use W&B Artifacts to track the lineage of our models.

In [None]:
processed_data_at = run.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')
processed_dataset_dir = Path(processed_data_at.download())

In [None]:
df = pd.read_csv(processed_dataset_dir / 'data_split.csv')

We will not use the hold out dataset stage at this moment. is_valid column will tell our trainer how we want to split data between training and validation.

In [None]:
df = df[df.Stage != 'test'].reset_index(drop=True)
df['is_valid'] = df.Stage == 'valid'

In [None]:
def label_func(fname):
    return (fname.parent.name)

We will use fastai's DataBlock API to feed data into model training and validation.

In [None]:
fnames = get_image_files(processed_dataset_dir)
fnames = list(filter(lambda x: "media" not in str(x), fnames))
len(fnames)

In [None]:
df["image_fname"] = [processed_dataset_dir/f'{f}' for f in df.File_Name.values]

In [None]:
def get_data(df, bs=64, img_size=(224, 224), augment=True):
    block = DataBlock(blocks=(ImageBlock, CategoryBlock),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("Label"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

We are using wandb.config to track our training hyperparameters.

In [None]:
config = wandb.config    

In [None]:
dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

In [None]:
metrics=[accuracy, error_rate]
learn = vision_learner(dls, arch=resnet18, pretrained=config.pretrained, metrics=metrics)

In [None]:
callbacks = [
    SaveModelCallback(monitor='valid_loss'),
    WandbCallback(log_preds=True, log_model=True)
]

Let's train our model!

In [None]:
learn.fine_tune(config.epochs, config.lr, cbs=callbacks)

We are reloading the model from the best checkpoint at the end and saving it. To make sure we track the final metrics correctly, we will validate the model again and save the final loss and metrics to wandb.summary.

In [None]:
scores = learn.validate()
metric_names = ['final_loss', 'Accuracy', 'Error_rate']
final_results = {metric_names[i] : scores[i] for i in range(len(scores))}
final_results.items()
for k,v in final_results.items(): 
    wandb.summary[k] = v  

In [None]:
run = wandb.init(project=params.WANDB_PROJECT, job_type="metric")

In [None]:
table = wandb.Table(columns=["Metric", "Value"])
for k, v in final_results.items():
    table.add_data(k, v)

wandb.log({'Results': table})

In [None]:
wandb.finish()