# BiteMe | Train

This notebook includes the most important part of the project - the modelling. The notebook tests methodologies for training, and in it the chosen algorithm is decided. Validation also occurs before final testing, which is conducted in the test notebook. This stage is highly iterative, so all model artefacts, logs and configurations are recorded and saved to disk automatically. This initial setup of what will eventually become MLOps for the final product will be really useful, and helps keep track of what is successful and what isn't.

Models to try:

 - [SE-ResNet50](https://github.com/Cadene/pretrained-models.pytorch#senet)


 - [DenseNet121](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [DenseNet161](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [DenseNet169](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [DenseNet201](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [DenseNet201](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [DualPathNet68](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
 - [DualPathNet92](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
 - [DualPathNet98](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
 - [DualPathNet107](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
 - [DualPathNet113](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
 - [FBResNet152](https://github.com/Cadene/pretrained-models.pytorch#facebook-resnet)
 - [InceptionResNetV2](https://github.com/Cadene/pretrained-models.pytorch#inception)
 - [InceptionV4](https://github.com/Cadene/pretrained-models.pytorch#inception)
 - [NASNet-A-Large](https://github.com/Cadene/pretrained-models.pytorch#nasnet)
 - [PNASNet-5-Large](https://github.com/Cadene/pretrained-models.pytorch#pnasnet)
 - [PolyNet](https://github.com/Cadene/pretrained-models.pytorch#polynet)
 - [ResNeXt101_32x4d](https://github.com/Cadene/pretrained-models.pytorch#resnext)
 - [ResNeXt101_64x4d](https://github.com/Cadene/pretrained-models.pytorch#resnext)
 - [ResNet101](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [ResNet152](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [ResNet50](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [SENet154](https://github.com/Cadene/pretrained-models.pytorch#senet)
 - [SE-ResNet101](https://github.com/Cadene/pretrained-models.pytorch#senet)
 - [SE-ResNet152](https://github.com/Cadene/pretrained-models.pytorch#senet)
 - [SE-ResNeXt50_32x4d](https://github.com/Cadene/pretrained-models.pytorch#senet)
 - [SE-ResNeXt101_32x4d](https://github.com/Cadene/pretrained-models.pytorch#senet)
 - [SqueezeNet1_0](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [SqueezeNet1_1](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [VGG16](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [VGG19](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [VGG16_BN](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [VGG19_BN](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
 - [Xception](https://github.com/Cadene/pretrained-models.pytorch#xception)


- efficientnet_b0
- efficientnet_b1
- efficientnet_b2
- efficientnet_b3
- efficientnet_b4
- efficientnet_b5

Initial model work is done by using simple, typical image recognition models (CNN architectures) to see how effective these models can be for the problem. Although I don't expect them to be particularly successful, it's important to establish baselines and take a holistic approach to modelling when it's possible.

In [1]:
# Basic imports
import pandas as pd
import numpy as np
import os
import sys
from argparse import ArgumentParser
import datetime
from time import time
import gc

# Data visualisation
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn

# Image processing
import cv2
import albumentations as A
import imgaug as ia
import imgaug.augmenters as iaa

# Model evaluation
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score

import torch
import pretrainedmodels
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

# Local imports
sys.path.append("..")
from utils.dataset import generate_transforms, generate_dataloaders
from models.models import se_resnet50
from utils.loss_function import CrossEntropyLossOneHot
from utils.lrs_scheduler import WarmRestart, warm_restart
from utils.utils import read_images, augs, get_augs, seed_reproducer, init_logger
from utils.constants import *

plt.rcParams["figure.figsize"] = (14, 8)

In [2]:
# Define directories
base_dir_path = "../"

data_dir_path = os.path.join(base_dir_path, "data")
data_preprocessed_dir_path = os.path.join(data_dir_path, "preprocessed")
data_preprocessed_train_dir_path = os.path.join(data_dir_path, "preprocessed/train")

data_dir = os.listdir(data_dir_path)
data_preprocessed_dir = os.listdir(data_preprocessed_dir_path)
data_preprocessed_train_dir = os.listdir(data_preprocessed_train_dir_path)

metadata_preprocessed_path = os.path.join(data_preprocessed_dir_path, "metadata.csv")
metadata = pd.read_csv(metadata_preprocessed_path)
# Subset to train only
metadata = metadata.loc[metadata.split == "train"]

metadata.head()

Unnamed: 0,img_name,img_path,label,split
1,ea1b100b581fcdb7ddfae52cc62347a99e304ba4.jpg,../data/cleaned/none/ea1b100b581fcdb7ddfae52cc...,none,train
2,6eac051b9c45ff6821ec8675216f371711b7cea9.jpg,../data/cleaned/none/6eac051b9c45ff6821ec86752...,none,train
3,fc72767f8520df9b2b83941077dc0ee013eb9399.jpg,../data/cleaned/none/fc72767f8520df9b2b8394107...,none,train
4,49850884a00703afe5aab78c3ce074d2d4acae30.jpg,../data/cleaned/none/49850884a00703afe5aab78c3...,none,train
5,74c8654309dbd09440342475dc307dbd7195bd28.jpg,../data/cleaned/none/74c8654309dbd09440342475d...,none,train


In [3]:
# Read in train images
X_train = read_images(
    data_dir_path=data_preprocessed_train_dir_path, 
    rows=ROWS, 
    cols=COLS, 
    channels=CHANNELS, 
    write_images=False, 
    output_data_dir_path=None,
    verbose=VERBOSE
)

# Get labels
y_train = np.array(pd.get_dummies(metadata["label"]))

Reading images from: ../data/preprocessed/train
Rows set to 1024
Columns set to 1024
Channels set to 3
Writing images is set to: False
Reading images...


100%|███████████████████████████████████████████| 27/27 [00:00<00:00, 48.08it/s]
100%|███████████████████████████████████████████| 55/55 [00:02<00:00, 20.39it/s]
100%|███████████████████████████████████████████| 21/21 [00:01<00:00, 13.25it/s]
100%|███████████████████████████████████████████| 46/46 [00:04<00:00, 10.23it/s]
100%|███████████████████████████████████████████| 25/25 [00:03<00:00,  8.15it/s]
100%|███████████████████████████████████████████| 21/21 [00:02<00:00,  7.23it/s]
100%|███████████████████████████████████████████| 27/27 [00:04<00:00,  6.47it/s]
100%|███████████████████████████████████████████| 21/21 [00:03<00:00,  5.84it/s]


Image reading complete.
Image array shape: (243, 1024, 1024, 3)


## Set Parameters

In [4]:
# Choose augmentations to use in preprocessing
# For full list see helpers.py
#augs_to_select = [
#    "Resize",
#    "HorizontalFlip", 
#    "VerticalFlip",
#    "Normalize"
#]
## Subset augs based on those selected
#AUGS = dict((aug_name, augs[aug_name]) for aug_name in augs_to_select)


def init_hparams():
    """
    Initialise hyperparameters for modelling.
    
    Returns
    ---------
    hparams : argparse.Namespace
        Parsed hyperparameters
    """
    parser = ArgumentParser(add_help=False)
    parser.add_argument("-backbone", "--backbone", type=str, default=MODEL_NAME)
    parser.add_argument("-device_name", type=str, default=DEVICE_NAME)
    parser.add_argument("--gpus", default=[0])
    parser.add_argument("--n_workers", type=int, default=N_WORKERS)
    parser.add_argument("--image_size", nargs="+", default=[ROWS, COLS])
    parser.add_argument("--seed", type=int, default=SEED)
    parser.add_argument("--min_epochs", type=int, default=MIN_EPOCHS)
    parser.add_argument("--max_epochs", type=int, default=MAX_EPOCHS)
    parser.add_argument("--patience", type=str, default=PATIENCE)    
    parser.add_argument("-tbs", "--train_batch_size", type=int, default=TRAIN_BATCH_SIZE)
    parser.add_argument("-vbs", "--val_batch_size", type=int, default=VAL_BATCH_SIZE)
    parser.add_argument("--n_splits", type=int, default=N_SPLITS)
    parser.add_argument("--test_size", type=float, default=TEST_SIZE)
    parser.add_argument("--precision", type=int, default=PRECISION)
    parser.add_argument("--gradient_clip_val", type=float, default=GRADIENT_CLIP_VAL)
    parser.add_argument("--verbose", type=str, default=VERBOSE)
    parser.add_argument("--log_dir", type=str, default=LOG_DIR)
    parser.add_argument("--log_name", type=str, default=LOG_NAME)
    
    
    try:
        hparams, unknown = parser.parse_known_args()
    except:
        hparams, unknown = parser.parse_args([])

    if len(hparams.gpus) == 1:
        hparams.gpus = [int(hparams.gpus[0])]
    else:
        hparams.gpus = [int(gpu) for gpu in hparams.gpus]

    hparams.image_size = [int(size) for size in hparams.image_size]
    
    return hparams

### Create Model

In [5]:
class CoolSystem(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        seed_reproducer(self.hparams.seed)

        self.model = se_resnet50()
        self.criterion = CrossEntropyLossOneHot()
        self.logger_kun = init_logger(
            hparams.log_name, 
            hparams.log_dir
        )

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        self.optimizer = torch.optim.Adam(
            self.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0
        )
        self.scheduler = WarmRestart(self.optimizer, T_max=2, T_mult=3, eta_min=1e-5)
        return [self.optimizer], [self.scheduler]

    def training_step(self, batch, batch_idx):
        step_start_time = time()
        images, labels, data_load_time = batch

        scores = self(images)
        loss = self.criterion(scores, labels)

        data_load_time = torch.sum(data_load_time)

        return {
            "loss": loss,
            "data_load_time": data_load_time,
            "batch_run_time": torch.Tensor([time() - step_start_time + data_load_time]).to(
                data_load_time.device
            ),
        }

    def training_epoch_end(self, outputs):
        # outputs is the return of training_step
        train_loss_mean = torch.stack([output["loss"] for output in outputs]).mean()
        self.data_load_times = torch.stack([output["data_load_time"] for output in outputs]).sum()
        self.batch_run_times = torch.stack([output["batch_run_time"] for output in outputs]).sum()

        self.current_epoch += 1
        if self.current_epoch < (self.trainer.max_epochs - 4):
            self.scheduler = warm_restart(self.scheduler, T_mult=2)

        return {"train_loss": train_loss_mean}

    def validation_step(self, batch, batch_idx):
        step_start_time = time()
        images, labels, data_load_time = batch
        data_load_time = torch.sum(data_load_time)
        scores = self(images)
        loss = self.criterion(scores, labels)

        # must return key -> val_loss
        return {
            "val_loss": loss,
            "scores": scores,
            "labels": labels,
            "data_load_time": data_load_time,
            "batch_run_time": torch.Tensor([time() - step_start_time + data_load_time]).to(
                data_load_time.device
            ),
        }

    def validation_epoch_end(self, outputs):
        # compute loss
        val_loss_mean = torch.stack([output["val_loss"] for output in outputs]).mean()
        self.data_load_times = torch.stack([output["data_load_time"] for output in outputs]).sum()
        self.batch_run_times = torch.stack([output["batch_run_time"] for output in outputs]).sum()

        # compute roc_auc
        scores_all = torch.cat([output["scores"] for output in outputs]).cpu()
        labels_all = torch.round(torch.cat([output["labels"] for output in outputs]).cpu())
        #print(f"scores_all: {scores_all}")
        val_roc_auc = torch.tensor(roc_auc_score(labels_all, scores_all))

        # terminal logs
        self.logger_kun.info(
            f"{self.hparams.fold_i}-{self.current_epoch} | "
            f"lr : {self.scheduler.get_lr()[0]:.6f} | "
            f"val_loss : {val_loss_mean:.4f} | "
            f"val_roc_auc : {val_roc_auc:.4f} | "
            f"data_load_times : {self.data_load_times:.2f} | "
            f"batch_run_times : {self.batch_run_times:.2f}"
        )

        return {"val_loss": val_loss_mean, "val_roc_auc": val_roc_auc}

## Cross Validation

In [6]:
# Initialise hyperparameters
hparams = init_hparams()

log_notes = "decreased lr T_max to 2"

# Initialise logger
logger = init_logger(hparams.log_name, hparams.log_dir)

# Log parameters
logger.info(f"backbone: {hparams.backbone}")
logger.info(f"device_name: {hparams.device_name}")
logger.info(f"gpus: {hparams.gpus}")
logger.info(f"n_workers: {hparams.n_workers}")
logger.info(f"image_size: {hparams.image_size}")
logger.info(f"seed: {hparams.seed}")
logger.info(f"min_epochs: {hparams.min_epochs}")
logger.info(f"max_epochs: {hparams.max_epochs}")
logger.info(f"patience: {hparams.patience}")
logger.info(f"train_batch_size: {hparams.train_batch_size}")
logger.info(f"val_batch_size: {hparams.val_batch_size}")
logger.info(f"n_splits: {hparams.n_splits}")
logger.info(f"test_size: {hparams.test_size}")
logger.info(f"precision: {hparams.precision}")
logger.info(f"gradient_clip_val: {hparams.gradient_clip_val}")
logger.info(f"log_dir: {hparams.log_dir}")
logger.info(f"log_name: {hparams.log_name}")

# Log any notes if they exist
if "log_notes" in locals():
    logger.info(f"Notes: {log_notes}")


# Create transform pipeline
transforms = generate_transforms(hparams.image_size)

# List for validation scores 
valid_roc_auc_scores = []

# Initialise cross validation
folds = StratifiedKFold(n_splits=hparams.n_splits, shuffle=True, random_state=hparams.seed)

# Start cross validation
for fold_i, (train_index, val_index) in enumerate(folds.split(metadata[["img_path"]], metadata[["label"]])):
    hparams.fold_i = fold_i
    # Split train images and validation sets
    train_data = metadata.iloc[train_index][["img_path", "label"]].reset_index(drop=True)
    train_data = pd.get_dummies(train_data, columns=["label"], prefix="", prefix_sep="")

    val_data = metadata.iloc[val_index][["img_path", "label"]].reset_index(drop=True)
    val_data = pd.get_dummies(val_data, columns=["label"], prefix="", prefix_sep="")
    
    logger.info(f"Fold {fold_i} num train records: {train_data.shape[0]}")
    logger.info(f"Fold {fold_i} num val records: {val_data.shape[0]}")
    
    train_dataloader, val_dataloader = generate_dataloaders(hparams, train_data, val_data, transforms)
    
    checkpoint_callback = ModelCheckpoint(
        monitor="val_roc_auc",
        save_top_k=2,
        mode="max",
        filepath=os.path.join(
            hparams.log_dir, 
            hparams.log_name, 
            f"fold={fold_i}" + "-{epoch}-{val_loss:.4f}-{val_roc_auc:.4f}"
        )
    )
    
    early_stop_callback = EarlyStopping(
        monitor="val_roc_auc", 
        patience=hparams.patience, 
        mode="max", 
        verbose=hparams.verbose
    )
    
    # Instance Model, Trainer and train model
    model = CoolSystem(hparams)
    trainer = pl.Trainer(
        gpus=hparams.gpus,
        min_epochs=hparams.min_epochs,
        max_epochs=hparams.max_epochs,
        early_stop_callback=early_stop_callback,
        checkpoint_callback=checkpoint_callback,
        progress_bar_refresh_rate=0,
        precision=hparams.precision,
        num_sanity_val_steps=0,
        profiler=False,
        weights_summary=None,
        gradient_clip_val=hparams.gradient_clip_val,
        default_root_dir=os.path.join(hparams.log_dir, hparams.log_name)
    )
    
    # Fit model
    trainer.fit(model, train_dataloader, val_dataloader)
            
    # Save val scores
    valid_roc_auc_scores.append(checkpoint_callback.best)
    
    # Cleanup
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
valid_roc_auc_scores = [i.item() for i in valid_roc_auc_scores]

# Add val scores to csv with all scores
if os.path.isfile("../logs/scores.csv") == False:
    pd.DataFrame(columns=["name", "scores", "mean_score"]).to_csv("../logs/scores.csv", index=False)
    
# Append to current scores csv
all_scores_df = pd.concat([
    pd.read_csv("../logs/scores.csv"),
    pd.DataFrame.from_dict(
        {
            "name": [hparams.log_name],
            "scores": [valid_roc_auc_scores],
            "mean_score": [np.mean(valid_roc_auc_scores)]
        }
    )],
    ignore_index=True
)
# Write all scores df to csv
all_scores_df.to_csv("../logs/scores.csv", index=False)

logger.info(f"Best scores: {valid_roc_auc_scores}")
logger.info("Training complete.")

[2022-11-08 19:48:45] 1328260540.py[  10] : INFO  backbone: se_resnet50
[2022-11-08 19:48:45] 1328260540.py[  11] : INFO  device_name: NVIDIA GeForce RTX 3090
[2022-11-08 19:48:45] 1328260540.py[  12] : INFO  gpus: [0]
[2022-11-08 19:48:45] 1328260540.py[  13] : INFO  n_workers: 128
[2022-11-08 19:48:45] 1328260540.py[  14] : INFO  image_size: [1024, 1024]
[2022-11-08 19:48:45] 1328260540.py[  15] : INFO  seed: 14
[2022-11-08 19:48:45] 1328260540.py[  16] : INFO  min_epochs: 30
[2022-11-08 19:48:45] 1328260540.py[  17] : INFO  max_epochs: 50
[2022-11-08 19:48:45] 1328260540.py[  18] : INFO  patience: 11
[2022-11-08 19:48:45] 1328260540.py[  19] : INFO  train_batch_size: 4
[2022-11-08 19:48:45] 1328260540.py[  20] : INFO  val_batch_size: 4
[2022-11-08 19:48:45] 1328260540.py[  21] : INFO  n_splits: 3
[2022-11-08 19:48:45] 1328260540.py[  22] : INFO  test_size: 0.1
[2022-11-08 19:48:45] 1328260540.py[  23] : INFO  precision: 16
[2022-11-08 19:48:45] 1328260540.py[  24] : INFO  gradient_c

[2022-11-08 20:08:53] 1230659057.py[  85] : INFO  0-41 | lr : 0.000823 | val_loss : 1.7296 | val_roc_auc : 0.7557 | data_load_times : 32.16 | batch_run_times : 32.38
[2022-11-08 20:09:22] 1230659057.py[  85] : INFO  0-42 | lr : 0.000801 | val_loss : 1.6997 | val_roc_auc : 0.7522 | data_load_times : 31.97 | batch_run_times : 32.18
[2022-11-08 20:09:50] 1230659057.py[  85] : INFO  0-43 | lr : 0.000777 | val_loss : 1.6990 | val_roc_auc : 0.7635 | data_load_times : 31.13 | batch_run_times : 31.37
[2022-11-08 20:10:18] 1230659057.py[  85] : INFO  0-44 | lr : 0.000753 | val_loss : 1.7391 | val_roc_auc : 0.7520 | data_load_times : 32.51 | batch_run_times : 32.73
[2022-11-08 20:10:46] 1230659057.py[  85] : INFO  0-45 | lr : 0.000727 | val_loss : 1.7424 | val_roc_auc : 0.7564 | data_load_times : 32.37 | batch_run_times : 32.60
[2022-11-08 20:11:14] 1230659057.py[  85] : INFO  0-46 | lr : 0.000701 | val_loss : 1.6778 | val_roc_auc : 0.7713 | data_load_times : 32.72 | batch_run_times : 32.94
[202

[2022-11-08 20:27:22] 1230659057.py[  85] : INFO  1-29 | lr : 0.000992 | val_loss : 1.9566 | val_roc_auc : 0.6213 | data_load_times : 33.28 | batch_run_times : 33.58
[2022-11-08 20:27:51] 1230659057.py[  85] : INFO  1-30 | lr : 0.000987 | val_loss : 1.9287 | val_roc_auc : 0.6418 | data_load_times : 33.44 | batch_run_times : 33.66
[2022-11-08 20:28:22] 1230659057.py[  85] : INFO  1-31 | lr : 0.000979 | val_loss : 1.9331 | val_roc_auc : 0.5987 | data_load_times : 33.09 | batch_run_times : 33.32
[2022-11-08 20:28:52] 1230659057.py[  85] : INFO  1-32 | lr : 0.000970 | val_loss : 1.9532 | val_roc_auc : 0.6043 | data_load_times : 35.39 | batch_run_times : 35.62
[2022-11-08 20:29:21] 1230659057.py[  85] : INFO  1-33 | lr : 0.000960 | val_loss : 1.9432 | val_roc_auc : 0.6113 | data_load_times : 34.37 | batch_run_times : 34.59
[2022-11-08 20:29:51] 1230659057.py[  85] : INFO  1-34 | lr : 0.000947 | val_loss : 1.9439 | val_roc_auc : 0.6103 | data_load_times : 33.91 | batch_run_times : 34.14
[202

[2022-11-08 20:48:00] 1230659057.py[  85] : INFO  2-29 | lr : 0.000992 | val_loss : 2.0124 | val_roc_auc : 0.6128 | data_load_times : 33.92 | batch_run_times : 34.19
[2022-11-08 20:48:29] 1230659057.py[  85] : INFO  2-30 | lr : 0.000987 | val_loss : 1.9485 | val_roc_auc : 0.6007 | data_load_times : 34.35 | batch_run_times : 34.60
[2022-11-08 20:48:59] 1230659057.py[  85] : INFO  2-31 | lr : 0.000979 | val_loss : 1.9969 | val_roc_auc : 0.5800 | data_load_times : 37.10 | batch_run_times : 37.32
[2022-11-08 20:49:28] 1230659057.py[  85] : INFO  2-32 | lr : 0.000970 | val_loss : 1.9878 | val_roc_auc : 0.6364 | data_load_times : 35.39 | batch_run_times : 35.83
[2022-11-08 20:49:57] 1230659057.py[  85] : INFO  2-33 | lr : 0.000960 | val_loss : 1.9680 | val_roc_auc : 0.6035 | data_load_times : 33.22 | batch_run_times : 33.49
[2022-11-08 20:50:26] 1230659057.py[  85] : INFO  2-34 | lr : 0.000947 | val_loss : 1.9475 | val_roc_auc : 0.6246 | data_load_times : 34.46 | batch_run_times : 34.69
[202

## Validation Inference

In [10]:
# Get model run path and define chosen fold
log_dir = "../logs/logs"
#model_run = "2022_11_08_14:57:52"
model_run = hparams.log_name
model_run_path = os.path.join(log_dir, model_run)
#best_fold = 0
best_fold = valid_roc_auc_scores.index(max(valid_roc_auc_scores))

# Get best model for chosen fold
model_run_dir = os.listdir(model_run_path)
model_folds = [i for i in model_run_dir if i.startswith(f"fold={best_fold}")]
model_folds_scores = [i.split(".")[-2] for i in model_folds]
model_name = model_folds[model_folds_scores.index(max(model_folds_scores))]
model_path = os.path.join(model_run_path, model_name)

# Load fold's model
model = CoolSystem(hparams)
model.load_state_dict(
    torch.load(model_path)["state_dict"]
)
model.eval()

# Retrieve validation indices for chosen fold
for fold_i, (train_index, val_index) in enumerate(folds.split(metadata[["img_path"]], metadata[["label"]])):
    if fold_i == best_fold:
        break

# Select fold validation images
X_val = torch.from_numpy(X_train[val_index]).permute(0, 3, 1, 2).float()

# Create predictions looped by batch
counter = 0
val_idx_batch = []
scores_df = pd.DataFrame()

for i in range(0, len(val_index)):
    counter += 1
    val_idx_batch.append(i) 
    
    # Run inference for val_batch_size
    if counter == hparams.val_batch_size:
        
        preds = model(X_val[val_idx_batch])
        
        # Create activation output
        log_softmax = torch.nn.LogSoftmax(dim=-1)

        # Convert raw output to probabilities
        preds = np.exp(log_softmax(preds).detach().numpy())

        # Create df with img paths and predicted label probs
        scores_df_batch = pd.DataFrame(preds, columns=val_data.columns[1:])
        scores_df_batch = pd.merge(
            metadata.iloc[val_index, 1:3].reset_index(drop=True),
            scores_df_batch, 
            left_index=True,
            right_index=True
        )
        scores_df = pd.concat([scores_df, scores_df_batch], ignore_index=True, axis=0)

        # Cleanup
        gc.collect()
        torch.cuda.empty_cache()
        # Reset counter and batch
        counter = 0
        val_idx_batch = []
        
    # Run inference for remaining batch
    elif i == len(val_index) - 1:
        preds = model(X_val[val_idx_batch])
        
        # Create activation output
        log_softmax = torch.nn.LogSoftmax(dim=-1)

        # Convert raw output to probabilities
        preds = np.exp(log_softmax(preds).detach().numpy())

        # Create df with img paths and predicted label probs
        scores_df_batch = pd.DataFrame(preds, columns=val_data.columns[1:])
        scores_df_batch = pd.merge(
            metadata.iloc[val_index, 1:3].reset_index(drop=True),
            scores_df_batch, 
            left_index=True,
            right_index=True
        )
        scores_df = pd.concat([scores_df, scores_df_batch], ignore_index=True, axis=0)
        
        # Cleanup
        gc.collect()
        torch.cuda.empty_cache()

        
# Write predictions to log
scores_df.to_csv(
    os.path.join(model_run_path, f"{model_run}_preds_fold_{best_fold}.csv"),
    index=False
)

In [11]:
scores_df

Unnamed: 0,img_path,label,ant,bedbug,bee,horsefly,mite,mosquito,none,tick
0,../data/cleaned/none/ea1b100b581fcdb7ddfae52cc...,none,0.133760,0.163069,0.078456,0.257561,0.088188,0.101174,0.057207,0.120586
1,../data/cleaned/none/6eac051b9c45ff6821ec86752...,none,0.134779,0.166059,0.077732,0.260209,0.087800,0.098749,0.057563,0.117109
2,../data/cleaned/none/f64f47a69e72ce97b3ae2b07e...,none,0.133815,0.163775,0.078367,0.258696,0.087899,0.100552,0.057515,0.119381
3,../data/cleaned/none/b8f71c61c19392c3cb24ebc54...,none,0.134907,0.166490,0.077403,0.261625,0.087551,0.098084,0.057636,0.116304
4,../data/cleaned/none/ea1b100b581fcdb7ddfae52cc...,none,0.134132,0.164425,0.078293,0.258571,0.088021,0.100126,0.057586,0.118846
...,...,...,...,...,...,...,...,...,...,...
76,../data/cleaned/none/ea1b100b581fcdb7ddfae52cc...,none,0.138921,0.164599,0.072797,0.276708,0.086375,0.093328,0.050856,0.116417
77,../data/cleaned/none/6eac051b9c45ff6821ec86752...,none,0.134711,0.164123,0.077650,0.258863,0.088338,0.099940,0.056430,0.119945
78,../data/cleaned/none/f64f47a69e72ce97b3ae2b07e...,none,0.134113,0.163946,0.078291,0.256321,0.088618,0.100971,0.057305,0.120435
79,../data/cleaned/none/b8f71c61c19392c3cb24ebc54...,none,0.135089,0.166448,0.077249,0.262180,0.087573,0.097927,0.057238,0.116295
