In [None]:
import csv
import torch
from model import CNN
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms

In [None]:
model_weigths_file = "./custom_model.pt"
device ='cpu'
# loading the model
n_classes = 10
model = CNN(3, n_classes)
model.load_state_dict(torch.load(model_weigths_file, map_location=device))
model = model.eval()

In [None]:
# getting the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

images_dataset = datasets.CIFAR10(root="data", train=False, download=True, transform=ToTensor())
selected_images = []

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import umap
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

dataset = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform
)

loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=4
)

# computing embeddings and logits
model.eval()

all_embeddings = []
all_logits = []
all_labels = []

MAX_N_POINTS = 500

with torch.no_grad():
    correctly_classified = 0
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        embeddings, logits = model(images)
        preds = torch.argmax(logits, dim=1)

        # mask for correct classifications
        correct_mask = preds == labels
        correctly_classified += correct_mask.sum().item()

        if correct_mask.any():
            all_embeddings.append(embeddings[correct_mask].cpu())
            all_logits.append(logits[correct_mask].cpu())
            all_labels.append(labels[correct_mask].cpu())
            print(len(all_labels))
        
        if correctly_classified >= MAX_N_POINTS:
            break

embeddings = torch.cat(all_embeddings, dim=0).numpy()
logits = torch.cat(all_logits, dim=0).numpy()
labels = torch.cat(all_labels, dim=0).numpy()

# umap computation
umap_reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

embeddings_2d = umap_reducer.fit_transform(embeddings)
logits_2d = umap_reducer.fit_transform(logits)

# plotting
plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
scatter = plt.scatter(
    embeddings_2d[:, 0],
    embeddings_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Embeddings (Correct Predictions)")
plt.colorbar(scatter, ticks=range(10))

# --------------------
# Plot logits
# --------------------
plt.subplot(1, 2, 2)
scatter = plt.scatter(
    logits_2d[:, 0],
    logits_2d[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)
plt.title("UMAP of Logits (Correct Predictions)")
plt.colorbar(scatter, ticks=range(10))

plt.tight_layout()
plt.show()
