In [1]:
import torch
import argparse
import sys
import os

from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

torch.manual_seed(0)
sys.path.append("../../../../../")

from fedlab.core.client.scale import ScaleClientManager
from fedlab.core.network import DistNetwork
from fedlab.utils.serialization import SerializationTool
from fedlab.utils.logger import Logger
from fedlab.utils.aggregator import Aggregators
from fedlab.utils.functional import load_dict
from fedlab.utils.functional import AverageMeter

from fedlab_benchmarks.models.rnn import RNN_Shakespeare
from fedlab_benchmarks.datasets.leaf_data_process.dataloader import get_LEAF_dataset, get_LEAF_dataloader

In [2]:
def evaluate(model, criterion, test_loader):
    model.eval()
    gpu = next(model.parameters()).device

    loss_ = AverageMeter()
    acc_ = AverageMeter()

    with torch.no_grad():
        for inputs, labels in test_loader:

            inputs = inputs.to(gpu)
            labels = labels.to(gpu)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs, 1)
            loss_.update(loss.item())
            acc_.update(torch.sum(predicted.eq(labels)).item(), len(labels))

    return loss_.sum, acc_.avg

In [3]:
train_loader, test_loader = get_LEAF_dataloader("shakespeare",1)

In [11]:
model = RNN_Shakespeare().cuda()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02)


for _ in range(10):
    model.train()
    for data, target in train_loader:
        
        data = data.cuda()
        target = target.cuda()

        output = model(data)

        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(evaluate(model, criterion, test_loader))

(4.40276575088501, 0.17871222076215507)
(4.294654369354248, 0.17871222076215507)
(4.1673760414123535, 0.17871222076215507)
(4.001285076141357, 0.17871222076215507)
(3.7913155555725098, 0.17871222076215507)
(3.6132514476776123, 0.17871222076215507)
(3.4578545093536377, 0.17871222076215507)
(3.334017753601074, 0.17871222076215507)
(3.258016586303711, 0.17871222076215507)
(3.212472915649414, 0.17871222076215507)


In [7]:
for data, label in test_loader:
    break

In [8]:
data

tensor([[20, 23, 68,  ..., 45,  4, 68],
        [23, 68, 45,  ...,  4, 68, 44],
        [68, 45,  4,  ..., 68, 44, 64],
        ...,
        [58, 13, 72,  ..., 24,  1, 65],
        [13, 72, 13,  ...,  1, 65,  3],
        [72, 13,  0,  ..., 65,  3, 58]])

In [9]:
label

tensor([44, 64, 35, 24, 13,  0, 65, 24,  3,  2, 64, 63, 24, 68, 45, 64, 13, 65,
        24, 13, 41, 68,  4, 13, 24,  2, 68,  4,  4, 65, 24,  1, 16, 13, 65, 42,
        13, 65,  4, 13, 24, 66, 64,  2,  2, 13, 24, 23, 13, 24,  4, 45, 23, 44,
        21,  2, 69, 13, 63, 24, 13,  4,  1, 23, 68, 13, 24,  3, 64, 63, 22, 35,
        24,  4, 13, 23, 42, 58, 13, 72, 13, 25, 65,  2,  2, 13,  1, 64, 44, 20,
        64, 42, 23, 45,  4,  1, 13, 64, 63,  4, 13, 44, 23, 13, 42, 65, 24,  1,
        13, 23, 42, 13, 49, 23, 45,  4, 68, 44, 64, 35, 24, 13, 41, 68,  4,  4,
        35, 45, 65, 44, 21, 58, 13, 10, 45, 65,  4,  1, 64, 64, 16, 13, 63,  2,
         2, 23, 25, 13,  4,  1, 64, 13, 25, 65, 44,  0, 58, 13, 72, 44,  0, 64,
        64,  0, 16, 13, 24, 65, 45, 16, 13, 65, 42, 13, 69, 23, 68, 45, 13, 66,
        64,  4, 63,  3,  1, 23, 45, 13, 24,  4, 65, 44, 22, 16, 13, 72, 13, 25,
        65,  2,  2, 13, 24,  4, 23,  3, 13, 66, 69, 13, 44, 23, 24, 64, 39, 13,
        23, 45, 13, 63, 21, 63, 65, 44, 