In [30]:
import torch
from torchvision.io import read_image
import torchvision

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from math import sqrt, ceil
import os
from tqdm import tqdm
from copy import deepcopy

import sahi.predict as predict

# Import YOLOv5 helper functions
os.chdir("/home/ucloud/EUMothModel")
from tutils.yolo_helpers import non_max_suppression
from tutils.models import *

In [31]:
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
# Model definition
backbone = "efficientnet_b0"
constructor = create_extendable_model_class(backbone_model=backbone)
model = constructor(path="models/run19.state", device=device, dtype=dtype)
model.eval()

ExtendableModel_Hierarchical

In [4]:
# # Bootstrap model from other state
# rstate = parse_state_file("models/run19.state")
# masks = rstate["masks"]
# class_handles = rstate["class_handles"]

# # Initialize model with AlexNet backbone
# backbone_other = "alexnet"
# constructor_other = create_extendable_model_class(backbone_model=backbone_other)
# model = constructor_other(device=device, dtype=dtype, class_handles=class_handles, masks=masks)

# # Save model with AlexNet backbone
# model.save_to_path("models/run19_alexnet.state")

# # Load model with AlexNet backbone
# model = constructor_other(path="models/run19_alexnet.state", device=device, dtype=dtype)

## Helper functions

In [5]:
from sahi import ObjectPrediction, BoundingBox
from sahi.postprocess.combine import GreedyNMMPostprocess, LSNMSPostprocess
from sahi.prediction import visualize_object_predictions, PredictionScore

def ObjectPrediction_to_yolov5_format(obj, num_classes):
    # Extract bounding box and objectness score
    bbox = obj.bbox.to_voc_bbox()  # Convert to [x_min, y_min, x_max, y_max]
    x_center = (bbox[0] + bbox[2]) / 2
    y_center = (bbox[1] + bbox[3]) / 2
    width = bbox[2] - bbox[0]
    height = bbox[3] - bbox[1]
    objectness = obj.score.value

    # One-hot encode the predicted class
    class_id = obj.category.id
    class_scores = [0] * num_classes
    class_scores[class_id] = 1

    # Combine into YOLOv5 format
    yolov5_prediction = [x_center, y_center, width, height, objectness] + class_scores

    return torch.tensor(yolov5_prediction)

def convert_to_yolov5_format(object_predictions, num_classes):
    """
    Convert a list of sahi.ObjectPrediction objects to YOLOv5 prediction format.

    Args:
    - object_predictions (List[sahi.ObjectPrediction]): List of object predictions.
    - num_classes (int): Total number of classes.

    Returns:
    - List[List[float]]: Converted predictions in YOLOv5 format.
    """
    yolov5_predictions = []

    for prediction in object_predictions:
        yolov5_predictions += [ObjectPrediction_to_yolov5_format(prediction, num_classes)]    

    return torch.stack(yolov5_predictions)

def convert_to_yolov5_batch_format(yolov5_list):
    """
    Convert a list of YOLOv5 predictions to a batched YOLOv5 prediction.

    Args:
    - yolov5_list (List[List[float]]): List of YOLOv5 predictions.

    Returns:
    - torch.Tensor: Batched YOLOv5 prediction.
    """
    return torch.concat(yolov5_list)

def yolov5_to_ObjectPrediction(yp):
    """
    Convert a YOLOv5 style output to a list of sahi.ObjectPrediction objects.

    Args:
    - yp (torch.Tensor): YOLOv5 style output. Torch.Tensor of shape (n_predictions, 5 + num_classes).
    """

    object_predictions = []

    for pred in yp:
        x, y, w, h, obj = pred[:5]
        class_scores = pred[5:]

        # Convert to sahi.BoundingBox
        x_min = x - w / 2
        y_min = y - h / 2
        x_max = x + w / 2
        y_max = y + h / 2

        # Convert to sahi.Category
        category_id = class_scores.argmax().item()

        this_obj = ObjectPrediction(
            bbox=[x_min.item(), y_min.item(), x_max.item(), y_max.item()],
            category_id=category_id,
            category_name="salient_moth",
            score=obj.item()
        )

        object_predictions += [this_obj]

    return object_predictions

