In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pytorch_lightning as pl
from torchmetrics import Accuracy
from pytorch_lightning.loggers import TensorBoardLogger

# Enable Tensor Core optimization
torch.set_float32_matmul_precision('medium')  # or 'high' for even more speed

# Create TensorBoard logger
logger = TensorBoardLogger("lightning_logs", name="my_model")

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
        return sample, label
class HeavyWorkloadModel(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.learning_rate = learning_rate
        
        # Viel grÃ¶ÃŸeres und komplexeres Netzwerk
        self.conv_layers = nn.Sequential(
            # Erste Convolutional Block
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 28x28 -> 14x14
            nn.Dropout2d(0.25),
            
            # Zweite Convolutional Block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 14x14 -> 7x7
            nn.Dropout2d(0.25),
            
            # Dritte Convolutional Block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4, 4)),  # Adaptive pooling zu 4x4
            nn.Dropout2d(0.5),
        )
        
        # Sehr tiefe fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(128, 10)
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics
        self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
        
        # Save hyperparameters
        self.save_hyperparameters()
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x
    
    def training_step(self, batch, batch_idx):
        """Training step"""
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        
        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        acc = self.train_accuracy(preds, labels)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step"""
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        
        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        acc = self.val_accuracy(preds, labels)
        
        # Log metrics
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        """Configure optimizer"""
        return optim.Adam(self.parameters(), lr=self.learning_rate)

  _C._set_float32_matmul_precision(precision)


In [2]:
# Much larger dataset to keep GPU busy
X_train = torch.rand(10000, 1, 28, 28)  # 10,000 samples instead of 80
y_train = torch.randint(0, 10, (10000,))

X_val = torch.rand(2000, 1, 28, 28)     # 2,000 samples instead of 20
y_val = torch.randint(0, 10, (2000,))

# Create datasets
train_dataset = CustomDataset(X_train, y_train, transform=transforms.Normalize((0.5,), (0.5,)))
val_dataset = CustomDataset(X_val, y_val, transform=transforms.Normalize((0.5,), (0.5,)))

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize model and trainer
model = HeavyWorkloadModel(learning_rate=0.001)
trainer = pl.Trainer(
    max_epochs=20,  # More epochs
    accelerator='gpu',
    devices=1,
    precision='16-mixed',
    logger=logger,
    log_every_n_steps=1,
    enable_progress_bar=True,
    enable_model_summary=True
)

Using 16bit Automatic Mixed Precision (AMP)
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [3]:
# Train the model with validation
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/office/miniforge3/envs/ml1/lib/python3.12/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | conv_layers    | Sequential         | 1.1 M  | train
1 | fc_layers      | Sequential         | 4.9 M  | train
2 | criterion      | CrossEntropyLoss   | 0      | train
3 | train_accuracy | MulticlassAccuracy | 0      | train
4 | val_accuracy   | MulticlassAccuracy | 0      | train
--------------------------------------------------------------
6.0 M     Trainable params
0         Non-trainable params
6.0 M     Total params
24.142    Total estimated model params size (MB)
47        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/office/miniforge3/envs/ml1/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


AttributeError: 'HeavyWorkloadModel' object has no attribute 'flatten'

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/