In [None]:
import torch
from torch import nn
from torch import optim

from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

In [None]:
#! pip install pytorch-lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint

In [None]:
x = torch.randn(5)
# print(x.cpu())
# print(x.cuda())

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
print(x.to(device))

x_on_device = x.to(device)
x_on_device.device

----
# PyTorch Lightning

1. model
2. optimizer
3. data
4. training loop "the magic"
5. validation loop "the validation magic"

In [None]:
import pytorch_lightning as pl

# old code
# from pytorch_lightning.metrics.functional import accuracy
from torchmetrics.functional import accuracy


# Lightning
class ResMLPLightning(pl.LightningModule):
    def __init__(self, learning_rate):
        super().__init__()
        self.save_hyperparameters()
        
        self.l1 = nn.Linear(28*28, 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)

        self.learning_rate = learning_rate
        self.loss_fct = nn.CrossEntropyLoss()

    def forward(self, x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do = self.do(h2 + h1) # added residual connection
        logits = self.l3(do)
        return logits

    def configure_optimizers(self):
        optimiser = optim.SGD(self.parameters(), lr=self.learning_rate)
        return optimiser

    def training_step(self, batch, batch_idx):
        x, y = batch # pl takes care of the devices
        
        # x (batch_size, 1, 28, 28) -> (batch_size, 1*28*28)
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        # 1 forward
        logits = self(x) # outputs before softmax (batch size, output size)

        # 2 compute the objective function (loss)
        J = self.loss_fct(logits, y)

        # 'loss' is a reserved keyword and it's what is returned by default
        #   return J
        # log accuracy
        acc = accuracy(preds=logits, target=y)
        self.log('train_acc:', acc, prog_bar=True)
        
        return {'loss': J}
        

    def validation_step(self, batch, batch_idx, ):
        # if the validation step is the same as the training step
        #   in everything other than trivial things like
        #   model.eval(); with torch.no_grad(); no opt.backwards(); no opt.step()
        # then we can just use the self.training step since all these trivial things
        #   are already done under the hood in the lightning code
        # however, this makes logging a bit trickier
        #
        # results = self.training_step(batch=batch, batch_idx=batch_idx)

        x, y = batch # pl takes care of the devices
        x = x.view(x.shape[0], -1)
        logits = self(x) # outputs before softmax (batch size, output size)
        J = self.loss_fct(logits, y)

        acc = accuracy(preds=logits, target=y)
        self.log('val_acc', acc, prog_bar=True)
        
        return {'loss': J}

    def validation_epoch_end(self, outputs):
        avg_val_loss = torch.tensor([x['loss'] for x in outputs]).mean()
        return {'val_loss': avg_val_loss}
        
    def test_step(self, batch, batch_idx):
        x, y = batch # pl takes care of the devices
        x = x.view(x.shape[0], -1)
        logits = self(x) # outputs before softmax (batch size, output size)
        J = self.loss_fct(logits, y)
        
        acc = accuracy(preds=logits, target=y)
        self.log('test_acc', acc, prog_bar=True,)
        
        return {'loss': J}

    def test_epoch_end(self, outputs):
        avg_test_loss = torch.tensor([x['loss'] for x in outputs]).mean()
        return {'test_loss': avg_test_loss}

    def prepare_data(self):
        # happens only once (once per node or once per all gpus on all nodes)
        # the point is not to download for each of the  n_nodes*n_gpus_per_node gpus
        # DO NOT USE self.something = something here! It will only happen on a subset of gpus
        #   and self.something will not be available on the other gpus
        #   this is because in a multy-gpu scenario we put a copy of the model on each of the gpus
        datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())  # train/val data
        datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor()) # test data
        
    def setup(self, stage):
        # transforms go into setup
        train_val_data = datasets.MNIST('data', train=True, download=False, transform=transforms.ToTensor())   
        self.train_data, self.val_data = random_split(train_val_data, [55000, 5000])

        self.test_data = datasets.MNIST('data', train=False, download=False, transform=transforms.ToTensor()) 

    def train_dataloader(self):
        train_loader = DataLoader(self.train_data, batch_size=32,
            num_workers=2
            )
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_data, batch_size=32,
            num_workers=2
            )
        return val_loader

    def test_dataloader(self):
        test_loader = DataLoader(self.test_data, batch_size=32,
            num_workers=2
            )
        return test_loader 
    

model = ResMLPLightning(learning_rate=1e-2)

In [None]:
torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.get_device_name(), torch.cuda.get_device_capability()

In [None]:
# callbacks

early_stop_callback = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.01, patience=3, verbose=True)

save_best_weights_callback = ModelCheckpoint(
    monitor='val_acc', mode='max',
    filename='{epoch}-{val_acc:.2f}',
    save_top_k=2,
)

progress_bar_callback = TQDMProgressBar(refresh_rate=20)

In [None]:
trainer = pl.Trainer(
    max_epochs=20,
    gpus=1,
    callbacks=[early_stop_callback, progress_bar_callback, save_best_weights_callback],
    )
trainer.fit(model)

In [None]:
%ls lightning_logs\version_6

In [None]:
%cat lightning_logs/version_6/hparams.yaml

# Test the results

In [None]:
test_data = datasets.MNIST(root="data", download=True, train=False, transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=32, num_workers=8)

## After just have run the training

In [None]:
trainer.test(model=model, dataloaders=test_loader)

## The training has been run previously

In [None]:
model_untrained = ResMLPLightning(learning_rate=1e-2)
trainer = pl.Trainer(gpus=1)

trainer.test(model=model_untrained,
    #dataloaders=test_loader # only need if it's not defined in the model
    )

In [None]:
save_best_weights_callback.best_model_path

In [None]:
model_trained = ResMLPLightning.load_from_checkpoint("lightning_logs/version_2/checkpoints/epoch=8-val_acc=0.96.ckpt")
#model_trained = ResMLPLightning.load_from_checkpoint(save_best_weights_callback.best_model_path)


trainer = pl.Trainer(gpus=1)
trainer.test(model=model_trained,
    #dataloaders=test_loader # only need if it's not defined in the model
    )

In [None]:
checkpoint = torch.load("lightning_logs/version_2/checkpoints/epoch=7-val_acc=0.96.ckpt", map_location=lambda storage, loc: storage)
print(checkpoint.keys())

print(checkpoint["hparams_name"])
print(checkpoint["hyper_parameters"])

In [None]:
checkpoint = torch.load("lightning_logs/version_2/checkpoints/epoch=7-val_acc=0.96.ckpt")
print(checkpoint.keys())