def plot_yolov5(yp, image, match_threshold=0.5, match_metric="IOS",**kwargs):
    object_predictions = yolov5_to_ObjectPrediction(yp)
    postprocessor = GreedyNMMPostprocess(match_threshold=match_threshold, match_metric=match_metric)
    object_predictions = postprocessor(object_predictions)
    image = np.ascontiguousarray(read_image(image).numpy().transpose(1, 2, 0), dtype=np.uint8)

    visualize_object_predictions(
        object_prediction_list=object_predictions,
        image=image,
        **kwargs
    )

def plot_list(l, image, mult=None, slice_sizes=None, conf_threshold=0.75, match_threshold=0.5, match_metric="IOS",**kwargs):
    object_predictions = []
    if mult is None:
        mult = [1] * len(l)
    for i, ol in enumerate(l):
        tl = ol.object_prediction_list
        tl = [op for op in tl if (op.score.value * mult[i]) >= conf_threshold]
        if slice_sizes is not None:
            tslice_area = slice_sizes[i] ** 2
            tl = [op for op in tl if op.bbox.area <= (tslice_area * 0.5)]

        object_predictions += tl
    postprocessor = LSNMSPostprocess(match_threshold=match_threshold, match_metric=match_metric)
    object_predictions = postprocessor(object_predictions)
    image = np.ascontiguousarray(read_image(image).numpy().transpose(1, 2, 0), dtype=np.uint8)

    visualize_object_predictions(
        object_prediction_list=object_predictions,
        image=image,
        **kwargs
    )

    out = l[0]
    out.object_prediction_list = object_predictions
    return out

def combine_and_nms(l, mult, slice_sizes, conf_threshold, match_metric, match_threshold):
    object_predictions = []
    if mult is None:
        mult = [1] * len(l)
    for i, ol in enumerate(l):
        tl = ol.object_prediction_list
        tl = [op for op in tl if (op.score.value * mult[i]) >= conf_threshold]
        if slice_sizes is not None:
            tslice_area = slice_sizes[i] ** 2
            tl = [op for op in tl if op.bbox.area <= (tslice_area * 0.5)]

        object_predictions += tl
    postprocessor = LSNMSPostprocess(match_threshold=match_threshold, match_metric=match_metric)
    object_predictions = postprocessor(object_predictions)

    out = l[0]
    out.object_prediction_list = object_predictions
    return out

## Augmentations

In [6]:
import albumentations as A

augment = A.Compose([
    A.Cutout(num_holes=4, max_h_size=32, max_w_size=32, p=0.5),
    A.GaussNoise(var_limit=(0.01, 0.05), p=0.5),
    A.Rotate(limit=45, p=1.0, rotate_method="ellipse"),
    A.RandomBrightnessContrast(p=0.5),
    A.VerticalFlip(p=0.5)
])



## Model and hyperparameters

In [7]:
# ## Hyperparameters
skip = 0 # Number of batches to skip (used to resume script from a specific batch)
batch_size = 8 # Batch size for chunked loading of images
n_batches = 8 # Number of batches to load
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype=torch.bfloat16
real_class_index = 0 # The real class index of the species to fine-tune on (6 for the initial model; insectGBIF-1280m6.pt)
inference_size = 1280 # The size of the images to run inference on (1280 for the initial model; insectGBIF-1280m6.pt)
image_directory = "/home/ucloud/testCrops/Set5/"

model_weights = "insect_iter7-1280m6.pt"
pass

In [8]:
from sahi.predict import get_sliced_prediction
from sahi import AutoDetectionModel

detection_model = AutoDetectionModel.from_pretrained(
    model_type='yolov5',
    model_path="models/" + model_weights,
    config_path="models/custom5m.yaml",
    confidence_threshold=0.25,
    device=device
)

11/20/2023 11:32:24 - INFO - numexpr.utils -   Note: NumExpr detected 10 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
11/20/2023 11:32:24 - INFO - numexpr.utils -   NumExpr defaulting to 8 threads.


In [9]:
images = [image_directory + os.sep + i for i in os.listdir(image_directory) if i.endswith(".jpg")]

In [10]:
def get_image(path, size=inference_size):
    image = read_image(path, mode=torchvision.io.image.ImageReadMode.RGB)
    if size is not None:
        image = torch.functional.F.interpolate(image.unsqueeze(0), size=size, mode="bilinear")
    else:
        image = image.unsqueeze(0)
    image = image.to(device,dtype=dtype)
    return image

In [11]:
tbatch = torch.cat([get_image(i, 256) for i in images], dim=0)

In [12]:
torch.cuda.empty_cache()
# print(torch.cuda.memory_summary())
torch.cuda.memory_allocated() // 10 ** 6

