In [None]:
import os
import time
import copy
import glob
import torch
import shutil
import pathlib
import torchvision

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.models as models
from torch.utils.data import Dataset
import torchvision.transforms as transforms

from PIL import Image

In [None]:
class ImageFolderCustom(Dataset):
    
    def __init__(self, targ_dir, transform, epsilon=None):
        self.paths = list(pathlib.Path(targ_dir).glob("*.png"))
        self.transform = transform
        self.epsilon = epsilon
        
    def load_image(self, index):
        color = 0
        
        image_path = self.paths[index]
        image_path_copy = image_path

        _,_,_,file_name = str(image_path_copy).split("/")
        
        parsed_file_name = file_name.split("-")
        if int(parsed_file_name[0]) != 0:
            color = 1
            
        label = int(parsed_file_name[1])
        
        if self.epsilon:
            random_label = GRR_Client(label, 10, self.epsilon)
        else:
            random_label = label

        return Image.open(image_path), color, random_label
    
    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index):
        img, color, class_idx = self.load_image(index) # 0: red, 1: blue

        # Transform if necessary
        if self.transform:
            return self.transform(img), class_idx, color
        else:
            return img, class_idx, color

In [None]:
def GRR_Client(input_data, k, epsilon):

    p = np.exp(epsilon) / (np.exp(epsilon) + k - 1)
    domain = np.arange(k) 

    if np.random.binomial(1, p) == 1:
        return input_data

    else:
        return np.random.choice(domain[domain != input_data])

In [None]:
def get_p(epsilon):
    prob_stay = np.exp(epsilon) / (1 + np.exp(epsilon))
    prob_change = 1 / (1 + np.exp(epsilon))
    
    p = torch.zeros(10,10)
    for i in range(10):
        for j in range(10):
            if i == j: 
                p[i][j] = prob_stay
            else:
                p[i][j] = prob_change
    
    return p

In [None]:
def set_parameter_requires_grad(model, feature_extracting=True):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
def train_model(model, dataloaders, criterion, optimizer, device, epsilon, num_epochs=100, is_train=True):
    since = time.time()
    acc_history = []
    loss_history = []
    best_acc = 0.0

    sm = torch.nn.Softmax(dim=1)
    
    if epsilon:
        p = get_p(epsilon)
    
    for epoch in range(num_epochs):        
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels, race in dataloaders:            
            inputs = inputs.to(device)
            labels = labels.to(device)
            model.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            outputs_prob = sm(outputs)
            
            if epsilon:
                flc_output = torch.matmul(torch.log(outputs_prob), p.to(device))
                loss = criterion(flc_output, labels)
            else:
                loss = criterion(outputs_prob, labels)
                
            _, preds = torch.max(outputs, 1)
            
            # backward
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(dataloaders.dataset)
        epoch_acc = running_corrects.double() / len(dataloaders.dataset)
        
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            
        acc_history.append(epoch_acc.item())
        loss_history.append(epoch_loss)
        
        torch.save(model.state_dict(), os.path.join('working/cimnist/resnet/', '{0:0=2d}.pth'.format(epoch)))
        
    time_elapsed = time.time() - since
    
    return acc_history, loss_history
    

In [None]:
def eval_model(model, dataloaders, device):
    since = time.time()
    
    num_red, num_blue = 0, 0
    
    correct_red, correct_blue = 0, 0
    
    acc_history = []
    best_acc = 0.0
    best_model = None
    
    saved_models = glob.glob('working/cimnist/resnet/' + '*.pth')
    saved_models.sort()
    
    for model_path in saved_models:

        model.load_state_dict(torch.load(model_path))
        model.eval()
        model.to(device)

        running_corrects = 0

        # Iterate over data.
        for inputs, labels, race in dataloaders:
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                outputs = model(inputs)

            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

        epoch_acc = running_corrects.double() / len(dataloaders.dataset)
        
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model = model_path
            
        acc_history.append(epoch_acc.item())

    model.load_state_dict(torch.load(best_model))
    model.eval()
    model.to(device)
    
    for inputs, labels, color in dataloaders:
        num_blue += sum(color)
        num_red += len(color)-sum(color)
            
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            outputs = model(inputs)
        
        _, preds = torch.max(outputs, 1)
        
        for p, t, c in zip(preds, labels.data, color):
            if p == t:
                if c == 0:
                    correct_red += 1
                else:
                    correct_blue += 1
        
    
    time_elapsed = time.time() - since
    print('Validation complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best Acc: {:.4f}'.format(best_acc))
    print('Red Acc: {:.4f}'.format(correct_red/num_red), 'Blue Acc: {:.4f}'.format(correct_blue/num_blue))
    return acc_history

In [None]:
epsilons = [None, .001, .01, .1, .25, .5, 1, 2, 5] #None gives original baseline

for e in epsilons:
    for ro in range(5):
        print('========================================================')
        print('Epsilon ', str(e), 'Round ', ro+1)
        print('========================================================')
        transform = transforms.Compose([transforms.ToTensor()])


        train_data = ImageFolderCustom(targ_dir='data/cimnist/train/', 
                                       transform=transform,
                                       epsilon=e)

        test_data = ImageFolderCustom(targ_dir='data/cimnist/test/', 
                                      transform=transform,
                                      epsilon=None)

        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=256,
                                                   num_workers=8,
                                                   shuffle=True)

        test_loader = torch.utils.data.DataLoader(test_data,
                                                   batch_size=256,
                                                   num_workers=8,
                                                   shuffle=True)


        resnet18 = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        set_parameter_requires_grad(resnet18)
        resnet18.fc = nn.Linear(512, 10)

        params_to_update = []
        for name,param in resnet18.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)

        optimizer = optim.Adam(params_to_update, lr=.0001)


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Setup the loss function
        criterion = nn.NLLLoss()

        # Train model
        train_acc_hist, train_loss_hist = train_model(resnet18, train_loader, criterion, optimizer, device, e)
        val_acc_hist = eval_model(resnet18, test_loader, device)