In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics import Accuracy
from pytorch_lightning.loggers import TensorBoardLogger

# Enable Tensor Core optimization
torch.set_float32_matmul_precision('medium')  # or 'high' for even more speed

# Create TensorBoard logger
logger = TensorBoardLogger("lightning_logs", name="my_model")

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
        return sample, label

class EnhancedLightningModel(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.learning_rate = learning_rate
        
        # Model architecture
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics
        self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
        
        # Save hyperparameters
        self.save_hyperparameters()
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc(x)
        return x
    
    def training_step(self, batch, batch_idx):
        """Training step"""
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        
        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        acc = self.train_accuracy(preds, labels)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step"""
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        
        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        acc = self.val_accuracy(preds, labels)
        
        # Log metrics
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        """Configure optimizer"""
        return optim.Adam(self.parameters(), lr=self.learning_rate)

# Data preparation
X_train = torch.rand(80, 1, 28, 28)  # 80 samples for training
y_train = torch.randint(0, 10, (80,))

X_val = torch.rand(20, 1, 28, 28)    # 20 samples for validation
y_val = torch.randint(0, 10, (20,))

# Create datasets
train_dataset = CustomDataset(X_train, y_train, transform=transforms.Normalize((0.5,), (0.5,)))
val_dataset = CustomDataset(X_val, y_val, transform=transforms.Normalize((0.5,), (0.5,)))

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize model and trainer
model = EnhancedLightningModel(learning_rate=0.001)
trainer = pl.Trainer(
    max_epochs=5,
    accelerator='gpu',
    devices=1,  # Use 1 GPU explicitly
    precision='16-mixed',  # Mixed precision for speed
    logger=logger,
    log_every_n_steps=1,
    enable_progress_bar=True,
    enable_model_summary=True
)

# Train the model with validation
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

  _C._set_float32_matmul_precision(precision)
Using 16bit Automatic Mixed Precision (AMP)
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/office/miniforge3/envs/ml1/lib/python3.12/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | flatten        | Flatten            | 0      | train
1 | fc             | Sequential         | 101 K  | train
2 | criterion      | CrossEntropyLoss   | 0      | train
3 | train_accuracy | Multiclass

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/office/miniforge3/envs/ml1/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

/home/office/miniforge3/envs/ml1/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:00<00:00, 71.56it/s, v_num=0, train_loss_step=1.240, train_acc_step=1.000, val_loss=2.310, val_acc=0.100, train_loss_epoch=1.270, train_acc_epoch=1.000] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 3/3 [00:00<00:00, 60.81it/s, v_num=0, train_loss_step=1.240, train_acc_step=1.000, val_loss=2.310, val_acc=0.100, train_loss_epoch=1.270, train_acc_epoch=1.000]


In [2]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/