[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](
https://colab.research.google.com/github/CMU-IDeeL/CMU-IDeeL.github.io/blob/master/F25/document/Recitation_0_Series/0.23/0_23_Distributed_Training.ipynb)

# DataParallel Notebook
-------------------------
Inspired by: [Pytorch Data Parallelism Tutorial](https://docs.pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html)

# Imports and Initial Setup
------------------------------------
This cell imports the necessary PyTorch libraries.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

In [None]:
print(f"PyTorch Version: {torch.__version__}")
print("-" * 30)

# Check for available GPUs
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Found {num_gpus} GPUs.")
    # Set the primary device
    device = torch.device("cuda:0")
else:
    num_gpus = 0
    print("No GPUs found. Running on CPU.")
    device = torch.device("cpu")

PyTorch Version: 2.6.0+cu124
------------------------------
Found 2 GPUs.


# Data Parallel
-----------------------------
Source: [DataParallel vs. DistributedDataParallel in PyTorch: What’s the Difference?](https://medium.com/@mlshark/dataparallel-vs-distributeddataparallel-in-pytorch-whats-the-difference-0af10bb43bc7)

# Define a Simple Model
-----------------------------
We'll create a basic neural network for this demonstration.
DataParallel will replicate this model on each available GPU.

In [None]:
class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, output_size)

    def forward(self, x, debug=False):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        if debug:
            print("\tInside the Model: input size", x.size(), "output size", out.size())
        return out

# Data Preparation and Training Loop
------------------------------------------
This is the main part where we wrap our model with DataParallel
and run the training process.

## 1. Hyperparameters and Data

In [None]:
input_size = 784
output_size = 10
batch_size = 256  # A larger batch size helps utilize multiple GPUs
learning_rate = 0.01
num_epochs = 20

# Create dummy data
# We create a dataset of 10000 samples
inputs = torch.randn(10000, input_size)
targets = torch.randint(0, output_size, (10000,))

# Use DataLoader for batching
dataset = TensorDataset(inputs, targets)
# The batch size will be split across GPUs. If you have 2 GPUs,
# each will process batch_size / 2 samples.
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## 2. Initialize and Wrap the Model
Instantiate the model

In [None]:
model = SimpleModel(input_size, output_size)

# IMPORTANT: Wrap the model with nn.DataParallel
# This is the key step for data parallelism.
# If multiple GPUs are available, this wrapper will handle the data distribution.
if num_gpus > 1:
    print(f"Using {num_gpus} GPUs for training!")
    model = nn.DataParallel(model)
else:
    print("Training on a single device (CPU or 1 GPU).")

# Move the model to the primary device. DataParallel will handle the rest.
model.to(device)

Using 2 GPUs for training!


DataParallel(
  (module): SimpleModel(
    (fc1): Linear(in_features=784, out_features=128, bias=True)
    (relu): ReLU()
    (fc2): Linear(in_features=128, out_features=10, bias=True)
  )
)

## 3. Loss and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

## 4. Training Loop

In [None]:
print("\nStarting training...")
for epoch in range(num_epochs):
    total_loss = 0
    for i, (batch_inputs, batch_targets) in enumerate(data_loader):
        # Move data to the primary device. DataParallel will scatter it.
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)

        # Forward pass
        # DataParallel automatically splits the batch, sends it to the GPUs,
        # executes the forward pass, and gathers the outputs on the primary device.
        debug = epoch == 0 and i == 0
        outputs = model(batch_inputs, debug=debug)
        if debug:
            print("Outside: input size", batch_inputs.size(), "output_size", outputs.size())
        loss = criterion(outputs, batch_targets)

        # Backward and optimize
        # The loss is computed on the primary GPU. The backward pass calculates
        # gradients on each GPU, which are then summed on the primary GPU.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(data_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

print("\nTraining finished!")




Starting training...
	Inside the Model: input size torch.Size([128, 784]) output size torch.Size([128, 10])
	Inside the Model: input size torch.Size([128, 784]) output size torch.Size([128, 10])
Outside: input size torch.Size([256, 784]) output_size torch.Size([256, 10])
Epoch [1/20], Loss: 2.3282
Epoch [2/20], Loss: 2.3176
Epoch [3/20], Loss: 2.3090
Epoch [4/20], Loss: 2.3036
Epoch [5/20], Loss: 2.2976
Epoch [6/20], Loss: 2.2932
Epoch [7/20], Loss: 2.2872
Epoch [8/20], Loss: 2.2796
Epoch [9/20], Loss: 2.2780
Epoch [10/20], Loss: 2.2728
Epoch [11/20], Loss: 2.2654
Epoch [12/20], Loss: 2.2600
Epoch [13/20], Loss: 2.2561
Epoch [14/20], Loss: 2.2523
Epoch [15/20], Loss: 2.2479
Epoch [16/20], Loss: 2.2409
Epoch [17/20], Loss: 2.2359
Epoch [18/20], Loss: 2.2329
Epoch [19/20], Loss: 2.2275
Epoch [20/20], Loss: 2.2236

Training finished!


## 5. Accessing the Original Model
If you need to save the model's state dict or access the original model
without the DataParallel wrapper, you need to use .module

In [None]:
if isinstance(model, nn.DataParallel):
    original_model = model.module
    print("\nModel was wrapped in DataParallel. Accessing the original model via .module")
    torch.save(original_model.state_dict(), 'model_state.pth')
else:
    original_model = model
    print("\nModel was not wrapped. Saving the model directly.")
    torch.save(original_model.state_dict(), 'model_state.pth')


Model was wrapped in DataParallel. Accessing the original model via .module
