# EfficientNet Training with Mixed Precision

FP16 mixed precision training for efficient neural networks on NVIDIA GPUs.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Rishav-raj-github/End-to-End-Computer-Vision-Pipeline-EfficientNet-on-NVIDIA-GPUs/blob/main/colab_notebooks/02_EfficientNet_Training_Colab.ipynb)

In [None]:
!pip install -q timm torch torchvision tensorboard

In [None]:
import torch
import torch.nn as nn
import timm
from torch.cuda.amp import autocast, GradScaler

print(f'CUDA Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f}GB')

In [None]:
class EfficientNetTrainer:
    """EfficientNet trainer with FP16 mixed precision"""
    
    def __init__(self, model_name='efficientnet_b0', num_classes=10, device='cuda'):
        self.device = device
        self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
        self.model = self.model.to(device)
        self.scaler = GradScaler()
        self.criterion = nn.CrossEntropyLoss()
    
    def train_step(self, batch, optimizer, epoch):
        x, y = batch
        x, y = x.to(self.device), y.to(self.device)
        
        optimizer.zero_grad()
        
        with autocast():
            logits = self.model(x)
            loss = self.criterion(logits, y)
        
        self.scaler.scale(loss).backward()
        self.scaler.step(optimizer)
        self.scaler.update()
        
        return loss.item()
    
    def validate(self, loader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in loader:
                x, y = x.to(self.device), y.to(self.device)
                with autocast():
                    logits = self.model(x)
                    correct += (logits.argmax(1) == y).sum().item()
                    total += y.size(0)
        return correct / total

In [None]:
# Initialize trainer
trainer = EfficientNetTrainer(model_name='efficientnet_b0', num_classes=10)
print(f'Model parameters: {sum(p.numel() for p in trainer.model.parameters()) / 1e6:.2f}M')