In [None]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# PyTorch Lightning Training
This notebook trains a model to predict whether the given sonar signals are bouncing off a metal cylinder or off a cylindrical rock from [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Connectionist+Bench+%28Sonar%2C+Mines+vs.+Rocks%29).

This notebook is derived from the [PyTorch sample](https://github.com/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/pytorch/TrainingAndPredictionWithPyTorch.ipynb). It demonstrates how to perform the same task using [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), a lightweight wrapper around PyTorch.

The notebook is intended to run within [AI Platform Notebooks](https://cloud.google.com/ai-platform-notebooks). The model will be trained within the notebook instance VM, optionally attached to GPUs or TPUs. With the following link, you can directly [Open in AI Platform Notebooks](https://console.cloud.google.com/ai-platform/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/ai-platform-samples/master/notebooks/samples/pytorch/lightning/TrainingAndPredictionWithPyTorchLightning.ipynb). 

### Dataset
The Sonar Signals dataset that this sample uses for training is provided by the UC Irvine Machine Learning Repository. Google has hosted the data on a public GCS bucket `gs://cloud-samples-data/ai-platform/sonar/sonar.all-data`.

* `sonar.all-data` is split for both training and evaluation

Note: Your typical development process with your own data would require you to upload your data to GCS so that you can access that data from inside your notebook. However, in this case, Google has put the data on GCS to avoid the steps of having you download the data from UC Irvine and then upload the data to GCS.

### Disclaimer
This dataset is provided by a third party. Google provides no representation, warranty, or other guarantees about the validity or any other aspects of this dataset.

## (Optional) TPU configuration

To use [Cloud TPUs](https://cloud.google.com/tpu), first create a [TPU node](https://cloud.google.com/tpu/docs/creating-deleting-tpus#setup_TPU_only). Set the **TPU software version** to a matching PyTorch version (e.g. `pytorch-1.7`) and the **Network** to the same network used for your notebook instance (e.g. `datalab-network`).

Uncomment this section only if you are using TPUs. Note that you must be running this notebook on an [XLA](https://github.com/pytorch/xla) image such as [pytorch-xla.1-7](gcr.io/deeplearning-platform-release/pytorch-xla.1-7) for PyTorch to connect to Cloud TPUs. To use an XLA image, you can create a new notebook instance with the **Environment** set to `Custom container` and the **Docker container image** set to the XLA image location.

If you need a quota increase for Cloud TPUs, please review the [Cloud TPU Quota Policy](https://cloud.google.com/tpu/docs/quota) for more details.

### Review TPU configuration

Run the gcloud command to review the available TPUs for the one you wish to use.
Make note of the IP address (from NETWORK_ENDPOINT, without the port), and the # of TPU cores (derived from ACCELERATOR_TYPE). An ACCELERATOR_TYPE of v3-8 will indicate 8 TPU cores, for example.

In [None]:
# !gcloud compute tpus list --zone=YOUR_ZONE_HERE_SUCH_AS_us-central1-b

### Update TPU configuration

Update the IP address and cores variables here

In [None]:
# tpu_ip_address='10.1.2.3'
# tpu_cores=8

### Set TPU environment variables

In [None]:
# # TPU configuration
# %env XRT_TPU_CONFIG=tpu_worker;0;$tpu_ip_address:8470

# # Use bfloat16
# %env XLA_USE_BF16=1

## Install and import packages

In [None]:
!pip install -U pytorch-lightning --quiet

In [None]:
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
if XLADeviceUtils.tpu_device_exists():
    import torch_xla  # noqa: F401

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split

import pandas as pd
from google.cloud import storage

from pytorch_lightning.core import LightningModule, LightningDataModule
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.trainer.trainer import Trainer

## Environment configuration

In [None]:
_ = !nproc
tpu_cores = tpu_cores if 'tpu_cores' in vars() else 0
num_cpus = int(_[0])
num_gpus = torch.cuda.device_count()
device = torch.device('cuda') if num_gpus else 'cpu'

print(f'Device: {device}')
print(f'CPUs: {num_cpus}')
print(f'GPUs: {num_gpus}')
print(f'TPUs: {tpu_cores}')

## Download data

In [None]:
# Public bucket holding data for samples
BUCKET = 'cloud-samples-data'

# Path to the directory inside the public bucket containing the sample data
BUCKET_PATH = 'ai-platform/sonar/'

# Sample data file
FILE = 'sonar.all-data'

In [None]:
bucket = storage.Client().bucket(BUCKET)

blob = bucket.blob(BUCKET_PATH + FILE)

blob.download_to_filename(FILE)

## Define the PyTorch [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)

In [None]:
class SonarDataset(Dataset):
    def __init__(self, csv_file):
        self.dataframe = pd.read_csv(csv_file, header=None)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        # When iterating through the dataset get the features and targets
        features = self.dataframe.iloc[idx, :-1].values.astype(dtype='float64')

        # Convert the targets to binary values:
        # R = rock --> 0
        # M = mine --> 1
        target = self.dataframe.iloc[idx, -1:].values
        if target[0] == 'R':
            target[0] = 0
        elif target[0] == 'M':
            target[0] = 1
        target = target.astype(dtype='float64')

        # Load the data as a tensor
        data = {'features': torch.from_numpy(features),
                'target': target}
        return data

## Define a data processing module

In this step, you will create a custom data module that extends [LightningDataModule](https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html) to encapsulate the data processing steps.

In [None]:
class SonarDataModule(LightningDataModule):

    def __init__(self, bucket=BUCKET, bucket_path=BUCKET_PATH, file=FILE, batch_size=32, num_workers=0):
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.bucket = bucket
        self.bucket_path = bucket_path
        self.file = file

    def prepare_data(self):
        # Public bucket holding the data
        bucket = storage.Client().bucket(self.bucket)

        # Path to the data inside the public bucket
        blob = bucket.blob(self.bucket_path + self.file)

        # Download the data
        blob.download_to_filename(self.file)

    def setup(self, stage=None):
        # Load the data
        sonar_dataset = SonarDataset(self.file)

        # Create indices for the split
        dataset_size = len(sonar_dataset)
        test_size = int(0.2 * dataset_size)  # Use a test_split of 0.2
        val_size = int(0.2 * dataset_size)  # Use a test_split of 0.2
        train_size = dataset_size - test_size - val_size

        # Assign train/test/val datasets for use in dataloaders
        self.sonar_train, self.sonar_val, self.sonar_test = random_split(sonar_dataset, [train_size, val_size, test_size])

    def train_dataloader(self):
        return DataLoader(self.sonar_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.sonar_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.sonar_test, batch_size=self.batch_size, num_workers=self.num_workers)


dm = SonarDataModule(num_workers=num_cpus)

## Define a model

Next, you will create a module that extends [LightningModule](https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html). This module includes your model code and organizes steps of the model-building process.

In [None]:
class SonarModel(LightningModule):

    def __init__(self):
        super().__init__()

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Linear(60, 60),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(60, 30),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(30, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x.float())

    def training_step(self, batch, batch_idx):
        x, y = batch['features'].float(), batch['target'].float()
        y_hat = self(x)

        loss = F.binary_cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['features'].float(), batch['target'].float()
        y_hat = self(x)

        loss = F.binary_cross_entropy(y_hat, y)

        # Binarize the output
        y_hat_binary = y_hat.round()
        acc = accuracy(y_hat_binary, y.int())

        # Log metrics for TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        # Reuse validation step
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.5, nesterov=False)


model = SonarModel()

## Train and evaluate the model

Finally, you will create and use a [Trainer](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html) to build the model and evaluate its accuracy.

The trainer is initialized with an [accelerator](https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#accelerator), with different options depending on your environment:
* TPUs support only [ddp](https://pytorch-lightning.readthedocs.io/en/stable/tpu.html#distributed-backend-with-tpu), or distributed data parallel, and so `accelerator` cannot be specified. 
* GPUs support a variety of [distributed modes](https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#distributed-modes). In this notebook, we are using [dp](https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#data-parallel) for multiple GPUs on 1 machine.
* CPUs can support [ddp_cpu](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html#trainer-flags) for multi-node CPU training. For multiple CPUs on one node, there is no speed increase from using this accelerator, and so the default of `None` is used in this notebook. 

In [None]:
epochs = 100

if tpu_cores:
    trainer = Trainer(tpu_cores=tpu_cores, max_epochs=epochs)
elif num_gpus:
    trainer = Trainer(gpus=num_gpus, accelerator='dp', max_epochs=epochs)
else:
    trainer = Trainer(max_epochs=epochs)

In [None]:
trainer.fit(model, dm)

In [None]:
trainer.test(datamodule=dm)

## Save and load a trained model

The following steps aren't required, but are shown for use in a production environment.

First, we'll export the model to a file. Then, we'll load the model file (which isn't required in a notebook, because we already have a trained model). Finally, we'll set the model to evaluation mode (rather than train mode) for inference.

In [None]:
torch.save(model.state_dict(), 'model.pt')

model.load_state_dict(torch.load('model.pt'))

model.eval()

## Predict with the model 

Finally, let's illustrate model inference, with set values as inputs:

In [None]:
rock_feature = torch.tensor([[3.6800e-02, 4.0300e-02, 3.1700e-02, 2.9300e-02, 8.2000e-02, 1.3420e-01,
                              1.1610e-01, 6.6300e-02, 1.5500e-02, 5.0600e-02, 9.0600e-02, 2.5450e-01,
                              1.4640e-01, 1.2720e-01, 1.2230e-01, 1.6690e-01, 1.4240e-01, 1.2850e-01,
                              1.8570e-01, 1.1360e-01, 2.0690e-01, 2.1900e-02, 2.4000e-01, 2.5470e-01,
                              2.4000e-02, 1.9230e-01, 4.7530e-01, 7.0030e-01, 6.8250e-01, 6.4430e-01,
                              7.0630e-01, 5.3730e-01, 6.6010e-01, 8.7080e-01, 9.5180e-01, 9.6050e-01,
                              7.7120e-01, 6.7720e-01, 6.4310e-01, 6.7200e-01, 6.0350e-01, 5.1550e-01,
                              3.8020e-01, 2.2780e-01, 1.5220e-01, 8.0100e-02, 8.0400e-02, 7.5200e-02,
                              5.6600e-02, 1.7500e-02, 5.8000e-03, 9.1000e-03, 1.6000e-02, 1.6000e-02,
                              8.1000e-03, 7.0000e-03, 1.3500e-02, 6.7000e-03, 7.8000e-03, 6.8000e-03]], dtype=torch.float64, device=device)
rock_prediction = model(rock_feature)

mine_feature = torch.tensor([[5.9900e-02, 4.7400e-02, 4.9800e-02, 3.8700e-02, 1.0260e-01, 7.7300e-02,
                              8.5300e-02, 4.4700e-02, 1.0940e-01, 3.5100e-02, 1.5820e-01, 2.0230e-01,
                              2.2680e-01, 2.8290e-01, 3.8190e-01, 4.6650e-01, 6.6870e-01, 8.6470e-01,
                              9.3610e-01, 9.3670e-01, 9.1440e-01, 9.1620e-01, 9.3110e-01, 8.6040e-01,
                              7.3270e-01, 5.7630e-01, 4.1620e-01, 4.1130e-01, 4.1460e-01, 3.1490e-01,
                              2.9360e-01, 3.1690e-01, 3.1490e-01, 4.1320e-01, 3.9940e-01, 4.1950e-01,
                              4.5320e-01, 4.4190e-01, 4.7370e-01, 3.4310e-01, 3.1940e-01, 3.3700e-01,
                              2.4930e-01, 2.6500e-01, 1.7480e-01, 9.3200e-02, 5.3000e-02, 8.1000e-03,
                              3.4200e-02, 1.3700e-02, 2.8000e-03, 1.3000e-03, 5.0000e-04, 2.2700e-02,
                              2.0900e-02, 8.1000e-03, 1.1700e-02, 1.1400e-02, 1.1200e-02, 1.0000e-02]], dtype=torch.float64, device=device)
mine_prediction = model(mine_feature)

print('Result Values: (Rock: 0) - (Mine: 1)\n')
print(f'Rock Prediction:\n\t{"Rock" if rock_prediction <= 0.5 else "Mine"} - {rock_prediction.item()}')
print(f'Mine Prediction:\n\t{"Rock" if mine_prediction <= 0.5 else "Mine"} - {mine_prediction.item()}')