In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

In [16]:
"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")"""

'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")'

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:13<00:00, 712kB/s] 


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 131kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.25MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.24MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [4]:
subset1, subset2 = random_split(train_dataset, [600, len(train_dataset) - 600])

In [5]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
def average_model_parameters(models, weights):
    avg_params = []
    for params in zip(*[list(model.parameters()) for model in models]):
        avg_params.append(sum(weight * param.data for weight, param in zip(weights, params)))
    return avg_params

In [7]:
def update_model_parameters(model, avg_params):
    for param, avg_param in zip(model.parameters(), avg_params):
        param.data.copy_(avg_param)

In [9]:
def train_model(model, dataloader, epochs=5):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for data, target in dataloader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

In [17]:
model1 = CNN()
model2 = CNN()
dataloader1 = DataLoader(subset1, batch_size=50, shuffle=True)
dataloader2 = DataLoader(subset2, batch_size=50, shuffle=True)

In [18]:
train_model(model1, dataloader1)
train_model(model2, dataloader2)

In [20]:
avg_params = average_model_parameters([model1, model2], [0.5, 0.5])
update_model_parameters(model1, avg_params)
update_model_parameters(model2, avg_params)

In [22]:
correct = 0
model1.eval()
test_loader = DataLoader(datasets.MNIST(root='./data', train=False, download=True, transform=transform), batch_size=1000)

with torch.no_grad():
    for data, target in test_loader:
        data, target = data, target
        output = model1(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = correct / len(test_loader.dataset)
print(f'Accuracy after averaging: {accuracy:.4f}')

Accuracy after averaging: 0.8018


In [23]:
train_model(model1, dataloader1)
train_model(model2, dataloader2)

avg_params = average_model_parameters([model1, model2], [0.5, 0.5])
update_model_parameters(model1, avg_params)
update_model_parameters(model2, avg_params)

correct = 0
model1.eval()
with torch.no_grad():
    for data, target in test_loader:
        data, target = data, target
        output = model1(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = correct / len(test_loader.dataset)
print(f'Accuracy after initializing: {accuracy:.4f}')

Accuracy after initializing: 0.9918


In [24]:
batch_sizes = [50, 25, 10, 5]
for batch_size in batch_sizes:
    dataloader1 = DataLoader(subset1, batch_size=batch_size, shuffle=True)
    dataloader2 = DataLoader(subset2, batch_size=batch_size, shuffle=True)

    train_model(model1, dataloader1)
    train_model(model2, dataloader2)

    avg_params = average_model_parameters([model1, model2], [0.5, 0.5])
    update_model_parameters(model1, avg_params)
    update_model_parameters(model2, avg_params)

    correct = 0
    model1.eval()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data, target
            output = model1(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = correct / len(test_loader.dataset)
    print(f'Batch size {batch_size}, accuracy: {accuracy:.4f}')

Batch size 50, accuracy: 0.9923
Batch size 25, accuracy: 0.9935
Batch size 10, accuracy: 0.9930
Batch size 5, accuracy: 0.9931
