# Exploring a Pytorch Models Latent Space
Atlas can be used to better understand your deep neural nets training and test characteristics.
By interacting with your models embeddings (logits) during training and evaluation you can:
- Identify which classes, targets or concepts your model has ease/difficulty learning.
- Identify mislabeled datapoints.
- Spot bugs/errors in your model implementation.

Atlas has a Pytorch Lightning hook that you can plug straight into your pytorch lightning training scripts.
This tutorial will take you through using it to visualize the training of a two layer neural network on MNIST.

In [1]:
!pip install pytorch-lightning torch torchvision torchmetrics



In [2]:
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import torchmetrics
from nomic.pl_callbacks import AtlasEmbeddingExplorer
import nomic

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
torch.manual_seed(0)
# nomic.login()


<torch._C.Generator at 0x7f155c63f1f0>

# The Lightning Module

In [3]:

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.l2 = torch.nn.Linear(10, 10)

    def forward(self, x):
        return torch.relu(self.l2(torch.relu(self.l1(x.view(x.size(0), -1)))))

    def training_step(self, batch, batch_nb):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        prediction = torch.argmax(logits, dim=1)[0].item()

        link = f'https://s3.amazonaws.com/static.nomic.ai/mnist/eval/{y}/{batch_idx}.jpg'
        metadata = {'label': y, 'prediction': prediction, 'fetched_datum_image_resource_url': link}

        self.atlas.log(embeddings=logits, metadata=metadata)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

## Training the model

In [4]:
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
max_epochs = 10

# Initialize the Embedding Explorer 🗺️ hook
embedding_explorer = AtlasEmbeddingExplorer(max_points=20_000,
                                            name="MNIST Validation Latent Space",
                                            description="MNIST Validation Latent Space",
                                            overwrite=True)
trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,
    max_epochs=max_epochs,
    callbacks=[TQDMProgressBar(refresh_rate=20),
               embedding_explorer],
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type               | Params
------------------------------------------------
0 | l1       | Linear             | 7.9 K 
1 | l2       | Linear             | 110   
2 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
8.0 K     Trainable params
0         Non-trainable params
8.0 K     Total params
0.032     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## Validate the model and log the embeddings

In [5]:
trainer.validate(mnist_model, test_ds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

2023-03-14 00:14:57.969 | INFO     | nomic.project:__init__:856 - Found existing project `MNIST Validation Latent Space` in organization `andriy`. Clearing it of data by request.
2023-03-14 00:14:58.788 | INFO     | nomic.project:_create_project:946 - Creating project `MNIST Validation Latent Space` in organization `andriy`
2023-03-14 00:14:59.620 | INFO     | nomic.atlas:map_embeddings:98 - Uploading embeddings to Atlas.
100%|██████████| 10/10 [00:01<00:00,  6.16it/s]
2023-03-14 00:15:01.252 | INFO     | nomic.atlas:map_embeddings:117 - Embedding upload succeeded.
2023-03-14 00:15:02.181 | INFO     | nomic.project:create_index:1259 - Created map `MNIST Validation Latent Space` in project `MNIST Validation Latent Space`: https://staging-atlas.nomic.ai/map/ca7de2b6-d039-4357-a8f8-52acde7eb804/64f90fe2-1f8c-48ed-83b8-1a5594914b99
2023-03-14 00:15:02.183 | INFO     | nomic.atlas:map_embeddings:130 - MNIST Validation Latent Space: https://staging-atlas.nomic.ai/map/ca7de2b6-d039-4357-a8f8-

[{}]


## View the map

In [6]:
embedding_explorer.map