211

In [13]:
with torch.no_grad():
    predictions = model(tbatch)

In [14]:
import types

def augment_float_tensor(x):
    tx = x.clone()
    tx = tx.permute(1, 2, 0).cpu().numpy() / 255
    tx = augment(image=tx)['image']
    tx = torch.tensor(tx).permute(2, 0, 1).to(device, dtype=dtype) * 255
    return tx

def do_define_call_with_augment(model):
    old_call = model.__call__
    _dtype = model.parameters().__next__().dtype
    _device = model.parameters().__next__().device

    def call_with_augment(self, x, do_augment=False, n=3, **kwargs):
        if do_augment:
            xs = []
            for tx in x:
                xs += [torch.stack([tx.clone()] + [augment_float_tensor(tx.float()).to(_dtype) for _ in range(n)])]
            xs = torch.stack(xs)
            xs = xs.to(device=_device, dtype=_dtype)
            if len(xs.shape) > 4:
                augment_preds = [old_call(xsi, **kwargs) for xsi in xs]
                if self.return_embeddings:
                    augment_embeddings = torch.stack([ap[1].mean(0) for ap in augment_preds])
                    augment_preds = [ap[0] for ap in augment_preds]
                combine_pred = [torch.stack([torch.stack([ap[i][j] for j in range(n + 1)]).log_softmax(1) for ap in augment_preds]) for i in range(len(augment_preds[0]))]
                single_pred = [torch.stack([(j.float().logsumexp(0) - torch.log(torch.tensor(j.shape[0])).to(j.dtype).to(j.device)).log_softmax(0) for j in combine_pred[i]]) for i in range(len(combine_pred))]
            else:
                augment_preds = old_call(xs, **kwargs)
                if self.return_embeddings:
                    augment_embeddings = augment_preds[1].mean(0)
                    augment_preds = augment_preds[0]
                combine_pred = [torch.cat([ap[i] for ap in augment_preds]).log_softmax(0) for i in range(len(augment_preds[0]))]
                single_pred = [(j.float().logsumexp(0) - torch.log(torch.tensor(j.shape[0])).to(j.dtype).to(j.device)).log_softmax(0) for j in combine_pred]
            if self.return_embeddings:
                return single_pred, augment_embeddings
            else:
                return single_pred
        else:
            return old_call(x, **kwargs)

    model.__call__ = types.MethodType(call_with_augment, model)
    return model

model = do_define_call_with_augment(model)

In [15]:
weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
class create_image_preprocessing:
    def __init__(self, weights):
        self.transform = weights.transforms.func(crop_size=256)
        self.isize = self.transform.resize_size[0]
        self.mean = self.transform.mean
        self.std = self.transform.std

    def __call__(self, images):
        """Preprocess images for EfficientNet."""
        images = torchvision.transforms.Resize((self.isize, self.isize), antialias=True)(images)
        return self.transform(images)

image_preprocessing = create_image_preprocessing(weights)

In [16]:
def shannon_entropy(t):
    """
    This function computes the row-wise Shannon entropy of a tensor.
    
    Args:
    - t (torch.Tensor): Input tensor of shape (n, m).

    Returns:
    - torch.Tensor: Row-wise Shannon entropy of shape (n,).
    """

    return (-t * torch.log2(t)).sum(1)

In [17]:
import re

