In [None]:
import sys

sys.path.append("../scripts")

In [None]:
import Models_02 as models
import numpy as np
import Paths as paths
import Preprocessing_01 as preprocessing
import pytorch_lightning as pl
import torch

from monai.data import DataLoader, Dataset
from monai.utils import set_determinism
from pytorch_lightning.callbacks import RichProgressBar

In [None]:
excluded_images = paths.excluded_images_GESPIC

In [None]:
# define data to use
predict_df = preprocessing.dataset_df[
    preprocessing.dataset_df["image"].str.contains("GESPIC")
]
print(len(predict_df))
predict_df = predict_df[~predict_df["image"].str.contains("|".join(excluded_images))]
print(len(predict_df))
predict_x = predict_df["image"]
predict_y = predict_df["classification"]
predict_data = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(predict_x, predict_y)
]

In [None]:
predict_ds = Dataset(data=predict_data, transform=preprocessing.val_transforms_416x628)
predict_dataloader = DataLoader(predict_ds, num_workers=16, batch_size=8)

In [None]:
set_determinism(seed=0)
model = models.ResNetTransferLearningDiscriminativeLR

In [None]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=15,
    log_every_n_steps=1,
    callbacks=[RichProgressBar()],
    num_sanity_val_steps=2,
    gradient_clip_val=0.5,
    gradient_clip_algorithm="value",
)

checkpoint = torch.load("../final_model.ckpt")
model2 = model(
    only_fc=False,
    max_lr=5e-4,
    wd=0.0001,
    first_dropout=0.0,
    lr_mult=0.9,
    alpha=0,
)

model2.load_state_dict(checkpoint["state_dict"])


In [None]:
raw_preds = trainer.predict(
    model2, dataloaders=predict_dataloader, return_predictions=True
)

In [None]:
raw_predictions0 = torch.Tensor(
    np.vstack([raw_preds[i]["preds"].numpy() for i in range(len(raw_preds))])
)
targets0 = torch.Tensor(
    np.vstack([raw_preds[i]["targets"].numpy() for i in range(len(raw_preds))])
)
preds = {"preds": raw_predictions0, "targets": targets0}

In [None]:
# save predictions to load them analysis

torch.save(preds, "../eval/GESPIC_preds.pt")