# Using a Vision Transformer on CIFAR-100

## Imports

In [None]:
!pip install datasets transformers[torch] renumics-spotlight

In [None]:
import datasets
from renumics import spotlight

ds = datasets.load_dataset('cifar100', split='test')
ds = ds.select(range(0,100))
#spotlight.show(ds)

In [None]:
import torch
import transformers

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "Ahmed9275/Vit-Cifar100"
processor = transformers.ViTImageProcessor.from_pretrained(model_name)
cls_model = transformers.ViTForImageClassification.from_pretrained(model_name).to(device)
fe_model = transformers.ViTModel.from_pretrained(model_name).to(device)

def infer(batch):
    images = [image.convert("RGB") for image in batch]
    inputs = processor(images=images, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = cls_model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
        embeddings = fe_model(**inputs).last_hidden_state[:, 0].cpu().numpy()
    preds = probs.argmax(axis=-1)
    return {"prediction": preds, "embedding": embeddings}

features = datasets.Features({**ds.features, "prediction": ds.features["fine_label"], "embedding": datasets.Sequence(feature=datasets.Value("float32"), length=768)})
ds_enriched = ds.map(infer, input_columns="img", batched=True, batch_size=2, features=features)

In [None]:
ds_results = datasets.load_dataset('renumics/spotlight-cifar100-enrichment', split='test')
ds_enriched = datasets.concatenate_datasets([ds, ds_results], axis=1)

In [None]:
layout = spotlight.layouts.debug_classification(label='fine_label', embedding='embedding', inspect={'img': spotlight.dtypes.image_dtype})
spotlight.show(ds_enriched, dtype={'embedding': spotlight.Embedding}, layout=layout)

On Saturn cloud:

First server: [https://cnn-tutorial-1.community.saturnenterprise.io:8000](https://cnn-tutorial-1.community.saturnenterprise.io:8000)

Second server: [https://cnn-tutorial-2.community.saturnenterprise.io:8000](https://cnn-tutorial-2.community.saturnenterprise.io:8000)