def bbox_predict(model, output, image, epistemic_threshold=2, background_threshold=0.3, transform=None):
    _dtype = model.parameters().__next__().dtype
    _device = model.parameters().__next__().device

    # print(f'dtype: {_dtype}')
    # print(f'device: {_device}')

    # Convert SAHI output to to torch.Tensor of shape (n, 4) with xmin, ymin, xmax, ymax for each bounding box
    tensor_bboxes = torch.tensor([list(i.bbox.__dict__.values())[:4] for i in output.object_prediction_list]).round().float().to(device=_device)
    tensor_image = get_image(image, size=None).float().to(device=_device)

    untransformed_inputs = torchvision.ops.roi_align(tensor_image, [tensor_bboxes], output_size=(256, 256))
    if transform is not None:
        inputs = transform(untransformed_inputs)
    inputs = inputs.to(device=_device, dtype=_dtype)
    # print(f'inputs - shape: {inputs.shape} - dtype: {inputs.dtype} - device: {inputs.device}')

    model_uses_embeddings = model.toggle_embeddings(True)
    with torch.no_grad():
        crop_pred, crop_embeddings = model(inputs)

    max_scores = torch.stack([crop_pred[i].max(1).values for i in range(3)])
    score_multiplier = max_scores.T.diff(1).exp()
    best_level = [torch.where(i)[0].max().item() + 1 if any(i) else 0 for i in (score_multiplier > epistemic_threshold)]

    predicted_class, predicted_class_name, prediction_confidence = [], [], []
    for i in range(len(inputs)):
        this_scores = crop_pred[best_level[i]][i]
        this_predicted_class = int(this_scores.argmax(0).int().cpu().numpy())
        this_predicted_class_name = model.class_handles["idx_to_class"][best_level[i]][this_predicted_class]
        this_prediction_confidence = this_scores.max(0).values.exp().float().cpu().numpy()

        predicted_class += [this_predicted_class]
        predicted_class_name += [this_predicted_class_name] if this_prediction_confidence > background_threshold else ["background"]
        prediction_confidence += [this_prediction_confidence]

    # Sort by class and confidence
    order = np.lexsort((prediction_confidence, predicted_class))
    untransformed_inputs = untransformed_inputs[order]
    predicted_class = [predicted_class[i] for i in order]
    predicted_class_name = [predicted_class_name[i] for i in order]
    prediction_confidence = [prediction_confidence[i] for i in order]
    crop_embeddings = crop_embeddings[order]
    best_level = [best_level[i] for i in order]

    # Restore model state
    model.toggle_embeddings(model_uses_embeddings)

    # Return the cropped tensors, predicted class level, predicted class index, predicted class name, and prediction confidence
    return untransformed_inputs, best_level, predicted_class, predicted_class_name, prediction_confidence, crop_embeddings

def plot_batch_predict(images, labels, conf, output_file="sahi/plots/crops.png", mixed=True):
    if mixed:
        images_background = torch.stack([i for i, l in zip(images, labels) if l == "background"])
        labels_background = [l for l in labels if l == "background"]
        conf_background = [c for c, l in zip(conf, labels) if l == "background"]
        plot_batch_predict(images_background, labels_background, conf_background, output_file=output_file.replace(".png", "_background.png"), mixed=False)

        images_salient = torch.stack([i for i, l in zip(images, labels) if l != "background"])
        labels_salient = [l for l in labels if l != "background"]
        conf_salient = [c for c, l in zip(conf, labels) if l != "background"]
        plot_batch_predict(images_salient, labels_salient, conf_salient, output_file=output_file.replace(".png", "_salient.png"), mixed=False)
    else:
        # plot crops for debugging
        ncol = int(sqrt(len(images)))
        nrow = ceil(len(images) / ncol)
        assert ncol * nrow >= len(images)

        fig, axs = plt.subplots(nrow, ncol, figsize=(ncol * 2, nrow * 2))

        for ax, crop, p, s in zip(axs.flatten(), images.int().cpu(), labels, conf):
            ax.imshow(crop.cpu().numpy().transpose(1, 2, 0))
            p = re.search(r"^[^_]+(_[^_]+){0,1}", p).group(0) if p != "background" else p
            p = p.replace("_", "\n")
            ax.set_title(f"{p}\n({100 * s:.1f}%)")
            ax.axis("off")
        # Set unused axes to invisible
        for ax in axs.flatten()[len(images):]:
            ax.axis("off")
            ax.set_visible(False)

        plt.tight_layout()
        plt.savefig(output_file, dpi=300)
        plt.close()

In [18]:
slice_sizes = np.arange(256, 512 + 1, 64).tolist()

pbar = tqdm(enumerate([images[0]]), total=len(images))

for image_idx, this_image in pbar:
    pbar.set_description(f"Processing image {image_idx + 1} of {len(images)} ({this_image.split(os.sep)[-1]})")
    torch.cuda.empty_cache()
    # Perform localization on the image
    these_predictions = [
        get_sliced_prediction(
        this_image,
        detection_model,
        slice_height=i,
        slice_width=i,
        perform_standard_pred=False,
        postprocess_type="GREEDYNMM",
        postprocess_match_threshold=1,
        overlap_height_ratio=0.25,
        overlap_width_ratio=0.25,
        verbose=0
        )
        for i in slice_sizes
    ]

    # Combine predictions
    this_combined_prediction = combine_and_nms(these_predictions, None, slice_sizes, 0.8, "IOU", 0.1)

    # Perform bounding box classification
    crops, plevel, pclass, pclass_name, pconfidence, embeddings = bbox_predict(model, this_combined_prediction, this_image, background_threshold=0.2, transform=lambda x : image_preprocessing(x / 255.0))

    # Plot the crops
    this_image_name = this_image.split(os.sep)[-1].split(".")[0]
    plot_batch_predict(crops, pclass_name, pconfidence, output_file=f"sahi/inference/{this_image_name}.png")

  0%|          | 0/20 [00:00<?, ?it/s]

