In [1]:
!pip install --upgrade wandb



In [2]:
import wandb
import pandas as pd
import torchvision.models as tvmodels
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

In [3]:
WANDB_PROJECT = "mlops-project"
ENTITY = None
BDD_CLASSES = {i:c for i,c in enumerate(['background', 'road', 'traffic light', 'traffic sign', 'person', 'vehicle', 'bicycle'])}
RAW_DATA_AT = 'av-team/mlops-course-001/bdd_simple_1k'
PROCESSED_DATA_AT = 'av-team/mlops-course-001/bdd_simple_1k_split'

# Model Optimization

In [4]:
from sklearn.metrics import ConfusionMatrixDisplay
from IPython.display import display, Markdown

CLASS_INDEX = {v:k for k,v in BDD_CLASSES.items()}

def t_or_f(arg):
    ua = str(arg).upper()
    if 'TRUE'.startswith(ua): return True
    else: return False

def iou_per_class(inp, targ):
    "Compute iou per class"
    iou_scores = []
    for c in range(inp.shape[0]):
        dec_preds = inp.argmax(dim=0)
        p = torch.where(dec_preds == c, 1, 0)
        t = torch.where(targ == c, 1, 0)
        c_inter = (p * t).float().sum().item()
        c_union = (p + t).float().sum().item()
        iou_scores.append(c_inter / (c_union - c_inter) if c_union > 0 else np.nan)
    return iou_scores

def create_row(sample, pred_label, prediction, class_labels):
    """"A simple function to create a row of (img, target, prediction, and scores...)"""
    (image, label) = sample
    # compute metrics
    iou_scores = iou_per_class(prediction, label)
    image = image.permute(1, 2, 0)
    row =[wandb.Image(
                image,
                masks={
                    "predictions": {
                        "mask_data": pred_label[0].numpy(),
                        "class_labels": class_labels,
                    },
                    "ground_truths": {
                        "mask_data": label.numpy(),
                        "class_labels": class_labels,
                    },
                },
            ),
            *iou_scores,
    ]
    return row

def create_iou_table(samples, outputs, predictions, class_labels):
    "Creates a wandb table with predictions and targets side by side"

    def _to_str(l):
        return [f'{str(x)} IoU' for x in l]

    items = list(zip(samples, outputs, predictions))

    table = wandb.Table(
        columns=["Image"]
        + _to_str(class_labels.values()),
    )
    # we create one row per sample
    for item in progress_bar(items):
        table.add_data(*create_row(*item, class_labels=class_labels))

    return table

def get_predictions(learner, test_dl=None, max_n=None):
    """Return the samples = (x,y) and outputs (model predictions decoded), and predictions (raw preds)"""
    test_dl = learner.dls.valid if test_dl is None else test_dl
    inputs, predictions, targets, outputs = learner.get_preds(
        dl=test_dl, with_input=True, with_decoded=True
    )
    x, y, samples, outputs = learner.dls.valid.show_results(
        tuplify(inputs) + tuplify(targets), outputs, show=False, max_n=max_n
    )
    return samples, outputs, predictions

    def value(self): return self.inter/(self.union-self.inter) if self.union > 0 else None

class MIOU(DiceMulti):
    @property
    def value(self):
        binary_iou_scores = np.array([])
        for c in self.inter:
            binary_iou_scores = np.append(binary_iou_scores, \
                                          self.inter[c]/(self.union[c]-self.inter[c]) if self.union[c] > 0 else np.nan)
        return np.nanmean(binary_iou_scores)

class IOU(DiceMulti):
    @property
    def value(self):
        c=CLASS_INDEX[self.nm]
        return self.inter[c]/(self.union[c]-self.inter[c]) if self.union[c] > 0 else np.nan

class BackgroundIOU(IOU): nm = 'background'
class RoadIOU(IOU): nm = 'road'
class TrafficLightIOU(IOU): nm = 'traffic light'
class TrafficSignIOU(IOU): nm = 'traffic sign'
class PersonIOU(IOU): nm = 'person'
class VehicleIOU(IOU): nm = 'vehicle'
class BicycleIOU(IOU): nm = 'bicycle'


