In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import Dataset

In [2]:
# Hyperparameter
batch_size = 64
learning_rate = 0.001
num_epochs = 1

In [3]:

# Laden des FashionMNIST-Datensatzes
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = FashionMNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [9]:
class ListDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]

In [20]:
def get_filtered_dataloader(dataset, classes):
    filtered_list = [
        i for i, (x, y) in enumerate(dataset) if y in classes
    ]
    return torch.utils.data.Subset(dataset, filtered_list)
    
        

In [21]:
filtered_dataloader = DataLoader(get_filtered_dataloader(train_dataset, [1, 2, 3]),  batch_size=batch_size, shuffle=True)

In [13]:

# LeNet Modell Definition
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5, padding=2)
        self.fc1 = nn.Linear(2450, 120) #800
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2, stride=2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2, stride=2)
        x = x.view(x.size(0), -1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [22]:

# Modellinitialisierung
model_all = LeNet()

# Verlustfunktion und Optimierer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [31]:



def train_model(model, train_dataloader):
    # Trainingsschleife
    total_step = len(train_dataloader)#len(train_loader)
    # Verlustfunktion und Optimierer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_dataloader):

        #for i, (images, labels) in enumerate(train_loader):
            # Vorwärtsdurchlauf
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Rückwärtsdurchlauf und Optimierung
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Schritt [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
                
    return model

def test_model(model, test_dataloader):
    # Testen des Modells
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Genauigkeit des Modells auf Testdaten: {} %'.format(100 * correct / total))


In [35]:
model_0123 = LeNet()
model_456 = LeNet()
model_789 = LeNet()

dataloader_0123 = DataLoader(get_filtered_dataloader(train_dataset, [0, 1, 2, 3]), batch_size=batch_size, shuffle=True)
dataloader_456 = DataLoader(get_filtered_dataloader(train_dataset, [4, 5, 6]),  batch_size=batch_size, shuffle=True)
dataloader_789 = DataLoader(get_filtered_dataloader(train_dataset, [7, 8, 9]),  batch_size=batch_size, shuffle=True)


In [36]:
model_0123 = train_model(model_0123, dataloader_0123)
model_456 = train_model(model_456, dataloader_456)
model_789 = train_model(model_789, dataloader_789)


Epoch [1/1], Schritt [100/375], Loss: 0.3041
Epoch [1/1], Schritt [200/375], Loss: 0.1666
Epoch [1/1], Schritt [300/375], Loss: 0.2774
Epoch [1/1], Schritt [100/282], Loss: 0.2294
Epoch [1/1], Schritt [200/282], Loss: 0.3087
Epoch [1/1], Schritt [100/282], Loss: 0.1628
Epoch [1/1], Schritt [200/282], Loss: 0.0979


In [37]:
test_model(model_0123, test_loader)
test_model(model_456, test_loader)
test_model(model_789, test_loader)

Genauigkeit des Modells auf Testdaten: 37.64 %
Genauigkeit des Modells auf Testdaten: 26.88 %
Genauigkeit des Modells auf Testdaten: 28.93 %


In [45]:
params1 = model_0123.state_dict()
params2 = model_456.state_dict()

params_mean = dict()
for param_name in params1.keys():
    params_mean[param_name] = torch.add(params1[param_name], params2[param_name]) / 2

In [46]:
model_avg = LeNet()
model_avg.load_state_dict(params_mean)

<All keys matched successfully>

In [47]:
test_model(model_avg, test_loader)

Genauigkeit des Modells auf Testdaten: 21.82 %


In [57]:
def get_average_model(models):
    state_dicts = [model.state_dict() for model in models]
    state_dict_average = dict()
    for i in range(len(state_dicts) - 1):
        params1 = state_dicts[i]
        params2 = state_dicts[i + 1]
        for param_name in params1.keys():
            state_dict_average[param_name] = torch.add(params1[param_name], params2[param_name])
            
    for param_name in state_dict_average.keys():
        state_dict_average[param_name] = state_dict_average[param_name] / len(state_dicts)
        
    model_avg = LeNet()
    model_avg.load_state_dict(state_dict_average)
    
    return model_avg

In [50]:
model_avg = get_average_model([model_0123, model_456, model_789])
test_model(model_avg, test_loader)

Genauigkeit des Modells auf Testdaten: 21.82 %


In [58]:
model_0123 = LeNet()
model_456 = LeNet()
model_789 = LeNet()

In [None]:
avgs = list()

for i in range(3):
    print(i)
    model_0123 = train_model(model_0123, dataloader_0123)
    model_456 = train_model(model_456, dataloader_456)
    model_789 = train_model(model_789, dataloader_789)
    
    model_avg = get_average_model([model_0123, model_456, model_789])
    avgs.append(model_avg.state_dict())
    test_model(model_0123, test_loader)
    test_model(model_456, test_loader)
    test_model(model_789, test_loader)

    test_model(model_avg, test_loader)
    
    model_0123.load_state_dict(model_avg.state_dict())
    model_456.load_state_dict(model_avg.state_dict())
    model_789.load_state_dict(model_avg.state_dict())

    

0
Epoch [1/1], Schritt [100/375], Loss: 0.1850
Epoch [1/1], Schritt [200/375], Loss: 0.1855
Epoch [1/1], Schritt [300/375], Loss: 0.0938
Epoch [1/1], Schritt [100/282], Loss: 0.2129
Epoch [1/1], Schritt [200/282], Loss: 0.4150
Epoch [1/1], Schritt [100/282], Loss: 0.2017
Epoch [1/1], Schritt [200/282], Loss: 0.0948
Genauigkeit des Modells auf Testdaten: 37.75 %
Genauigkeit des Modells auf Testdaten: 27.01 %
Genauigkeit des Modells auf Testdaten: 29.01 %
Genauigkeit des Modells auf Testdaten: 8.6 %
1
Epoch [1/1], Schritt [100/375], Loss: 0.3075
Epoch [1/1], Schritt [200/375], Loss: 0.2571
Epoch [1/1], Schritt [300/375], Loss: 0.1237
Epoch [1/1], Schritt [100/282], Loss: 0.2176
Epoch [1/1], Schritt [200/282], Loss: 0.1953
Epoch [1/1], Schritt [100/282], Loss: 0.2335
Epoch [1/1], Schritt [200/282], Loss: 0.2013
Genauigkeit des Modells auf Testdaten: 36.19 %
Genauigkeit des Modells auf Testdaten: 25.23 %
Genauigkeit des Modells auf Testdaten: 28.41 %
Genauigkeit des Modells auf Testdaten: 

In [56]:
for param_name in avgs[0]:
    for i in range(len(avgs) - 1):
        print(torch.sum(avgs[i][param_name] - avgs[i + 1][param_name] ))

tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