Processing image 1 of 20 (SS2 - 20230626030000-167-snapshot.jpg):   5%|▌         | 1/20 [00:33<10:42, 33.80s/it]


In [19]:
import matplotlib as mpl

plt.close()

fig, axs = plt.subplots(6, 6, figsize=(24, 24))

for i, (ax, crop, p, s) in enumerate(zip(axs.flatten(), crops[:36].int().cpu(), pclass_name, pconfidence)):
    ax.imshow(crop.cpu().numpy().transpose(1, 2, 0))
    p = re.search(r"^[^_]+(_[^_]+){0,1}", p).group(0) if s > 0.3 else "background"
    p = p.replace("_", "\n")
    ax.set_title(f"{p}\n({100 * s:.1f}%) - {i}")
    ax.axis("off")

plt.tight_layout()
plt.savefig("sahi/inference/crops.png", dpi=300)

In [None]:
orig_do_augment = model.toggle_embeddings(True)

with torch.no_grad():
    test_crop = crops[5].clone()

    test_crop = test_crop / 255.0
    test_crop = image_preprocessing(test_crop)
    tout, temb = model.__call__(test_crop.unsqueeze(0), do_augment=True, n = 100)

model.toggle_embeddings(orig_do_augment)

In [47]:
import time

orig_do_augment = model.toggle_embeddings(True)

with torch.no_grad():
    test_crop = crops.clone()

    test_crop = test_crop / 255.0
    test_crop = image_preprocessing(test_crop)
    start_time = time.time()
    # tout, temb = model(test_crop.to(dtype))
    tout, temb = model.__call__(test_crop, do_augment=True, n = 25)
    end_time = time.time()

print(f"Time taken: {end_time - start_time:.2f} seconds")

model.toggle_embeddings(orig_do_augment)

[2023-11-20 11:47:41,610] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _forward_impl
[2023-11-20 11:47:43,448] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-11-20 11:47:48,222] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 2


BackendCompilerFailed: debug_wrapper raised RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/compile-ptx-src-de44af, line 151; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 151; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 153; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 153; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 155; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 155; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 157; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 157; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 159; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 159; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 161; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 161; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 163; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 163; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 165; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 165; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 167; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 167; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 169; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 169; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 171; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 171; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 173; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 173; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 175; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 175; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 177; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 177; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 179; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 179; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 181; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 181; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 225; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 225; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 227; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 227; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 229; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 229; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 231; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 231; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 233; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 233; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 235; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 235; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 237; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 237; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 239; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 239; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 241; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 241; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 243; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 243; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 245; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 245; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 247; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 247; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 249; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 249; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 251; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 251; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 253; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 253; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 255; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 255; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 291; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 291; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 293; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 293; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 295; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 295; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 297; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 297; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 299; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 299; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 301; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 301; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 303; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 303; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 305; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 305; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 307; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 307; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 309; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 309; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 311; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 311; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 313; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 313; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 315; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 315; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 317; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 317; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 319; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 319; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 321; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 321; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 357; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 357; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 359; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 359; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 361; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 361; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 363; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 363; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 365; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 365; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 367; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 367; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 369; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 369; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 371; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 371; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 373; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 373; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 375; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 375; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 377; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 377; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 379; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 379; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 381; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 381; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 383; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 383; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 385; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 385; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 387; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 387; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 423; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 423; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 425; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 425; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 427; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 427; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 429; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 429; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 431; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 431; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 433; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 433; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 435; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 435; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 437; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 437; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 439; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 439; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 441; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 441; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 443; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 443; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 445; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 445; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 447; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 447; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 449; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 449; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 451; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 451; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 453; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 453; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 667; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 667; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 669; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 669; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 671; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 671; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 673; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 673; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 675; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 675; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 677; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 677; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 679; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 679; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 681; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-de44af, line 681; error   : Feature 'cvt with .bf16' requires .target sm_80 or higher
ptxas fatal   : Ptx assembly aborted due to errors


Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True


