In [None]:
import warnings

import numpy as np
import pandas as pd

import lightning as pl
from torch.optim.lr_scheduler import ReduceLROnPlateau

import albumentations as albm
from albumentations.pytorch import ToTensorV2 as ToTensor

from tqdm import tqdm

from thesis.utils import get_test_dataloader
from thesis.optim import Lion
from thesis.models import EfficientNetB7
from thesis.models.model_wrapper import ModelWrapper
from thesis.custom_callbacks import progress_bar, MetricsCallback, checkpoint_callback

In [None]:
FOLDER = 'path/to/folder/with/images'
BATCH_SIZE = 16 # can be any number, does not affects quality
chkp_path = 'path/to/checkpoint.chkp'

In [None]:
backbone = EfficientNetB7()
optimizer = Lion(backbone.parameters(), lr=1e-4)   
scheduler = ReduceLROnPlateau(optimizer, factor=0.2, patience=3)
model = ModelWrapper.load_from_checkpoint(
    chkp_path,
    model=backbone,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
# trainer can be used to continue training from checkpoint

pb = progress_bar()
ck = checkpoint_callback()

trainer = pl.Trainer(
    callbacks=[pb, MetricsCallback(), ck],
    deterministic=False
)

In [None]:
df = pd.read_csv('path/to/file_with_images_paths.csv')
df['path'] = df['image_id'] + '.jpg'

In [None]:
transforms = albm.Compose([
    albm.Resize(height=300, width=300),
    albm.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), always_apply=True),
    ToTensor()
])

data_loader = get_test_dataloader(FOLDER, df, transforms, BATCH_SIZE)

In [None]:
preds = trainer.predict(model, data_loader, return_predictions=True)

In [None]:
preds = np.concatenate([x.numpy() for x in preds])
y_pred = np.argmax(preds, axis=1)