# Mixed-Precision Training in SELM

This notebook demonstrates the use of mixed-precision training in the SELM model. Mixed-precision training allows for faster computation and reduced memory usage by performing certain operations in lower precision (FP16) while maintaining model accuracy by keeping sensitive operations in full precision (FP32).

We will utilize PyTorch’s `torch.cuda.amp` module for automatic mixed-precision training and compare the performance and memory utilization against standard training.

### Import Necessary Libraries

In [None]:
import torch
from torch.cuda import amp
import torch.optim as optim
from torch.utils.data import DataLoader
from src.model.transformer import SELMTransformer
from src.tasks.text_classification import TextClassificationDataset
import time

### Load Dataset and Initialize Model

We will load a text classification dataset and initialize the SELM model for training.

In [None]:
# Load the dataset
train_dataset = TextClassificationDataset('data/processed/train_data.csv')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Initialize the SELM model
model = SELMTransformer(config_path='config/model_config.yaml')
model = model.to('cuda')  # Move model to GPU

### Define Optimizer and Loss Function

We’ll define the optimizer (Adam) and the loss function (cross-entropy) for training.

In [None]:
# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

### Standard Training Loop (FP32 Precision)

We’ll first define a standard training loop using full precision (FP32) to compare it against the mixed-precision method.

In [None]:
# Standard training loop (FP32)
def train_fp32():
    model.train()
    start_time = time.time()
    for batch in train_loader:
        inputs, labels = batch
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    return time.time() - start_time

# Run standard FP32 training
fp32_time = train_fp32()
print(f"FP32 Training Time: {fp32_time:.2f} seconds")

### Mixed-Precision Training Loop (FP16 Precision)

Next, we define the mixed-precision training loop using `torch.cuda.amp` for automatic mixed precision. This allows some operations to run in FP16, while others remain in FP32 to preserve accuracy.

In [None]:
# Mixed-precision training loop (FP16)
scaler = amp.GradScaler()  # For scaling gradients in FP16

def train_fp16():
    model.train()
    start_time = time.time()
    for batch in train_loader:
        inputs, labels = batch
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()

        # Use autocast for mixed precision
        with amp.autocast():
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        # Scale loss for mixed-precision
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    return time.time() - start_time

# Run mixed-precision FP16 training
fp16_time = train_fp16()
print(f"FP16 Mixed-Precision Training Time: {fp16_time:.2f} seconds")

### Compare Memory Usage

To compare memory usage between FP32 and FP16 training, we can measure the peak GPU memory usage for both methods.

In [None]:
# Function to get memory usage on GPU
def get_gpu_memory():
    return torch.cuda.max_memory_allocated() / (1024 ** 2)  # Convert to MB

# Get memory usage for FP32 training
torch.cuda.reset_peak_memory_stats()
train_fp32()
fp32_memory = get_gpu_memory()
print(f"FP32 Memory Usage: {fp32_memory:.2f} MB")

# Get memory usage for FP16 training
torch.cuda.reset_peak_memory_stats()
train_fp16()
fp16_memory = get_gpu_memory()
print(f"FP16 Memory Usage: {fp16_memory:.2f} MB")

### Performance Results

Let’s compare the training time and memory usage for both standard FP32 training and mixed-precision FP16 training.

In [None]:
# Print out performance comparison
print(f"FP32 Training Time: {fp32_time:.2f} seconds, FP32 Memory: {fp32_memory:.2f} MB")
print(f"FP16 Mixed-Precision Training Time: {fp16_time:.2f} seconds, FP16 Memory: {fp16_memory:.2f} MB")

### Conclusion

Mixed-precision training offers significant advantages in terms of both training speed and memory utilization. By utilizing FP16 operations where appropriate, we can achieve faster training times and lower memory consumption while maintaining the accuracy of the SELM model.

In this notebook, we demonstrated how to integrate mixed-precision training into the SELM model, leveraging PyTorch’s `torch.cuda.amp` for automatic mixed precision. We observed that training with mixed precision can dramatically reduce the memory footprint and speed up training without sacrificing performance.