In [22]:
plt.close()

plt.figure(figsize=(12, 6))
plt.imshow(tout[0].exp().float().cpu().numpy(), interpolation="nearest", cmap="magma", norm=mpl.colors.LogNorm())
plt.gca().set_aspect("auto")

plt.colorbar()

plt.tight_layout()
plt.savefig("sahi/plots/test1.png", dpi=300)
plt.close()

In [23]:
plt.close()

plt.figure(figsize=(12, 6))
plt.scatter(*torch.pca_lowrank(tout[0].float(), 2)[0].cpu().numpy().transpose())
plt.gca().set_aspect("equal")

plt.tight_layout()
plt.savefig("sahi/plots/test2.png", dpi=300)
plt.close()

In [24]:
[f'{model.class_handles["idx_to_class"][0][i.argmax(0).int().item()]} ({i.max().exp().item() * 100:.1f}%)' for i in tout[0]]

['Thaumetopoea_processionea_Linnaeus_1758 (8.4%)',
 'Thyatira_batis_batis (17.0%)',
 'Thaumetopoea_processionea_Linnaeus_1758 (12.8%)',
 'Thaumetopoea_processionea_Linnaeus_1758 (20.2%)',
 'Thyatira_batis_batis (17.8%)',
 'Thyatira_batis_batis (11.5%)',
 'Thaumetopoea_processionea_Linnaeus_1758 (15.5%)',
 'Thyatira_batis_batis (11.5%)',
 'Thyatira_batis_batis (15.4%)',
 'Thaumetopoea_processionea_Linnaeus_1758 (17.6%)']

In [25]:
# crop_ord = torch.pca_lowrank(test_crop_inf[1].float(), 2)[0]
# Try umap for testing
import umap
import matplotlib as mpl
crop_ord = umap.UMAP(n_components=2, metric = "euclidean", min_dist = 0.05, n_neighbors = 15).fit_transform(embeddings.float().cpu())
crop_ord = torch.tensor(crop_ord)

plt.figure(figsize=(5, 5))
plt.scatter(*crop_ord.cpu().T)
plt.savefig("sahi/plots/crop_pca.png", dpi=300)

# Plot each crop at the PCA coordinates, in a sort of "scatterplot" of the crops
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

crop_ord_screen_coords = crop_ord.clone()
crop_ord_screen_coords -= crop_ord_screen_coords.min(0, keepdim=True).values
crop_ord_screen_coords /= crop_ord_screen_coords.max(0, keepdim=True).values

plt.figure(figsize=(30, 30))
ax = plt.gca()

# define colormap for classes
cmap = mpl.colormaps['tab20']
norm = mpl.colors.Normalize(vmin=0, vmax=len(set(pclass_name)))
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)

for i, (x, y) in enumerate(crop_ord_screen_coords.cpu()):
    image = crops[i].int().cpu().numpy().transpose(1, 2, 0)
    conf = pconfidence[i].item()  # Prediction confidence

    # Adjust the alpha of the image based on the confidence
    im = OffsetImage(image, zoom=0.25, alpha=1)  # Adjust zoom as needed
    box_color = list(cmap.to_rgba(pclass[i])) if pclass_name[i] != "background" else [0, 0, 0, 0]
    box_color[3] = conf  # Adjust alpha based on confidence

    # Create an annotation box with adjusted alpha for the border
    bboxprops = dict(edgecolor=box_color, linewidth=3)  # Add confidence to the RGBA tuple
    ab = AnnotationBbox(im, (x, y), xycoords='data', frameon=True, pad=0.1, boxcoords="offset points", bboxprops=bboxprops)
    ax.add_artist(ab)

    # Extract and format the box label
    box_label = re.search(r"^[^_]+(_[^_]+){0,1}", pclass_name[i]).group(0)
    box_label = box_label.replace("_", "\n")

    # Adjust these parameters to position the annotation above the box
    text_offset = 20  # Adjust this value as needed
    # Adjust the alpha of the text and its background based on confidence
    if pclass_name[i] != "background":
        ax.annotate(box_label, (x, y), xycoords='data', xytext=(0, text_offset), textcoords='offset points', ha="center", va="bottom", fontsize=12, color="white", bbox=dict(boxstyle="round", fc=box_color, ec="none", alpha=conf))

plt.savefig("sahi/plots/crop_pca_scatter.png", dpi=300)
plt.close()