In [None]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from neurocorgi_sdk import NeuroCorgiNet, Head4ImageNet
from neurocorgi_sdk.transforms import ToNeuroCorgiChip

In [None]:
# If possible, set up the GPU 0 for the application
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
imagenet_path = "/datasets/imagenet"
transform_val = transforms.Compose([transforms.Resize(256),
                                    transforms.CenterCrop(224),
                                    transforms.ToTensor(),
                                    ToNeuroCorgiChip()
                                    ])

validation_dataset = datasets.ImageFolder(f"{imagenet_path}/val", transform=transform_val)

validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=False)

In [None]:
model = NeuroCorgiNet("neurocorginet_imagenet.safetensors")
model.to(device)

In [None]:
head = Head4ImageNet("neurocorginet_head_imagenet.safetensors")
head.to(device)

In [None]:
pbar = tqdm(validation_loader)
val_correct = 0

with torch.no_grad():
    for i, (inputs, labels) in enumerate(pbar):
        inputs = inputs.to(device)
        labels = labels.to(device)

        div4, div8, div16, div32 = model(inputs)
        out = head(div32)

        _, val_preds = torch.max(out, 1)
        val_correct += torch.sum(val_preds == labels.data)
        accuracy = (val_correct / i) * 100

        pbar.set_description_str(f"Torch model - ImageNet test accuracy: {accuracy:.2f}%")

In [None]:
def im_convert(tensor):
    image = tensor.cpu().clone().detach().numpy()
    image = image.transpose(1, 2, 0)
    image = image.clip(0, 1)
    image = image / 255
    return image

In [None]:
import requests
from PIL import Image

url = "https://media.os.fressnapf.com/cms/2020/07/ratgeber_hund_rasse_portraits_welsh-corgi-pembroke_1200x527.jpg?t=cmsimg_920"
response = requests.get(url, stream=True)
img = Image.open(response.raw)
plt.imshow(img)

In [None]:
img = transform_val(img) 
plt.imshow(im_convert(img))

In [None]:
# Dog image expected label: 263: 'Pembroke, Pembroke Welsh corgi'
# See https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
image = img.to(device).unsqueeze(0)

# Inference
div4, div8, div16, div32 = model(image)
out = head(div32)
print(torch.topk(out.flatten(), 5))