In [3]:
import sys

sys.path.append("..")

from library.models._legacy_timm_model import TimmModule
from library.inference_utils import correct_img_rotation
from library.config import TrainMetadata
from library.data.utils import list_index

ckpt_path = "../rotation_model/checkpoint.ckpt"

train_metadata = TrainMetadata.from_yaml("../train_metadata.yml")

In [4]:
model = (
    TimmModule.load_from_checkpoint(ckpt_path, pretrained=False, for_inference=True)
    .cpu()
    .eval()
)

In [5]:
import os

all_imgs = [os.path.join("..", p) for p in list_index(train_metadata.images, train_metadata.val_idxs)]
print(len(all_imgs))
print(all_imgs[0])

4472
../data/hotel-id-to-combat-human-trafficking-2022-fgvc9/train_images/83679/000022124.jpg


In [20]:
from library.data.utils import read_img_rot, read_img
from dataclasses import dataclass
import numpy as np
from typing import Any
from torchvision.transforms import Compose
import PIL.Image as pil_img
import random

n_imgs = 200


@dataclass
class ImgRot:
    img: Any
    img_90: Any
    img_180: Any
    img_270: Any

    def vals(self):
        return [self.img, self.img_90, self.img_180, self.img_270]

    def apply_transform(self, t: Compose) -> "ImgRot":
        return ImgRot(*(t(pil_img.fromarray(i)) for i in self.vals()))



def get_rot_img(img_p) -> ImgRot:
    img = read_img(img_p)

    rot_imgs = (read_img_rot(img_p, i) for i in range(1, 4))

    return ImgRot(img, *rot_imgs)


imgs = random.sample(all_imgs, k=n_imgs)
    

In [21]:
import matplotlib.pyplot as plt
import torch
from tqdm.notebook import tqdm
from typing import Optional
import time

t_transform = model.get_transform()

SAVE_IMGS: Optional[str] = ".cache/rotcors"
PLOT_IMGS = False 

def bench_fn(typ, fn, args, kwargs, disable=False):
    now = time.time()
    result = fn(*args, **kwargs)
    if not disable:
        print(f"{typ}, took: {time.time() - now}")
    return result

for i_img, img_p in enumerate(tqdm(imgs)):
    fig, plts = plt.subplots(2, 4, figsize=(20, 10))

    img_rots = get_rot_img(img_p)

    np_imgs = img_rots.vals()
    t_imgs = img_rots.apply_transform(t_transform).vals()
    for i_rot in range(4):
        np_img = np_imgs[i_rot]
        model_inp = t_imgs[i_rot]

        sp = plts[0, i_rot]

        model_inp = model_inp.reshape(1, *model_inp.shape)

        pred = bench_fn("inference", lambda: torch.argmax(model.forward(model_inp)[0, ...].detach().cpu()).item(), [], {}, disable=True)

        # double work, but test if this function works
        np_img_cor = bench_fn("corr_img", lambda: correct_img_rotation(model, np_img), [], {}, disable=True)
        sp.set_title(f"Prediction angle: {90 * pred}, True angle {90 * i_rot}")
        sp.imshow(np_img)

        sp.axes.get_xaxis().set_visible(False)
        sp.axes.get_yaxis().set_visible(False)

        sp = plts[1, i_rot]
        sp.set_title("Above corrected image")
        sp.imshow(np_img_cor)

        sp.axes.get_xaxis().set_visible(False)
        sp.axes.get_yaxis().set_visible(False)


        fig.tight_layout()

        if PLOT_IMGS:
            fig.show()

        if SAVE_IMGS is not None:
            os.makedirs(SAVE_IMGS, exist_ok=True)
            fig.savefig(os.path.join(SAVE_IMGS, f"{i_img}.png"))

        plt.close(fig)


  0%|          | 0/200 [00:00<?, ?it/s]