In [1]:
import os
import lightning.pytorch as pl

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torchmetrics
import wandb
from lightning.pytorch.loggers import WandbLogger

In [2]:
PATH_DATASETS = os.environ.get("PATH_DATASETS","/users/PLS0129/ysu0053/CSCI4852_6852_F23_DL/data")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

In [3]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('Number of GPUs:',torch.cuda.device_count())
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda
Number of GPUs: 1

Tesla V100-PCIE-16GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [4]:
# Define the Lightning module
class MNISTLightning(pl.LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        # Define PyTorch model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes),
        )

        self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
        self.test_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.val_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_accuracy, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_accuracy.update(preds, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [50000, 10000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

In [None]:
# Init our model
mnist_model = MNISTLightning()

# Initialize wandb
wandb.init(project='mnist_mlp')
settings=wandb.Settings(silent="True")

# Create the WandbLogger
wandb_logger = WandbLogger()

# Initialize a trainer
trainer = pl.Trainer(accelerator="auto",devices=1,max_epochs=15, logger=wandb_logger) 
# Train the model ⚡
trainer.fit(mnist_model)

#Testing
trainer.test()
# Close wandb run
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malazar[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | model         | Sequential         | 55.1 K
1 | val_accuracy  | MulticlassAccuracy | 0     
2 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 196/196 [00:16<00:00, 11.55it/s, v_num=6407]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/40 [00:00<?, ?it/s][A
Validation DataLoader 0:   2%|▎         | 1/40 [00:00<00:00, 60.10it/s][A
Validation DataLoader 0:   5%|▌         | 2/40 [00:00<00:01, 27.17it/s][A
Validation DataLoader 0:   8%|▊         | 3/40 [00:00<00:02, 16.55it/s][A
Validation DataLoader 0:  10%|█         | 4/40 [00:00<00:02, 15.64it/s][A
Validation DataLoader 0:  12%|█▎        | 5/40 [00:00<00:02, 13.77it/s][A
Validation DataLoader 0:  15%|█▌        | 6/40 [00:00<00:02, 13.68it/s][A
Validation DataLoader 0:  18%|█▊        | 7/40 [00:00<00:02, 12.81it/s][A
Validation DataLoader 0:  20%|██        | 8/40 [00:00<00:02, 12.51it/s][A
Validation DataLoader 0:  22%|██▎       | 9/40 [00:00<00:02, 12.02it/s][A
Validation DataLoader 0:  25%|██▌       | 10/40 [00:00<00:02, 11.89it/s][A
Validation DataLoader 0