In [1]:
from dotenv import load_dotenv
load_dotenv()

from argparse import ArgumentParser
import warnings
from collections import OrderedDict
import json
import io
import os
import sys
import pickle
import base64
from traceback import print_exc

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data.sampler as sampler
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm
from time import sleep, time

import logging

from non_iid_generator.customDataset import CustomDataset

DEVICE = os.environ["TORCH_DEVICE"]


In [2]:
def load_data(train_dataset_path, test_dataset_path):
    """Load CIFAR-10 (training and test set)."""
    batch_size = 128

    train_data = pickle.load(open(train_dataset_path, "rb"))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True)

    test_data = pickle.load(open(test_dataset_path, "rb"))
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=True)
    
    return train_loader, test_loader

def fine_tune(model, no_epoch, train_loader, print_frequency=100):
    '''
        short-term fine-tune a simplified model
        
        Input:
            `model`: model to be fine-tuned.
            `iterations`: (int) num of short-term fine-tune iterations.
            `print_frequency`: (int) how often to print fine-tune info.
        
        Output:
            `model`: fine-tuned model.
    '''

    # Data loaders for fine tuning and evaluation.
    batch_size = 128
    momentum = 0.9
    weight_decay = 1e-4
    finetune_lr = 0.001

    # train_loader, val_loader = load_data(train_dataset_path, test_dataset_path)

    criterion = torch.nn.BCEWithLogitsLoss()
    
    _NUM_CLASSES = 10
    optimizer = torch.optim.SGD(
        model.parameters(),
        finetune_lr, 
        momentum=momentum,
        weight_decay=weight_decay)

    model = model.to(DEVICE)
    model.train()
    # dataloader_iter = iter(train_loader)
    print("Fine tuning started.")
    for i in range(no_epoch):
        if i % print_frequency == 0:
            print('Fine-tuning Epoch {}'.format(i))
            sys.stdout.flush()
        for i, (input, target) in enumerate(tqdm(train_loader)):

            # (input, target) = next(dataloader_iter)
            
            # Ensure the target shape is sth like torch.Size([batch_size])
            if len(target.shape) > 1: target = target.reshape(len(target))

            target.unsqueeze_(1)
            target_onehot = torch.FloatTensor(target.shape[0], _NUM_CLASSES)
            target_onehot.zero_()
            target_onehot.scatter_(1, target, 1)
            target.squeeze_(1)
            input, target = input.to(DEVICE), target.to(DEVICE)
            target_onehot = target_onehot.to(DEVICE)

            pred = model(input)
            loss = criterion(pred, target_onehot)
            optimizer.zero_grad()
            loss.backward()  # compute gradient and do SGD step
            optimizer.step()

            del loss, pred

    return model

######################
######################

def load_data_cifar(dataset_path = "./data/"):
    batch_size = 128

    train_dataset = datasets.CIFAR10(root=dataset_path, train=True, download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4), 
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size, 
        pin_memory=True,
        shuffle=True)#, sampler=train_sampler)


    val_dataset = datasets.CIFAR10(root=dataset_path, train=False, download=True,
    transform=transforms.Compose([
        transforms.Resize(224), 
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]))
    test_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True) #, sampler=valid_sampler)
    return train_loader, test_loader
    

In [3]:
client_id = 0

train_dataset_path = f"./data/Cifar10/train/{client_id}.pkl"
test_dataset_path = f"./data/Cifar10/test/{client_id}.pkl"
trainloader, testloader = load_data(train_dataset_path, test_dataset_path)

In [4]:
cifartrainloader, cifartestloader = load_data_cifar()

Files already downloaded and verified
Files already downloaded and verified


In [5]:
model = torch.load("models/alexnet/model_cpu.pth.tar")

In [6]:
fine_tune(model, 5, trainloader, print_frequency=1)

Fine tuning started.
Fine-tuning Epoch 0


 43%|████▎     | 6/14 [00:05<00:06,  1.14it/s]


KeyboardInterrupt: 

In [7]:
fine_tune(model, 5, cifartrainloader, print_frequency=1)


Fine tuning started.
Fine-tuning Epoch 0


 42%|████▏     | 166/391 [02:24<03:19,  1.13it/s]

KeyboardInterrupt: 

In [3]:
train_dataset_path = f"./data/Cifar10/server/train.pkl"
test_dataset_path = f"./data/Cifar10/server/test.pkl"
trainloader, testloader = load_data(train_dataset_path, test_dataset_path)

In [2]:
import os
import pickle
import torch
from non_iid_generator.customDataset import CustomDataset


train_dataset_path = os.path.join("./data/32_Cifar10_NIID_56c_a03", "train", "1.pkl")
train_data = pickle.load(open(train_dataset_path, "rb"))
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=128,
    shuffle=True)

In [None]:
len(train_data)