In [1]:
import pytorch_lightning as pl
import torch
from torch import nn, optim
from torchmetrics.classification import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split, Subset
from pytorch_lightning.loggers import TensorBoardLogger

# Define a simple model
class SimpleClassifier(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, output_dim, lr):
        super().__init__()
        self.save_hyperparameters()  # Save hyperparameters
        self.layer_1 = nn.Linear(input_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(num_classes=output_dim, task='multiclass')

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)  # Flatten the input
        x = torch.relu(self.layer_1(x))
        x = self.layer_2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        
        # Compute accuracy
        acc = self.accuracy(y_hat, y)
        
        logs = {
            'train_loss': loss,
            'train_acc': acc
        }
        # Log accuracy and loss
        self.log_dict(logs, on_step=True, on_epoch=False)
        
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.lr)

# Data preparation
dataset = MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True)

# Use a small subset of the dataset for quick training
subset_indices = list(range(1000))  # Using 1000 samples for quick training
subset = Subset(dataset, subset_indices)

train_set, val_set = random_split(subset, [800, 200])

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)  # Smaller batch size
val_loader = DataLoader(val_set, batch_size=16)

# Define hyperparameters
hyperparams = {
    'input_dim': 28*28,    # 28x28 images
    'hidden_dim': 32,      # Reduced hidden dimension
    'output_dim': 10,      # 10 classes
    'lr': 1e-3,
}

# Initialize the model
model = SimpleClassifier(**hyperparams)

# Initialize TensorBoard Logger
# logger = TensorBoardLogger('.', default_hp_metric=False)  # Use default folder
logger = TensorBoardLogger('.')  # Use default folder

# Initialize the trainer
trainer = pl.Trainer(
    max_epochs=5,           # Train for only 1 epoch
    logger=logger,
    log_every_n_steps=1,
    enable_progress_bar=True,  # Disable progress bar
    accelerator='cpu',     # Use CPU to speed up (remove if you have a GPU)
)

# Train the model
trainer.fit(model, train_loader, val_loader)

  from .autonotebook import tqdm as notebook_tqdm
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name     | Type               | Params
------------------------------------------------
0 | layer_1  | Linear             | 25.1 K
1 | layer_2  | Linear             | 330   
2 | loss_fn  | CrossEntropyLoss   | 0     
3 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
25.4 K    Trainable params
0         Non-trainable params
25.4 K    Total params
0.102     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 4: 100%|██████████| 50/50 [00:00<00:00, 133.84it/s, loss=0.473, v_num=4]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 50/50 [00:00<00:00, 132.42it/s, loss=0.473, v_num=4]


# run tensorboard --logdir lightning_logs
# in current folder