# Exploring a Pytorch Models Latent Space
Atlas can be used to better understand your deep neural nets training and test characteristics.
By interacting with your models embeddings (logits) during training and evaluation you can:
- Identify which classes, targets or concepts your model has ease/difficulty learning.
- Identify mislabeled datapoints.
- Spot bugs/errors in your model implementation.

Atlas has a Pytorch Lightning hook that you can plug straight into your pytorch lightning training scripts.
This tutorial will take you through using it to visualize the training of a two layer neural network on MNIST.

In [None]:
!pip install pytorch-lightning torch torchvision torchmetrics

In [2]:
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import torchmetrics
from nomic.pl_callbacks import AtlasEmbeddingExplorer
import nomic

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
torch.manual_seed(0)
nomic.login('7xDPkYXSYDc1_ErdTPIcoAR9RNd8YDlkS3nVNXcVoIMZ6')


# The Lightning Module

In [3]:

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.l2 = torch.nn.Linear(10, 10)

    def forward(self, x):
        return torch.relu(self.l2(torch.relu(self.l1(x.view(x.size(0), -1)))))

    def training_step(self, batch, batch_nb):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        prediction = torch.argmax(logits, dim=1)[0].item()

        link = f'https://s3.amazonaws.com/static.nomic.ai/mnist/eval/{y}/{batch_idx}.jpg'
        metadata = {'label': y, 'prediction': prediction, 'url': link}

        self.atlas.log(embeddings=logits, metadata=metadata)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

## Training the model

In [None]:
mnist_model = MNISTModel()

# Init DataLoader from MNIST Dataset
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
test_ds = MNIST(PATH_DATASETS, train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

# Initialize a trainer
max_epochs = 10

# Initialize the Embedding Explorer 🗺️ hook
embedding_explorer = AtlasEmbeddingExplorer(max_points=20_000,
                                            name="MNIST Validation Latent Space",
                                            description="MNIST Validation Latent Space",
                                            overwrite_on_validation=True)
trainer = Trainer(
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,
    max_epochs=max_epochs,
    callbacks=[TQDMProgressBar(refresh_rate=20),
               embedding_explorer],
)

# Train the model ⚡
trainer.fit(mnist_model, train_loader)


## Validate the model and log the embeddings

In [5]:
trainer.validate(mnist_model, test_ds)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

2023-03-14 10:53:03.504 | INFO     | nomic.project:_create_project:946 - Creating project `MNIST Validation Latent Space` in organization `Atlas Demo`
2023-03-14 10:53:04.977 | INFO     | nomic.atlas:map_embeddings:98 - Uploading embeddings to Atlas.
100%|██████████| 10/10 [00:03<00:00,  3.10it/s]
2023-03-14 10:53:08.206 | INFO     | nomic.atlas:map_embeddings:117 - Embedding upload succeeded.
2023-03-14 10:53:09.863 | INFO     | nomic.project:create_index:1259 - Created map `MNIST Validation Latent Space` in project `MNIST Validation Latent Space`: https://atlas.nomic.ai/map/5c920000-f861-400e-ae04-9401d54c2633/41b407d8-ba49-4e8c-83ed-a943ed9bed0b
2023-03-14 10:53:09.864 | INFO     | nomic.atlas:map_embeddings:130 - MNIST Validation Latent Space: https://atlas.nomic.ai/map/5c920000-f861-400e-ae04-9401d54c2633/41b407d8-ba49-4e8c-83ed-a943ed9bed0b


[{}]

## View the map

In [7]:
embedding_explorer.map