class IOUMacro(DiceMulti):
    @property
    def value(self):
        c=CLASS_INDEX[self.nm]
        if c not in self.count: return np.nan
        else: return self.macro[c]/self.count[c] if self.count[c] > 0 else np.nan

    def reset(self): self.macro,self.count = {},{}

    def accumulate(self, learn):
        pred,targ = learn.pred.argmax(dim=self.axis), learn.y
        for c in range(learn.pred.shape[self.axis]):
            p = torch.where(pred == c, 1, 0)
            t = torch.where(targ == c, 1, 0)
            c_inter = (p*t).float().sum(dim=(1,2))
            c_union = (p+t).float().sum(dim=(1,2))
            m = c_inter / (c_union - c_inter)
            macro = m[~torch.any(m.isnan())]
            count = macro.shape[1]

            if count > 0:
                msum = macro.sum().item()
                if c in self.count:
                    self.count[c] += count
                    self.macro[c] += msum
                else:
                    self.count[c] = count
                    self.macro[c] = msum


class MIouMacro(IOUMacro):
    @property
    def value(self):
        binary_iou_scores = np.array([])
        for c in self.count:
            binary_iou_scores = np.append(binary_iou_scores, self.macro[c]/self.count[c] if self.count[c] > 0 else np.nan)
        return np.nanmean(binary_iou_scores)


class BackgroundIouMacro(IOUMacro): nm = 'background'
class RoadIouMacro(IOUMacro): nm = 'road'
class TrafficLightIouMacro(IOUMacro): nm = 'traffic light'
class TrafficSignIouMacro(IOUMacro): nm = 'traffic sign'
class PersonIouMacro(IOUMacro): nm = 'person'
class VehicleIouMacro(IOUMacro): nm = 'vehicle'
class BicycleIouMacro(IOUMacro): nm = 'bicycle'


def display_diagnostics(learner, dls=None, return_vals=False):
    """
    Display a confusion matrix for the unet learner.
    If `dls` is None it will get the validation set from the Learner

    You can create a test dataloader using the `test_dl()` method like so:
    >> dls = ... # You usually create this from the DataBlocks api, in this library it is get_data()
    >> tdls = dls.test_dl(test_dataframe, with_labels=True)

    See: https://docs.fast.ai/tutorial.pets.html#adding-a-test-dataloader-for-inference

    """
    probs, targs = learner.get_preds(dl = dls)
    preds = probs.argmax(dim=1)
    classes = list(BDD_CLASSES.values())
    y_true = targs.flatten().numpy()
    y_pred = preds.flatten().numpy()

    tdf, pdf = [pd.DataFrame(r).value_counts().to_frame(c) for r,c in zip((y_true, y_pred) , ['y_true', 'y_pred'])]
    countdf = tdf.join(pdf, how='outer').reset_index(drop=True).fillna(0).astype(int).rename(index=BDD_CLASSES)
    countdf = countdf/countdf.sum()
    display(Markdown('### % Of Pixels In Each Class'))
    display(countdf.style.format('{:.1%}'))


    disp = ConfusionMatrixDisplay.from_predictions(y_true=y_true, y_pred=y_pred,
                                                   display_labels=classes,
                                                   normalize='pred')
    fig = disp.ax_.get_figure()
    fig.set_figwidth(10)
    fig.set_figheight(10)
    disp.ax_.set_title('Confusion Matrix (by Pixels)', fontdict={'fontsize': 32, 'fontweight': 'medium'})
    fig.show()

    if return_vals: return countdf, disp

In [5]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=(180, 320),
    batch_size=8,
    augment=True, # use data augmentation
    epochs=10,
    lr=2e-3,
    arch="resnet18",
    pretrained=True,  # whether to use pretrained encoder
    seed=42,
    log_preds=True,
)

In [6]:
def download_data():
    processed_data_at = wandb.use_artifact(f'{PROCESSED_DATA_AT}:latest')
    processed_dataset_dir = Path(processed_data_at.download())
    return processed_dataset_dir

In [7]:
def label_func(fname):
    return (fname.parent.parent/"labels")/f"{fname.stem}_mask.png"

In [8]:
def get_df(processed_dataset_dir, is_test=False):
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')

    if not is_test:
        df = df[df.Stage != 'test'].reset_index(drop=True)
        df['is_valid'] = df.Stage == 'valid'
    else:
        df = df[df.Stage == 'test'].reset_index(drop=True)

    df["image_fname"] = [processed_dataset_dir/f'images/{f}' for f in df.File_Name.values]
    df["label_fname"] = [label_func(f) for f in df.image_fname.values]
    return df

