In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import os
import argparse
from tqdm import tqdm
from models.resnet import ResNet50
from utils.reproducibility import make_it_reproducible, seed_worker
from utils.fedavg_utils import get_datasets

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# reproducibility
seed = 0

g = torch.Generator()
g.manual_seed(seed)

In [None]:
# setting parameters
EPOCHS = 25

In [None]:
# datasets and loaders
trainset, testset = get_datasets()
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=128, shuffle=True, num_workers=2,
                                          worker_init_fn=seed_worker, generator=g)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=100, shuffle=False, num_workers=2,
                                         worker_init_fn=seed_worker, generator=g)

In [None]:
net = ResNet50()
# net = ResNet50("Group Norm")
net = net.to(device)

net.train()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=0.1e-2, momentum=0.5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[20,25], gamma=0.1)

In [None]:
make_it_reproducible(seed)

accuracies = []
losses = []

In [None]:
def train():
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

In [None]:
def test():
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

        test_loss = loss.item()
        _, predicted = outputs.max(1)
        total = targets.size(0)
        correct = predicted.eq(targets).sum().item()
        accuracies.append(correct/total)
        losses.append(test_loss)

In [None]:
for epoch in tqdm(range(EPOCHS)):
    train()
    test()
    scheduler.step()