In [1]:
%pip install torch torchvision pyspark

Defaulting to user installation because normal site-packages is not writeable
Collecting sparktorch
  Downloading sparktorch-0.2.0-py3-none-any.whl (25 kB)
Collecting flask
  Downloading Flask-2.2.3-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 3.2 MB/s eta 0:00:011
Collecting Werkzeug>=2.2.2
  Downloading Werkzeug-2.2.3-py3-none-any.whl (233 kB)
[K     |████████████████████████████████| 233 kB 161.4 MB/s eta 0:00:01
[?25hCollecting click>=8.0
  Downloading click-8.1.3-py3-none-any.whl (96 kB)
[K     |████████████████████████████████| 96 kB 2.0 MB/s s eta 0:00:01
[?25hCollecting itsdangerous>=2.0
  Downloading itsdangerous-2.1.2-py3-none-any.whl (15 kB)
Installing collected packages: Werkzeug, click, itsdangerous, flask, sparktorch
Successfully installed Werkzeug-2.2.3 click-8.1.3 flask-2.2.3 itsdangerous-2.1.2 sparktorch-0.2.0
Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import os
from sparktorch import serialize_torch_obj, SparkTorch
from pyspark.sql import SparkSession
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms

In [None]:
os.environ['MASTER_ADDR'] = 'hartford.cs.colostate.edu'
os.environ['MASTER_PORT'] = '31220' # This is for DDP, not Spark's master port
os.environ['RANK'] = '0'

# Create a SparkSession
spark = SparkSession.builder.appName("PyTorch Distributed Training").getOrCreate()

# Initialize the default process group
dist.init_process_group(backend="gloo", world_size=16)


In [None]:
class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=n_classes),
        )


    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=1)
        return logits, probs

In [None]:
# define transforms
transform = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])

In [None]:
# Load the MNIST dataset
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
val_data = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Create RDDs from the datasets
train_rdd = spark.sparkContext.parallelize([(x.unsqueeze(0), y) for x, y in train_data])
val_rdd = spark.sparkContext.parallelize([(x.unsqueeze(0), y) for x, y in val_data])


In [None]:
# Define params
batch_size = 32
epochs = 15
lr=0.001
N_CLASSES = 10

In [None]:
# Create a DataLoader from the RDD
train_loader = DataLoader(train_rdd, batch_size=batch_size)
val_loader = DataLoader(val_rdd, batch_size=batch_size)

In [None]:
# Initialize the model and optimizer
model = LeNet5(N_CLASSES)
optimizer = optim.SGD(model.parameters(), lr=lr)
# Wrap the model with DistributedDataParallel
model = nn.parallel.DistributedDataParallel(model)

In [None]:
# Train the model
for epoch in range(epochs):
    model.train()
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data.squeeze().to(torch.float32))
        loss = nn.functional.nll_loss(output, target.to(torch.long))
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data.squeeze().to(torch.float32))
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target.to(torch.long)).sum().item()

    print("Epoch {} Accuracy: {}".format(epoch+1, correct/total))

# Stop the SparkSession
spark.stop()