In [None]:
import pandas as pd
import yaml
from PIL import Image
import numpy
from torchvision import transforms
import sys
import matplotlib.pyplot as plt
import plotnine as pl
import seaborn as sns
# import math
import textwrap
from tqdm import tqdm

sys.path.append('../')

from train import create_dataloader, load_model
from dataset import MDclassDataset

import torch

split_name = "frac1.0_split0.3"
config_name = "resnet18"
run_name = "2025-01-22_01-26-42_resnet18_200e_256bs_0.0001lr_0.0wd_frac1.0_split0.3"
model_name = "resnet18"
!pwd

In [84]:
dat_train = pd.read_csv(f"~/Documents/cv4e/CV4E-2025/data/tabular/splits/{split_name}/dat_train.csv")
dat_val = pd.read_csv(f"~/Documents/cv4e/CV4E-2025/data/tabular/splits/{split_name}/dat_val.csv")
dat_test = pd.read_csv(f"~/Documents/cv4e/CV4E-2025/data/tabular/splits/{split_name}/dat_test.csv")

lookup = pd.read_csv("~/Documents/cv4e/CV4E-2025/data/tabular/labels_lookup.csv")
cfg = yaml.safe_load(open(f"../runs/{run_name}/{run_name}_config.yaml", "r"))

In [None]:
dat_train[["label_group", "label_id"]].drop_duplicates()
# len(dat_test.label_id.unique())
dat_val.shape
dat_train.label_id
dat_val[["label_group", "label_id"]].drop_duplicates()
n_class = dat_train.label_id.nunique()
n_class

In [None]:
list(1/(dat_train.label_id.value_counts()/max(dat_train.label_id.value_counts())))

In [None]:
dl_train = create_dataloader(cfg, dat_train.crop_path, dat_train.label_id)
dl_val = create_dataloader(cfg, dat_val.crop_path, dat_val.label_id)
instance=next(iter(dl_train))
ims = [instance["image"][x] for x in range(len(instance["image"]))]
labs = [instance["label"][x] for x in range(len(instance["label"]))]
labs = [lookup.query(f'label_id == {x}').iloc[0,0] for x in [x.numpy() for x in labs]]
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

ims = [invTrans(ims[x]) for x in range(len(ims))]
Image.fromarray((255*numpy.transpose(ims[0].numpy(), (1,2,0))).astype(numpy.uint8))
ims_pil = [Image.fromarray((255*numpy.transpose(ims[x].numpy(), (1,2,0))).astype(numpy.uint8)) for x in range(len(ims))]
ims_pil[0]

In [89]:
def display_images(
    images, labs,
    columns=5, width=20, height=8, max_images=1000000, 
    label_wrap_length=50, label_font_size=30):

    if not images:
        print("No images to display.")
        return 

    if len(images) > max_images:
        print(f"Showing {max_images} images of {len(images)}:")
        images=images[0:max_images]

    height = max(height, int(len(images)/columns) * height)
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):

        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.imshow(image)

        title=textwrap.wrap(labs[i], label_wrap_length)
        title="\n".join(title)
        plt.title(title, fontsize=label_font_size); 

# display_images(ims_pil, labs)

In [None]:
state = torch.load(
            open(f"/home/Vale/Documents/cv4e/CV4E-2025/runs/{run_name}/best.pt", "rb"), map_location="cpu", 
            weights_only=True
        )
state["epoch"]

In [None]:
model = load_model(cfg, n_class)[0]
model.load_state_dict(state["model"])
# model

In [None]:
preds = model(instance["image"])
preds

preds_id = list(preds.argmax(axis=1).numpy())
preds_id[0:10]
preds_labs = [lookup.at[x, 'label_group'] for x in preds_id]
preds_labs[0:10]
truth_preds_labs = ["T:" + x + " - " + "P:" + y for x, y in zip(labs, preds_labs)]
truth_preds_labs[0:10]

display_images(ims_pil, truth_preds_labs, label_font_size=15)

In [None]:
instance=next(iter(dl_val))
ims = [instance["image"][x] for x in range(len(instance["image"]))]
labs = [instance["label"][x] for x in range(len(instance["label"]))]
labs = [lookup.query(f'label_id == {x}').iloc[0,0] for x in [x.numpy() for x in labs]]
ims = [invTrans(ims[x]) for x in range(len(ims))]
ims_pil = [Image.fromarray((255*numpy.transpose(ims[x].numpy(), (1,2,0))).astype(numpy.uint8)) for x in range(len(ims))]
preds = model(instance["image"])
preds_id = list(preds.argmax(axis=1).numpy())
preds_labs = [lookup.at[x, 'label_group'] for x in preds_id]
true_preds_labs = ["T:" + x + " - " + "P:" + y for x, y in zip(labs, preds_labs)]
display_images(ims_pil, true_preds_labs, label_font_size=15)

In [None]:
device = "cuda"
all_predictions = []
all_labels = []
model.eval()
model.to(device)
with torch.no_grad():
    for batch_n, batch in tqdm(enumerate(dl_train), total = len(dl_train)):
        # put data and labels on device
        data, labels = batch["image"].to(device), batch["label"].to(device)

        # forward pass
        all_predictions.extend(model(data).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [None]:
all_predictions

In [82]:
softmax = torch.logit(torch.softmax(torch.tensor(all_predictions), dim =1))
preds = pd.DataFrame(softmax, columns=lookup.label_group).assign(label_id=all_labels).merge(lookup).drop(["label_id", "size"], axis=1)
preds.to_csv("../outputs/predictions.csv")