<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Implementing_Mixed_Precision_Training_with_AMP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assuming you have a Generator and Discriminator model defined
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Linear(100, 784)

    def forward(self, x):
        return torch.tanh(self.fc(x))

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Linear(784, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss function and optimizer
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.001)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.001)

# Transform and DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# GradScaler for mixed precision
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())

# Training loop with mixed precision
n_epochs = 10
for epoch in range(n_epochs):
    for real_data, _ in train_loader:
        real_data = real_data.view(real_data.size(0), -1).to(device)
        batch_size = real_data.size(0)
        noise = torch.randn(batch_size, 100).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
            fake_data = generator(noise)
            real_loss = criterion(discriminator(real_data), real_labels)
            fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
            d_loss = real_loss + fake_loss

        scaler.scale(d_loss).backward()
        scaler.step(optimizer_D)
        scaler.update()
        optimizer_D.zero_grad()

        # Train Generator
        with torch.amp.autocast(device_type='cuda', enabled=torch.cuda.is_available()):
            fake_data = generator(noise)
            g_loss = criterion(discriminator(fake_data), real_labels)

        scaler.scale(g_loss).backward()
        scaler.step(optimizer_G)
        scaler.update()
        optimizer_G.zero_grad()

    print(f"Epoch [{epoch}/{n_epochs}]  Loss D: {d_loss.item()}, loss G: {g_loss.item()}")

print("Training completed!")