In [None]:
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torchvision.transforms
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch
import numpy as np
from keypoint_detection.utils.heatmap import generate_channel_heatmap, get_keypoints_from_heatmap
from keypoint_detection.utils.visualization import overlay_image_with_heatmap
from keypoint_detection.models.detector import KeypointDetector
from keypoint_detection.data.unlabeled_dataset import UnlabeledKeypointsDataset

In [None]:
import wandb
from pathlib import Path
from skimage import io
import torchvision

In [None]:
## Get Model checkpoint from wandb


checkpoint_reference = 'airo-box-manipulation/clothes/model-1kkjks2v:v1'

# download checkpoint locally (if not already cached)
run = wandb.init(project="clothes", entity="airo-box-manipulation")
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()

# 
#checkpoint = torch.load(Path(artifact_dir) / "model.ckpt")
#print(checkpoint["hyper_parameters"])
# load checkpoint
# ,map_location={"cuda:0":"cpu"}
model = KeypointDetector.load_from_checkpoint(Path(artifact_dir) / "model.ckpt",backbone_type='ConvNeXtUnet')

In [None]:
import os
home = os.path.expanduser("~")
dataset_dir = os.path.join(home, "cloth-keypoint-generation", "datasets", "towel_dataset_2")
JSON_PATH = os.path.join(dataset_dir, "annotations.json")
IMAGE_DIR = os.path.join(dataset_dir, "images")


print(JSON_PATH)
dataset = UnlabeledKeypointsDataset(IMAGE_DIR)
print(len(dataset))
dataloader = torch.utils.data.DataLoader(dataset, batch_size= 4, shuffle= False)

In [None]:
def imshow(img):
    """
    plot Tensor as image
    images are kept in the [0,1] range, although in theory [-1,1] should be used to whiten..
    """
    np_img = img.numpy()
    # bring (C,W,H) to (W,H,C) dims
    img = np.transpose(np_img, (1,2,0))
    plt.imshow(img)
    plt.show()


In [None]:
img = next(iter(dataloader))[0]

In [None]:
transform  = torchvision.transforms.Resize((256,256))

In [None]:
# def crop(img_batch, start_v, height, start_u, width):
#     return img_batch[:,:,start_v: start_v +height, start_u: start_u + width]

img = next(iter(dataloader))
# cropped  =crop(img, 250, 350, 300, 450)
imshow(img[0])
# imshow(cropped[0])

In [None]:
def show_results(show_extracted_keypoints = True, mode ="eval"):
    """
    show network outputs on the dataset.
    """
    plt.rcParams["figure.figsize"] = (10,10)
    pil_to_torch = torchvision.transforms.ToTensor()
    if mode == "eval":
        model.eval()
    else:
        model.train()
    for batch in iter(dataloader):
        with torch.no_grad():
            # batch = crop(batch, 250, 350, 300, 500)
            batch = transform(batch)
            channel = 0
            output = model(batch)[:,channel]
            if not show_extracted_keypoints:
                overlayed_heatmap = torch.stack(
                    [
                        pil_to_torch(overlay_image_with_heatmap(batch[i], torch.unsqueeze(output[i].cpu(), 0),0.6))
                        for i in range(batch.shape[0])
                    ]
                )
            else:
                n_keypoints = 4
                overlayed_heatmap = torch.stack(
                [
                    pil_to_torch(
                        overlay_image_with_heatmap(
                            batch[i],
                            torch.unsqueeze(
                                generate_channel_heatmap(
                                    batch.shape[-2:],

                                    get_keypoints_from_heatmap(output[i].cpu(), 30,n_keypoints),
                                    sigma=4,
                                    device = 'cpu'
                                ),
                                0,
                            ),
                            0.6
                        )
                    )
                    for i in range(batch.shape[0])
                ]
        )
        grid = torchvision.utils.make_grid(overlayed_heatmap, nrow=8)
        imshow(grid)


In [None]:
show_results(True)