In [9]:
def get_data(df, bs=4, img_size=(180, 320), augment=True):
    block = DataBlock(blocks=(ImageBlock, MaskBlock(codes=BDD_CLASSES)),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("label_fname"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

In [10]:
def log_predictions(learn):
    "Log a Table with model predictions"
    samples, outputs, predictions = get_predictions(learn)
    table = create_iou_table(samples, outputs, predictions, BDD_CLASSES)
    wandb.log({"pred_table":table})

In [11]:
def log_final_metrics(learn):
    scores = learn.validate()
    metric_names = ['final_loss'] + [f'final_{x.name}' for x in learn.metrics]
    final_results = {metric_names[i] : scores[i] for i in range(len(scores))}
    for k,v in final_results.items():
        wandb.summary[k] = v

In [12]:
def train(config):
    set_seed(config.seed, reproducible=True)
    run = wandb.init(project=WANDB_PROJECT, entity=ENTITY, job_type="training", config=config, anonymous="allow")

    config = wandb.config

    processed_dataset_dir = download_data()
    df = get_df(processed_dataset_dir)

    dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

    metrics = [MIOU(), BackgroundIOU(), RoadIOU(), TrafficLightIOU(), \
               TrafficSignIOU(), PersonIOU(), VehicleIOU(), BicycleIOU()]

    learn = unet_learner(dls, arch=getattr(tvmodels, config.arch), pretrained=config.pretrained, metrics=metrics)

    callbacks = [
        SaveModelCallback(monitor='miou'),
        WandbCallback(log_preds=False, log_model=True)
    ]

    learn.fit_one_cycle(config.epochs, config.lr, cbs=callbacks)

    if config.log_preds:
        log_predictions(learn)

    log_final_metrics(learn)

    wandb.finish()

In [13]:
train(train_config)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: (1) Private W&B dashboard, no account required
[34m[1mwandb[0m: (2) Use an existing W&B account


[34m[1mwandb[0m: Enter your choice: 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Downloading large artifact bdd_simple_1k_split:latest, 813.25MB. 4010 files... 
[34m[1mwandb[0m:   4010 of 4010 files downloaded.  
Done. 0:1:10.4
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 194MB/s]


epoch,train_loss,valid_loss,miou,background_iou,road_iou,traffic_light_iou,traffic_sign_iou,person_iou,vehicle_iou,bicycle_iou,time
0,0.508315,0.373088,0.298293,0.847094,0.627754,0.0,0.0,0.0,0.613202,0.0,00:42
1,0.414142,0.312673,0.324899,0.871612,0.77478,0.0,0.0,0.0,0.627902,0.0,00:40
2,0.353783,0.285624,0.342206,0.899976,0.811155,0.0,0.0,0.0,0.684313,0.0,00:42
3,0.30261,0.273936,0.339724,0.901371,0.806622,0.0,0.0,0.0,0.670078,0.0,00:46
4,0.279481,0.253478,0.349226,0.911338,0.824685,0.0,0.0,0.0,0.708562,0.0,00:45
5,0.254253,0.251364,0.365557,0.910293,0.827412,0.10077,0.0,0.0,0.720425,0.0,00:42
6,0.216596,0.268127,0.355422,0.907123,0.825305,0.077311,0.000538,0.0,0.67768,0.0,00:43
7,0.208213,0.229635,0.371713,0.91684,0.840746,0.106758,0.001072,0.0,0.736577,0.0,00:42
8,0.185974,0.233359,0.378048,0.917886,0.84084,0.140905,0.001732,0.0,0.744976,0.0,00:43
9,0.172926,0.227373,0.377671,0.918306,0.842633,0.135409,0.002924,0.0,0.744425,0.0,00:42


Better model found at epoch 0 with miou value: 0.2982930011210065.
Better model found at epoch 1 with miou value: 0.324899188369408.
Better model found at epoch 2 with miou value: 0.3422062811208069.
Better model found at epoch 4 with miou value: 0.3492263478062214.
Better model found at epoch 5 with miou value: 0.3655570098671184.
Better model found at epoch 7 with miou value: 0.37171328205295534.
Better model found at epoch 8 with miou value: 0.3780483566085093.


  state = torch.load(file, map_location=device, **torch_load_kwargs)


VBox(children=(Label(value='126.598 MB of 126.598 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
background_iou,▁▃▆▆▇▇▇███
bicycle_iou,▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▁▂▂▂▅▇▇▇█████████▇▇▆▆▆▆▆▅▅▄▄▃▂▂▂▂▂▂▁▁▁▁
lr_1,▂▃▄▅▆████████▇▇▇▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁▁
lr_2,▁▁▁▂▂▃▄▅▆███████▇▇▆▆▆▅▅▄▄▃▃▃▃▃▃▃▂▂▁▁▁▁▁▁
miou,▁▃▅▅▅▇▆▇██

0,1
background_iou,0.91831
bicycle_iou,0.0
epoch,10.0
eps_0,1e-05
eps_1,1e-05
eps_2,1e-05
final_background_iou,0.91789
final_bicycle_iou,0.0
final_loss,0.23336
final_miou,0.37805
