In [None]:
!pip install pytorch_lightning

Collecting pytorch_lightning
[?25l  Downloading https://files.pythonhosted.org/packages/07/0c/e2d52147ac12a77ee4e7fd7deb4b5f334cfb335af9133a0f2780c8bb9a2c/pytorch_lightning-1.2.10-py3-none-any.whl (841kB)
[K     |████████████████████████████████| 849kB 5.3MB/s 
[?25hCollecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 10.0MB/s 
Collecting torchmetrics==0.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/3a/42/d984612cabf005a265aa99c8d4ab2958e37b753aafb12f31c81df38751c8/torchmetrics-0.2.0-py3-none-any.whl (176kB)
[K     |████████████████████████████████| 184kB 20.5MB/s 
Collecting PyYAML!=5.4.*,>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████|

In [None]:
import torch
import torchvision
import pytorch_lightning as pl

In [None]:

class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, path = '../data', batch_size = 64):
        super().__init__()
        self.path = path
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.mnist_train = torchvision.datasets.MNIST(
            self.path, train=True, download=True, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
          )
        self.mnist_val = torchvision.datasets.MNIST(
            self.path, train=False, download=True, transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,))
                ])
          )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.mnist_val, batch_size=self.batch_size)



In [None]:
from pytorch_lightning.metrics.functional.classification import accuracy
import torch.nn.functional as F

def block(c_in, c_out, k=3, p=1, s=1, pk=2, ps=2):
    return torch.nn.Sequential(
        torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(pk, stride=ps)
    )
class Modelo(pl.LightningModule):

    def __init__(self, n_channels=1, n_outputs=10):
        super().__init__()
        self.conv1 = block(n_channels, 64)
        self.conv2 = block(64, 128)
        self.fc = torch.nn.Linear(128*7*7, n_outputs)
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.train_precision = pl.metrics.Precision(num_classes=n_outputs)
        self.val_precision = pl.metrics.Precision(num_classes=n_outputs)
        self.train_recall = pl.metrics.Recall(num_classes=n_outputs)
        self.val_recall = pl.metrics.Recall(num_classes=n_outputs)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('loss', loss)
        self.log('acc', self.train_acc(y_hat, y), prog_bar=True)
        self.log('precision', self.train_precision(y_hat, y), prog_bar=True)
        self.log('recall', self.train_recall(y_hat, y), prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y_hat, y), prog_bar=True)
        self.log('val_precision', self.val_precision(y_hat, y), prog_bar=True)
        self.log('val_recall', self.val_recall(y_hat, y), prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

callbacks durante el entrenamiento puedo modificar por eejemplo si la accuracy no mejora en los epochs


In [None]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping #earlystoopinf
from pytorch_lightning.callbacks import ModelCheckpoint

early_stop_callback = EarlyStopping(
   monitor='val_acc', # el valor en el modelo
   patience=3, #cuantas epoch tiene que pasar seguida sin que mejore
   verbose=False,
   mode='max'
)

#### chepoint es para guardar el mejor modelo entrenado
checkpoint = ModelCheckpoint(
    dirpath='./',              #donde queremos que guarde el modelo
    filename='modelo-{val_acc:.5f}', #nombre y ademas las variables que queramos
    save_top_k=1, #el numero de mejores moedlos que guarde
    monitor='val_acc', # la metrica a memorizar
    mode='max'  #el valor mas grande modo max
)                                                                             



modelo = Modelo()
dm = MNISTDataModule()

trainer = pl.Trainer(
    gpus=1,
    callbacks=[early_stop_callback,checkpoint ]
)

trainer.fit(modelo, dm)


In [None]:
#### cargando el mejor modelo para hacer predicciones y lo que sea

modelo = Modelo.load_from_checkpoint(checkpoint_path="modelo-val_acc=0.99060.ckpt")