In [9]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.loggers import TensorBoardLogger

from pl_bolts.datamodules import FashionMNISTDataModule
from torchmetrics.functional import accuracy


seed_everything(42)

Global seed set to 42


42

In [10]:
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 512
NUM_WORKERS = int(os.cpu_count() / 2)

In [11]:
train_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
    ]
)

test_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
    ]
)


fashion_mnist_dm = FashionMNISTDataModule(
    data_dir=PATH_DATASETS,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_transforms=train_transforms,
    val_transforms=test_transforms,
    test_transforms=test_transforms
)

In [12]:
def create_model():
    layers = [
        nn.Flatten(),
        nn.Linear(28 * 28, 300),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(300, 100),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.Linear(100, 10),
    ]

    return layers


In [13]:
class DNNModel(pl.LightningModule):
    def __init__(self) -> None:
        super().__init__()

        # self.save_hyperparameters()
        self.net = nn.Sequential(
            *create_model()
        )

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):

        # if self.current_epoch == 1:
        #     sample_img = torch.rand((1, 1, 28, 28))
        #     self.logger.experiment.add_graph(DNNModel(), sample_img)

        x, y = batch
        loss = F.cross_entropy(self(x), y)
        self.log("train_loss", loss)

        return loss
        

    def evaluate(self, batch, stage=None):
        x, y = batch
        out = self(x)
        loss = F.cross_entropy(out, y)
        preds = torch.argmax(out, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=3e-4, weight_decay=5e-4)
        return optimizer

In [14]:
### Uncomment the next to lines if using Colab

# %reload_ext tensorboard
# %tensorboard --logdir lightning_logs/

In [15]:
fashion_mnist_model = DNNModel()
# fashion_mnist_model.datamodule = fashion_mnist_dm

trainer = pl.Trainer(
    progress_bar_refresh_rate=10,
    max_epochs=30,
    gpus=AVAIL_GPUS,
    logger=TensorBoardLogger("lightning_logs/", name="test_run_1"),
)

trainer.fit(fashion_mnist_model, datamodule=fashion_mnist_dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | net  | Sequential | 266 K 
------------------------------------
266 K     Trainable params
0         Non-trainable params
266 K     Total params
1.066     Total estimated model params size (MB)


                                                              

Global seed set to 42


Epoch 1:   0%|          | 0/118 [00:00<00:00, 1331.95it/s, loss=0.786, v_num=0, val_loss=0.714, val_acc=0.744] 

  if hasattr(mod, name):
  if hasattr(mod, name):


Epoch 29: 100%|██████████| 118/118 [00:03<00:00, 35.45it/s, loss=0.279, v_num=0, val_loss=0.310, val_acc=0.888]


In [16]:
trainer.test(fashion_mnist_model, datamodule=fashion_mnist_dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 100%|██████████| 20/20 [00:00<00:00, 36.40it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.882099986076355, 'test_loss': 0.33224135637283325}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 20/20 [00:00<00:00, 29.04it/s]


[{'test_loss': 0.33224135637283325, 'test_acc': 0.882099986076355}]