In [None]:
import json
from argparse import Namespace
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T


from tvcalib.cam_modules import SNProjectiveCamera
from tvcalib.utils.objects_3d import SoccerPitchSNCircleCentralSplit, SoccerPitchLineCircleSegments
from tvcalib.inference import get_camera_from_per_sample_output, load_annotated_points
from tvcalib.utils import visualization_mpl_min as viz

args = Namespace(
    file_hparams=Path("configs/wc14-test/extrem-pred.json"),
    per_sample_output=Path("experiments/wc14-test/extrem-pred/per_sample_output.json"),
    dir_images=Path("data/datasets/wc14-test"),
    output_dir=Path("tmp"),
)


object3d = SoccerPitchLineCircleSegments(device="cpu", base_field=SoccerPitchSNCircleCentralSplit())

with open(args.file_hparams) as fr:
    hparams = json.load(fr)
    lens_dist = False
    if hparams["lens_distortion"] == True:
        lens_dist = True

df_cam = pd.read_json(args.per_sample_output, orient="records", lines=True).set_index(
    "image_ids", drop=False
)
df_cam.head(5)

In [None]:
sample = df_cam.iloc[10]

image_id = Path(sample.image_ids).stem
print(f"{image_id=}")
image = Image.open(args.dir_images / sample.image_ids).convert("RGB")
image = T.functional.to_tensor(image)

cam = get_camera_from_per_sample_output(sample, lens_dist)
print(cam, cam.str_lens_distortion_coeff(b=0) if lens_dist else "")
points_line, points_circle = load_annotated_points(hparams, image_id, object3d)

if args.lens_dist:
    # we visualize annotated points and image after undistortion
    image = cam.undistort_images(image.unsqueeze(0).unsqueeze(0)).squeeze()
    # print(points_line.shape) # expected: (1, 1, 3, S, N)
    points_line = SNProjectiveCamera.static_undistort_points(points_line.unsqueeze(0).unsqueeze(0), cam).squeeze()
    points_circle = SNProjectiveCamera.static_undistort_points(points_circle.unsqueeze(0).unsqueeze(0), cam).squeeze()
else:
    psi = None


fig, ax = viz.init_figure(hparams["image_width"], hparams["image_height"])
ax = viz.draw_image(ax, image)
ax = viz.draw_reprojection(ax, object3d, cam)
ax = viz.draw_selected_points(
    ax,
    object3d,
    points_line,
    points_circle,
    kwargs_outer={
        "zorder": 1000,
        "rasterized": False,
        "s": 500,
        "alpha": 0.3,
        "facecolor": "none",
        "linewidths": 3,
    },
    kwargs_inner={
        "zorder": 1000,
        "rasterized": False,
        "s": 50,
        "marker": ".",
        "color": "k",
        "linewidths": 4.0,
    },
)
dpi = 50
plt.savefig(args.output_dir / f"{image_id}.pdf", dpi=dpi)
plt.savefig(args.output_dir / f"{image_id}.svg", dpi=dpi)
plt.savefig(args.output_dir / f"{image_id}.png", dpi=dpi)