This notebook demonstrates how to train a screen classification model on the Enrico dataset. Screen classification classifies a whole screen into one of 20 possible screen categories such as `media player` or `login`

In [6]:
import os

from screenclassification.ui_datasets import *
from screenclassification.ui_models import *
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import *

from torch import nn
import torch
import datetime
from pytorch_lightning.loggers import TensorBoardLogger
import os

import torchvision.models as models

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


In [7]:
ARTIFACT_DIR = "checkpoints_screenclassification_imagenet-enrico"
CHECK_INTERVAL_STEPS = 8000

if not os.path.exists(ARTIFACT_DIR):
    os.makedirs(ARTIFACT_DIR)

In [8]:
logger = TensorBoardLogger(ARTIFACT_DIR)

## Create the data module

This model uses the Enrico dataset (Thanks Luis!) which provides a high quality dataset of UI screens for topic modeling, or screen categorization. Enrico classifies screens into categories such as `login`, `maps`, `media player`, etc. See more about the dataset here -https://userinterfaces.aalto.fi/enrico/

In [9]:
data = EnricoDataModule()

## Instantiate the model class. 

The screen classifier uses a resnet50 base model - documented here: https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html

The ResNet model is based on the Deep Residual Learning for Image Recognition paper. It is an image classification model which we will train to recognize screen types from an input screenshot. We will use the pretrained model with weights pretrained on the ImageNet data. Then we add another layer to the output to map to the correct number of screen category classes. You can find a full list of classification models here: https://pytorch.org/vision/main/models.html#classification. To use some of these, you will likely need to modify the model definitions in the `__init__` function of `UIScreenClassifier`. 

model = UIScreenClassifier(num_classes=20, arch="resnet50pretrained")

In [13]:
print("***********************************")
print("checkpoints: " + str(os.listdir(ARTIFACT_DIR)))
print("***********************************")

checkpoint_callback = ModelCheckpoint(dirpath=ARTIFACT_DIR, every_n_train_steps=CHECK_INTERVAL_STEPS, save_last=True)
checkpoint_callback2 = ModelCheckpoint(dirpath=ARTIFACT_DIR, filename= "screenclassification", monitor="f1_weighted", mode="max", save_top_k=1)
earlystopping_callback = EarlyStopping(monitor="f1_weighted", mode="max", patience=20)

***********************************
checkpoints: []
***********************************


In [14]:
trainer = Trainer(
    accelerator='cpu',
    gradient_clip_val=1.0,
    callbacks=[checkpoint_callback, checkpoint_callback2, earlystopping_callback],
    logger=logger,
    accumulate_grad_batches=2,
    min_epochs=10, 
    max_epochs=100
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [15]:
trainer.fit(model, data)

Missing logger folder: checkpoints_screenclassification_imagenet-enrico/lightning_logs

  | Name     | Type       | Params
----------------------------------------
0 | model    | ResNet     | 23.5 M
1 | conv_cls | Sequential | 368 K 
----------------------------------------
23.9 M    Trainable params
0         Non-trainable params
23.9 M    Total params
95.458    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

{'f1_macro': 0.028708133971291863, 'f1_micro': 0.1875, 'f1_weighted': 0.05921052631578947}


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Instead of `mAP` like when you trained the UIElementDetector, this model uses 3 f1 scores. `f1_macro` takes the mean of all the per class `F1` scores. `f1_micro1` computes a global average F1 score by counting the sums of the True Positives (TP), False Negatives (FN), and False Positives (FP) and feeding them into the F1 score equation. Finally `f1_weighted` takes the mean of all per-class F1 scores while weighting them by the number of actual occurrences of the class in the dataset. Ideally, you would want to see all of these metrics continue increasing as the model trains. 