In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import time

In [13]:
input_size = 512
hidden_size = 1024
output_size = 10
num_samples = 10000
batch_size = 512
epochs = 5

learning_rate = 0.001
criterion = nn.CrossEntropyLoss()

In [14]:
X = torch.randn(num_samples, input_size)
y = torch.randint(0, output_size, (num_samples,))

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [15]:
class NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.layers(x)


class ModelParallelNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers1 = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        ).to('cuda:0')

        self.layers2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        ).to('cuda:1')

    def forward(self, x):
        x = x.to('cuda:0')
        x = self.layers1(x)
        x = x.to('cuda:1')
        x = self.layers2(x)
        return x

In [24]:
def train_model(model, dataloader, device):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    start_time = time.time()

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        epoch_loss = 0
        for batch_x, batch_y in dataloader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        print(f"Loss: {epoch_loss / len(dataloader)}")
    return time.time() - start_time

In [25]:
def train_model_parallel(model, dataloader):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    start_time = time.time()

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        epoch_loss = 0
        for batch_x, batch_y in dataloader:
            optimizer.zero_grad()
            output = model(batch_x)
            batch_y = batch_y.to('cuda:1')
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Loss: {epoch_loss / len(dataloader)}")
    return time.time() - start_time

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("Training on a single GPU...")
model_single = NN()
time_single = train_model(model_single, dataloader, device)

Using device: cuda:0
Training on a single GPU...
Epoch 1/5
Loss: 2.3075669527053835
Epoch 2/5
Loss: 2.2551753759384154
Epoch 3/5
Loss: 2.014691323041916
Epoch 4/5
Loss: 1.5095293402671814
Epoch 5/5
Loss: 0.6876625180244446


In [27]:
print("Training with model parallelism...")
model_mp = ModelParallelNN()
time_mp = train_model_parallel(model_mp, dataloader)

Training with model parallelism...
Epoch 1/5
Loss: 2.3064645171165465
Epoch 2/5
Loss: 2.235544431209564
Epoch 3/5
Loss: 1.9499459862709045
Epoch 4/5
Loss: 1.3909866333007812
Epoch 5/5
Loss: 0.5930090069770813


In [28]:
print("Training with data parallelism...")
model_dp = NN()
model_dp = nn.DataParallel(model_dp)
time_dp = train_model(model_dp, dataloader, device)

Training with data parallelism...
Epoch 1/5
Loss: 2.3065056204795837
Epoch 2/5
Loss: 2.247606670856476
Epoch 3/5
Loss: 1.9925490617752075
Epoch 4/5
Loss: 1.4459463596343993
Epoch 5/5
Loss: 0.6214424446225166


In [30]:
print(f"Single GPU: {time_single:.2f} seconds")
print(f"Model Parallel: {time_mp:.2f} seconds")
print(f"Data Parallel: {time_dp:.2f} seconds")

Single GPU: 1.60 seconds
Model Parallel: 1.10 seconds
Data Parallel: 4.34 seconds
