# Simple Autoencoder

In [1]:
# https://pytorch-lightning.readthedocs.io/en/stable/model/train_model_basic.html

# XLA Configuration

Before starting, instruct your python kernel to use the TPU accelerator by setting the `XRT_TPU_CONFIG` env var. We set TPU config using the address to host process for TPU cores on your VM.<br> `XLA_USE_BF16=1` instructs PyTorch to use the bFloat16 format, rather than Float32. This will maximize TPU performance.

In [3]:
%env XRT_TPU_CONFIG=localservice;0;localhost:51011
%env XLA_USE_BF16=1

env: XRT_TPU_CONFIG=localservice;0;localhost:51011
env: XLA_USE_BF16=1


In [4]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

# XLA Libraries
Here, we import all of the required torch XLA libraries in order to use the TPU.

In [5]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as ploader
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu

# Create a Pytorch Lightning Module
Below is a simple Autoencoder model configured from the LightningModule.

In [6]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

In [7]:
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, batch, batch_idx, device):
        # training_step defines the train loop.
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

# Distributed Sampling
In order to use parallelization, we must set up a `DistributedSampler` from torch utils. This sampler is passed into our dataloader and allows us to sample subsets of the data for multiprocessing.

In [8]:
from torch.utils.data.distributed import DistributedSampler
train_dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())

train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )

train_loader = DataLoader(train_dataset, batch_size=128, sampler=train_sampler, num_workers=0)

In [11]:
# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())
optimizer = autoencoder.configure_optimizers()

In [None]:
xm.xla_device()

# Training Loop Changes
Below is the training loop. The most important steps here are:
1. Wrapping the model with `xmp.MpModelWrapper`
2. Initializing our XLA device with `xm.xla_device()`
3. Sending our model to the device. 
4. Initializing parallel data loading with `MpDeviceLoader`.
5. Replacing `optimizer.step()` lines with `xm.optimizer_step(optimizer)`

In [27]:
def train_model():
    
    autoencoder = LitAutoEncoder(Encoder(), Decoder())
    optimizer = autoencoder.configure_optimizers()

    WRAPPED_MODEL = xmp.MpModelWrapper(autoencoder)
    
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    
    para_loader = ploader.ParallelLoader(train_loader, [device])
    para_train_loader = para_loader.per_device_loader(device)
    
    # para_train_loader = ploader.MpDeviceLoader(train_loader, device)
    xm.master_print('Parallel Loader Created. Training ...')
    
    for batch_idx, batch in enumerate(para_train_loader):

        if (batch_idx + 1) % (len(train_loader) // 10) == 0:
            print(f'PROGRESS: Training is {((batch_idx + 1)/len(train_loader)*100):.2f}% complete...')
        
        optimizer.zero_grad()
        loss = autoencoder(batch, batch_idx, device)
        loss.backward()
        xm.optimizer_step(optimizer)
        
    print("SUCCESS: Training is 100% complete!")

# Multiprocessing Function
Finally, we can initialize a multiprocessing function `_mp_fn` with the training loop inside, and pass the function as a callback to `xmp.spawn`. 
<br>We also need to specify the number of processes `nprocs`, which is the number of TPU cores used.

In [28]:
def _mp_fn(rank, flags):
    train_model()
    
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')

Parallel Loader Created. Training ...
PROGRESS: Training is 9.81% complete...
PROGRESS: Training is 19.62% complete...
PROGRESS: Training is 29.42% complete...
PROGRESS: Training is 39.23% complete...
PROGRESS: Training is 49.04% complete...
PROGRESS: Training is 58.85% complete...
PROGRESS: Training is 68.66% complete...
PROGRESS: Training is 78.46% complete...
PROGRESS: Training is 88.27% complete...
PROGRESS: Training is 98.08% complete...
SUCCESS: Training is 100% complete!


In [29]:
def _mp_fn(rank, flags):
    train_model()
    
FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8
Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 316, in _setup_replication
    xm.set_replication(device, [device])
Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 318, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_

ProcessExitedException: process 0 terminated with exit code 17