In [None]:
%pip install torch torchvision pyspark

In [8]:
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import os
from pyspark.sql import SparkSession
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms

In [9]:
os.environ['MASTER_ADDR'] = 'hartford.cs.colostate.edu'
os.environ['MASTER_PORT'] = '31220' # This is for DDP, not Spark's master port
os.environ['RANK'] = '0'

# Create a SparkSession
spark = SparkSession.builder.appName("PyTorch Distributed Training").getOrCreate()

# Initialize the default process group
dist.init_process_group(backend="gloo", world_size=8)


[W socket.cpp:426] [c10d] The server socket has failed to bind to [::]:31215 (errno: 98 - Address already in use).
[W socket.cpp:426] [c10d] The server socket has failed to bind to 0.0.0.0:31215 (errno: 98 - Address already in use).
[E socket.cpp:462] [c10d] The server socket has failed to listen on any local network address.


RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:31215 (errno: 98 - Address already in use). The server socket has failed to bind to 0.0.0.0:31215 (errno: 98 - Address already in use).

In [None]:
class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=n_classes),
        )


    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=1)
        return logits, probs

In [None]:
# define transforms
transform = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])

In [None]:
# Load the MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
val_data = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Create RDDs from the datasets
train_rdd = spark.sparkContext.parallelize([(x.unsqueeze(0), y) for x, y in train_data])
val_rdd = spark.sparkContext.parallelize([(x.unsqueeze(0), y) for x, y in val_data])


In [None]:
# Define params
batch_size = 32
epochs = 15
lr=0.001
N_CLASSES = 10

In [None]:
# Create a DataLoader from the RDD
train_loader = DataLoader(train_rdd, batch_size=batch_size)
val_loader = DataLoader(val_rdd, batch_size=batch_size)

In [None]:
# Initialize the model and optimizer
model = LeNet5(N_CLASSES)
optimizer = optim.SGD(model.parameters(), lr=lr)
# Wrap the model with DistributedDataParallel
model = nn.parallel.DistributedDataParallel(model)

In [None]:
# Train the model
for epoch in range(epochs):
    model.train()
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data.squeeze().to(torch.float32))
        loss = nn.functional.nll_loss(output, target.to(torch.long))
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data.squeeze().to(torch.float32))
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target.to(torch.long)).sum().item()

    print("Epoch {} Accuracy: {}".format(epoch+1, correct/total))

# Stop the SparkSession
spark.stop()