## Assumptions

In the first notebook we have established the structure and contents of the `modelnet40v1` dataset.

Now, we will attempt to create a model that is capable of determining the type of the object based on all images of an object instance. The type of architecture that allows just that is called a *Multi-View Convolutional Network (`MVCNN`)* - you can read about it in more detail [in this arXiv paper](https://arxiv.org/pdf/1505.00880.pdf). In short: it allows feature extraction from multiple images, pools the individual image feature vectors using a symmetrical function and finally makes a prediction on the pooled vector. Because we are using a symmetrical function for pooling, the number of images per instance is arbitrary and can change from instance to instance.

![mvcnn_figure](./data/mvcnn_figure.JPG)

In our case, this experiment will be performed with the following configuration - we'll use:
* `ResNet50` pretrained on the `ImageNet` dataset as the backbone feature extractor for the individual images
* `mean` as the symmetrical pooling function for the individual vectors 
* `dropout` in the layers responsible for pooling and prediction
* `Adam` as the optimizer
* `F1-Macro` calculated on the validation set as the driving metric to reveal our champion model

In addition, because the experiment will be performed on a local machine with no access to a reasonable GPU, we will make some simplifications in the training pipeline:
* the number of classes in training will be limited to just 4: `airplane`, `bathtub`, `bed` and `bench`
* the pretrained feature extractor weights will remain frozen throughout the training
* at training time we'll use the batch dimension of the input to provide the instance images - instances will be forwarded through the net one by one, i.e. the *de facto* batch size will be 1

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

from dataset_classes.mvcnn_data_module import MVCNNDataModule
from model_classes.mvcnn import MVCNNClassifier
from model_classes.callbacks import UnfreezePretrainedWeights, ResetEvalResults

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


## Parameters

In [2]:
NUM_CLASSES = 4
LEARNING_RATE = 1e-3
LEARNING_RATE_REDUCTION_FACTOR = 1e3
NUM_EPOCHS = 20
NUM_EPOCHS_FREEZE_PRETRAINED = 20
BATCH_SIZE = 1
DROPOUT_RATE = 0.3
SAVE_PATH = './output'

## Class initialization

In [3]:
data_module = MVCNNDataModule(NUM_CLASSES, BATCH_SIZE)

Dataset type: TRAIN
------------------------------------------------------------
Class name: AIRPLANE
Total number of instances: 80
Total number of images: 960
------------------------------------------------------------
Class name: BATHTUB
Total number of instances: 80
Total number of images: 960
------------------------------------------------------------
Class name: BED
Total number of instances: 80
Total number of images: 960
------------------------------------------------------------
Class name: BENCH
Total number of instances: 80
Total number of images: 960
Dataset type: TEST
------------------------------------------------------------
Class name: AIRPLANE
Total number of instances: 20
Total number of images: 240
------------------------------------------------------------
Class name: BATHTUB
Total number of instances: 20
Total number of images: 240
------------------------------------------------------------
Class name: BED
Total number of instances: 20
Total number of images: 

In [4]:
model = MVCNNClassifier(
    learning_rate=LEARNING_RATE,
    num_epochs_freeze_pretrained=NUM_EPOCHS_FREEZE_PRETRAINED,
    dropout_rate=DROPOUT_RATE,
    )

Feature extractor weights frozen


In [5]:
callbacks = [
    ModelCheckpoint(monitor='val_f1', verbose=True, mode='max'),
    UnfreezePretrainedWeights(LEARNING_RATE_REDUCTION_FACTOR),
    ResetEvalResults(NUM_CLASSES)
]

In [6]:
trainer = Trainer(
    max_epochs=NUM_EPOCHS,
    fast_dev_run=False,
    default_root_dir=SAVE_PATH,
    callbacks=callbacks
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores


## Model training

In [7]:
trainer.fit(
    model,
    train_dataloader=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)


  | Name                 | Type             | Params
----------------------------------------------------------
0 | feature_extractor    | FeatureExtractor | 25.6 M
1 | image_vector_creator | Sequential       | 512 K 
2 | predictor            | Sequential       | 74.5 K
----------------------------------------------------------
586 K     Trainable params
25.6 M    Non-trainable params
26.1 M    Total params
104.576   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


Eval stats:
TP per class: [0 0 0 0] Average: 0.00
TN per class: [0 2 2 0] Average: 1.00
FP per class: [0 0 0 2] Average: 0.50
FN per class: [2 0 0 0] Average: 0.50
ACCURACY per class: [0. 1. 1. 0.] Average: 0.50
PRECISION per class: [0. 0. 0. 0.] Average: 0.00
RECALL per class: [0. 0. 0. 0.] Average: 0.00
F1 per class: [0. 0. 0. 0.] Average: 0.00
Avg Loss val tensor(1.4337)
F1 val tensor(0.)
Accuracy VAL tensor(0.7500)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 0, global step 319: val_f1 reached 0.54300 (best 0.54300), saving model to "output/lightning_logs/version_0/checkpoints/epoch=0-step=319.ckpt" as top 1



Eval stats:
TP per class: [19 15 15  0] Average: 12.25
TN per class: [57 55 37 60] Average: 52.25
FP per class: [ 3  5 23  0] Average: 7.75
FN per class: [ 1  5  5 20] Average: 7.75
ACCURACY per class: [0.95  0.875 0.65  0.75 ] Average: 0.81
PRECISION per class: [0.86363636 0.75       0.39473684 0.        ] Average: 0.50
RECALL per class: [0.95 0.75 0.75 0.  ] Average: 0.61
F1 per class: [0.9047619  0.75       0.51724138 0.        ] Average: 0.54
Avg Loss val tensor(0.8390)
F1 val tensor(0.5430)
Accuracy VAL tensor(0.8250)


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 639: val_f1 reached 0.61787 (best 0.61787), saving model to "output/lightning_logs/version_0/checkpoints/epoch=1-step=639.ckpt" as top 1



Eval stats:
TP per class: [17 17  2 17] Average: 13.25
TN per class: [58 56 59 40] Average: 53.25
FP per class: [ 2  4  1 20] Average: 6.75
FN per class: [ 3  3 18  3] Average: 6.75
ACCURACY per class: [0.9375 0.9125 0.7625 0.7125] Average: 0.83
PRECISION per class: [0.89473684 0.80952381 0.66666667 0.45945946] Average: 0.71
RECALL per class: [0.85 0.85 0.1  0.85] Average: 0.66
F1 per class: [0.87179487 0.82926829 0.17391304 0.59649123] Average: 0.62
Avg Loss val tensor(1.3921)
F1 val tensor(0.6179)
Accuracy VAL tensor(0.8281)


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 959: val_f1 reached 0.78818 (best 0.78818), saving model to "output/lightning_logs/version_0/checkpoints/epoch=2-step=959.ckpt" as top 1



Eval stats:
TP per class: [19 14 14 16] Average: 15.75
TN per class: [57 59 56 51] Average: 55.75
FP per class: [3 1 4 9] Average: 4.25
FN per class: [1 6 6 4] Average: 4.25
ACCURACY per class: [0.95   0.9125 0.875  0.8375] Average: 0.89
PRECISION per class: [0.86363636 0.93333333 0.77777778 0.64      ] Average: 0.80
RECALL per class: [0.95 0.7  0.7  0.8 ] Average: 0.79
F1 per class: [0.9047619  0.8        0.73684211 0.71111111] Average: 0.79
Avg Loss val tensor(1.1146)
F1 val tensor(0.7882)
Accuracy VAL tensor(0.8969)


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 1279: val_f1 was not in top 1



Eval stats:
TP per class: [19 16  8 16] Average: 14.75
TN per class: [57 56 58 48] Average: 54.75
FP per class: [ 3  4  2 12] Average: 5.25
FN per class: [ 1  4 12  4] Average: 5.25
ACCURACY per class: [0.95  0.9   0.825 0.8  ] Average: 0.87
PRECISION per class: [0.86363636 0.8        0.8        0.57142857] Average: 0.76
RECALL per class: [0.95 0.8  0.4  0.8 ] Average: 0.74
F1 per class: [0.9047619  0.8        0.53333333 0.66666667] Average: 0.73
Avg Loss val tensor(1.2800)
F1 val tensor(0.7262)
Accuracy VAL tensor(0.8750)


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 1599: val_f1 reached 0.85042 (best 0.85042), saving model to "output/lightning_logs/version_0/checkpoints/epoch=4-step=1599.ckpt" as top 1



Eval stats:
TP per class: [19 16 19 14] Average: 17.00
TN per class: [57 59 52 60] Average: 57.00
FP per class: [3 1 8 0] Average: 3.00
FN per class: [1 4 1 6] Average: 3.00
ACCURACY per class: [0.95   0.9375 0.8875 0.925 ] Average: 0.93
PRECISION per class: [0.86363636 0.94117647 0.7037037  1.        ] Average: 0.88
RECALL per class: [0.95 0.8  0.95 0.7 ] Average: 0.85
F1 per class: [0.9047619  0.86486486 0.80851064 0.82352941] Average: 0.85
Avg Loss val tensor(1.1935)
F1 val tensor(0.8504)
Accuracy VAL tensor(0.9187)


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 1919: val_f1 was not in top 1



Eval stats:
TP per class: [14 16  3 18] Average: 12.75
TN per class: [57 59 60 35] Average: 52.75
FP per class: [ 3  1  0 25] Average: 7.25
FN per class: [ 6  4 17  2] Average: 7.25
ACCURACY per class: [0.8875 0.9375 0.7875 0.6625] Average: 0.82
PRECISION per class: [0.82352941 0.94117647 1.         0.41860465] Average: 0.80
RECALL per class: [0.7  0.8  0.15 0.9 ] Average: 0.64
F1 per class: [0.75675676 0.86486486 0.26086957 0.57142857] Average: 0.61
Avg Loss val tensor(2.1630)
F1 val tensor(0.6135)
Accuracy VAL tensor(0.8219)


Validating: 0it [00:00, ?it/s]

Epoch 6, global step 2239: val_f1 was not in top 1



Eval stats:
TP per class: [19 16 11 17] Average: 15.75
TN per class: [57 60 59 47] Average: 55.75
FP per class: [ 3  0  1 13] Average: 4.25
FN per class: [1 4 9 3] Average: 4.25
ACCURACY per class: [0.95  0.95  0.875 0.8  ] Average: 0.89
PRECISION per class: [0.86363636 1.         0.91666667 0.56666667] Average: 0.84
RECALL per class: [0.95 0.8  0.55 0.85] Average: 0.79
F1 per class: [0.9047619  0.88888889 0.6875     0.68      ] Average: 0.79
Avg Loss val tensor(1.5115)
F1 val tensor(0.7903)
Accuracy VAL tensor(0.8969)


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 2559: val_f1 was not in top 1



Eval stats:
TP per class: [12 11 19 13] Average: 13.75
TN per class: [53 60 45 57] Average: 53.75
FP per class: [ 7  0 15  3] Average: 6.25
FN per class: [8 9 1 7] Average: 6.25
ACCURACY per class: [0.8125 0.8875 0.8    0.875 ] Average: 0.84
PRECISION per class: [0.63157895 1.         0.55882353 0.8125    ] Average: 0.75
RECALL per class: [0.6  0.55 0.95 0.65] Average: 0.69
F1 per class: [0.61538462 0.70967742 0.7037037  0.72222222] Average: 0.69
Avg Loss val tensor(2.8054)
F1 val tensor(0.6877)
Accuracy VAL tensor(0.8438)


Validating: 0it [00:00, ?it/s]

Epoch 8, global step 2879: val_f1 was not in top 1



Eval stats:
TP per class: [20 17  9 15] Average: 15.25
TN per class: [51 60 59 51] Average: 55.25
FP per class: [9 0 1 9] Average: 4.75
FN per class: [ 0  3 11  5] Average: 4.75
ACCURACY per class: [0.8875 0.9625 0.85   0.825 ] Average: 0.88
PRECISION per class: [0.68965517 1.         0.9        0.625     ] Average: 0.80
RECALL per class: [1.   0.85 0.45 0.75] Average: 0.76
F1 per class: [0.81632653 0.91891892 0.6        0.68181818] Average: 0.75
Avg Loss val tensor(1.4977)
F1 val tensor(0.7543)
Accuracy VAL tensor(0.8813)


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 3199: val_f1 was not in top 1



Eval stats:
TP per class: [19 17 15 16] Average: 16.75
TN per class: [57 58 57 55] Average: 56.75
FP per class: [3 2 3 5] Average: 3.25
FN per class: [1 3 5 4] Average: 3.25
ACCURACY per class: [0.95   0.9375 0.9    0.8875] Average: 0.92
PRECISION per class: [0.86363636 0.89473684 0.83333333 0.76190476] Average: 0.84
RECALL per class: [0.95 0.85 0.75 0.8 ] Average: 0.84
F1 per class: [0.9047619  0.87179487 0.78947368 0.7804878 ] Average: 0.84
Avg Loss val tensor(1.0908)
F1 val tensor(0.8366)
Accuracy VAL tensor(0.9187)


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 3519: val_f1 was not in top 1



Eval stats:
TP per class: [19 17 11 12] Average: 14.75
TN per class: [49 58 57 55] Average: 54.75
FP per class: [11  2  3  5] Average: 5.25
FN per class: [1 3 9 8] Average: 5.25
ACCURACY per class: [0.85   0.9375 0.85   0.8375] Average: 0.87
PRECISION per class: [0.63333333 0.89473684 0.78571429 0.70588235] Average: 0.75
RECALL per class: [0.95 0.85 0.55 0.6 ] Average: 0.74
F1 per class: [0.76       0.87179487 0.64705882 0.64864865] Average: 0.73
Avg Loss val tensor(2.1048)
F1 val tensor(0.7319)
Accuracy VAL tensor(0.8719)


Validating: 0it [00:00, ?it/s]

Epoch 11, global step 3839: val_f1 was not in top 1



Eval stats:
TP per class: [20 17 10 16] Average: 15.75
TN per class: [52 54 59 58] Average: 55.75
FP per class: [8 6 1 2] Average: 4.25
FN per class: [ 0  3 10  4] Average: 4.25
ACCURACY per class: [0.9    0.8875 0.8625 0.925 ] Average: 0.89
PRECISION per class: [0.71428571 0.73913043 0.90909091 0.88888889] Average: 0.81
RECALL per class: [1.   0.85 0.5  0.8 ] Average: 0.79
F1 per class: [0.83333333 0.79069767 0.64516129 0.84210526] Average: 0.78
Avg Loss val tensor(1.1334)
F1 val tensor(0.7778)
Accuracy VAL tensor(0.8938)


Validating: 0it [00:00, ?it/s]

Epoch 12, global step 4159: val_f1 reached 0.87524 (best 0.87524), saving model to "output/lightning_logs/version_0/checkpoints/epoch=12-step=4159.ckpt" as top 1



Eval stats:
TP per class: [19 16 18 17] Average: 17.50
TN per class: [57 60 57 56] Average: 57.50
FP per class: [3 0 3 4] Average: 2.50
FN per class: [1 4 2 3] Average: 2.50
ACCURACY per class: [0.95   0.95   0.9375 0.9125] Average: 0.94
PRECISION per class: [0.86363636 1.         0.85714286 0.80952381] Average: 0.88
RECALL per class: [0.95 0.8  0.9  0.85] Average: 0.88
F1 per class: [0.9047619  0.88888889 0.87804878 0.82926829] Average: 0.88
Avg Loss val tensor(1.4887)
F1 val tensor(0.8752)
Accuracy VAL tensor(0.9375)


Validating: 0it [00:00, ?it/s]

Epoch 13, global step 4479: val_f1 was not in top 1



Eval stats:
TP per class: [14 10 20 16] Average: 15.00
TN per class: [59 60 45 56] Average: 55.00
FP per class: [ 1  0 15  4] Average: 5.00
FN per class: [ 6 10  0  4] Average: 5.00
ACCURACY per class: [0.9125 0.875  0.8125 0.9   ] Average: 0.88
PRECISION per class: [0.93333333 1.         0.57142857 0.8       ] Average: 0.83
RECALL per class: [0.7 0.5 1.  0.8] Average: 0.75
F1 per class: [0.8        0.66666667 0.72727273 0.8       ] Average: 0.75
Avg Loss val tensor(2.9285)
F1 val tensor(0.7485)
Accuracy VAL tensor(0.8750)


Validating: 0it [00:00, ?it/s]

Epoch 14, global step 4799: val_f1 was not in top 1



Eval stats:
TP per class: [ 7 15 19 13] Average: 13.50
TN per class: [56 60 49 49] Average: 53.50
FP per class: [ 4  0 11 11] Average: 6.50
FN per class: [13  5  1  7] Average: 6.50
ACCURACY per class: [0.7875 0.9375 0.85   0.775 ] Average: 0.84
PRECISION per class: [0.63636364 1.         0.63333333 0.54166667] Average: 0.70
RECALL per class: [0.35 0.75 0.95 0.65] Average: 0.67
F1 per class: [0.4516129  0.85714286 0.76       0.59090909] Average: 0.66
Avg Loss val tensor(4.8759)
F1 val tensor(0.6649)
Accuracy VAL tensor(0.8375)


Validating: 0it [00:00, ?it/s]

Epoch 15, global step 5119: val_f1 was not in top 1



Eval stats:
TP per class: [15  8 19  9] Average: 12.75
TN per class: [52 60 39 60] Average: 52.75
FP per class: [ 8  0 21  0] Average: 7.25
FN per class: [ 5 12  1 11] Average: 7.25
ACCURACY per class: [0.8375 0.85   0.725  0.8625] Average: 0.82
PRECISION per class: [0.65217391 1.         0.475      1.        ] Average: 0.78
RECALL per class: [0.75 0.4  0.95 0.45] Average: 0.64
F1 per class: [0.69767442 0.57142857 0.63333333 0.62068966] Average: 0.63
Avg Loss val tensor(4.9173)
F1 val tensor(0.6308)
Accuracy VAL tensor(0.8188)


Validating: 0it [00:00, ?it/s]

Epoch 16, global step 5439: val_f1 was not in top 1



Eval stats:
TP per class: [11 10 19 14] Average: 13.50
TN per class: [54 60 46 54] Average: 53.50
FP per class: [ 6  0 14  6] Average: 6.50
FN per class: [ 9 10  1  6] Average: 6.50
ACCURACY per class: [0.8125 0.875  0.8125 0.85  ] Average: 0.84
PRECISION per class: [0.64705882 1.         0.57575758 0.7       ] Average: 0.73
RECALL per class: [0.55 0.5  0.95 0.7 ] Average: 0.68
F1 per class: [0.59459459 0.66666667 0.71698113 0.7       ] Average: 0.67
Avg Loss val tensor(4.9312)
F1 val tensor(0.6696)
Accuracy VAL tensor(0.8375)


Validating: 0it [00:00, ?it/s]

Epoch 17, global step 5759: val_f1 was not in top 1



Eval stats:
TP per class: [16 13 20 12] Average: 15.25
TN per class: [53 60 48 60] Average: 55.25
FP per class: [ 7  0 12  0] Average: 4.75
FN per class: [4 7 0 8] Average: 4.75
ACCURACY per class: [0.8625 0.9125 0.85   0.9   ] Average: 0.88
PRECISION per class: [0.69565217 1.         0.625      1.        ] Average: 0.83
RECALL per class: [0.8  0.65 1.   0.6 ] Average: 0.76
F1 per class: [0.74418605 0.78787879 0.76923077 0.75      ] Average: 0.76
Avg Loss val tensor(3.2532)
F1 val tensor(0.7628)
Accuracy VAL tensor(0.8813)


Validating: 0it [00:00, ?it/s]

Epoch 18, global step 6079: val_f1 was not in top 1



Eval stats:
TP per class: [13 14  9 16] Average: 13.00
TN per class: [54 60 57 41] Average: 53.00
FP per class: [ 6  0  3 19] Average: 7.00
FN per class: [ 7  6 11  4] Average: 7.00
ACCURACY per class: [0.8375 0.925  0.825  0.7125] Average: 0.83
PRECISION per class: [0.68421053 1.         0.75       0.45714286] Average: 0.72
RECALL per class: [0.65 0.7  0.45 0.8 ] Average: 0.65
F1 per class: [0.66666667 0.82352941 0.5625     0.58181818] Average: 0.66
Avg Loss val tensor(4.6792)
F1 val tensor(0.6586)
Accuracy VAL tensor(0.8250)


Validating: 0it [00:00, ?it/s]

Epoch 19, global step 6399: val_f1 was not in top 1



Eval stats:
TP per class: [19 14 20 14] Average: 16.75
TN per class: [57 60 51 59] Average: 56.75
FP per class: [3 0 9 1] Average: 3.25
FN per class: [1 6 0 6] Average: 3.25
ACCURACY per class: [0.95   0.925  0.8875 0.9125] Average: 0.92
PRECISION per class: [0.86363636 1.         0.68965517 0.93333333] Average: 0.87
RECALL per class: [0.95 0.7  1.   0.7 ] Average: 0.84
F1 per class: [0.9047619  0.82352941 0.81632653 0.8       ] Average: 0.84
Avg Loss val tensor(1.9574)
F1 val tensor(0.8362)
Accuracy VAL tensor(0.9219)
