In [11]:
import torch
import torch.nn.functional as F
from torch import nn
from utils.fmodule import FModule, get_module_from_model
import utils.fmodule as fmodule

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
        
class DNN_proposal(FModule):
    def __init__(self, input_dim = 784, mid_dim = 100, output_dim = 10):
        super().__init__()
        # define network layers
        self.fc1 = nn.Linear(input_dim, mid_dim)
        self.fc2 = nn.Linear(mid_dim, output_dim)
        # mask regenerator
        self.mg_fc1 = nn.Linear(mid_dim, 128)
        self.mg_fc2 = nn.Linear(128, output_dim)
        self.apply(init_weights)
        
    def  __call__(self, x, original_mask_diagonal=None):            
        return self.forward(x, original_mask_diagonal)
    
    def forward(self, x, original_mask_diagonal=None):
        r_x = self.encoder(x).detach()
        l_x = self.decoder(r_x)
        dm_x = self.mask_diagonal_regenerator(r_x).detach()
        dm_x = (dm_x > 1/10) * 1.
        m_x = torch.diag_embed(dm_x)
        
        if original_mask_diagonal is None:
            """ When inference """
            suro_l_x = l_x
        else:
            """ When training """
            suro_l_x = self.surogate_logits(l_x, torch.diag_embed(original_mask_diagonal))
        
        mirr_suro_l_x = self.mirror_surogate_logits(suro_l_x, m_x)
        # output = F.log_softmax(mirr_suro_l_x, dim=1)
        output = mirr_suro_l_x
        return output
    
    def mask_diagonal(self, x):
        """
        This function returns the mask's diagonal vector of x
        """
        r_x = self.encoder(x).detach()
        dm_x = self.mask_diagonal_regenerator(r_x)
        return dm_x
    
    def encoder(self, x):
        """
        This function returns the representation of x
        """
        r_x = torch.flatten(x, 1)
        r_x = torch.sigmoid(self.fc1(r_x))
        return r_x
    
    def decoder(self, r_x):
        """
        This function returns the logits of r_x
        """
        l_x = self.fc2(r_x)
        return l_x
    
    def mask_diagonal_regenerator(self, r_x):
        """
        This function generate a mask's diagonal vector for each element in r_x,
        returning shape of b x 10
        """
        dm_x = F.relu(self.mg_fc1(r_x))
        dm_x = torch.softmax(self.mg_fc2(dm_x), dim=1)
        dm_x = dm_x.view(r_x.shape[0], 10)
        return dm_x
           
    def surogate_logits(self, l_x, original_mask):
        """
        Args:
            l_x             : b x 10
            original_mask   : 10 x 10
        
        This function return the logits that are masked,
        the returning shape b x 10
        """
        l_x = l_x.unsqueeze(2)
        suro_l_x = (original_mask * 1.0) @ l_x
        return suro_l_x.squeeze(2)
    
    def mirror_surogate_logits(self, suro_l_x, m_x):
        """
        Args:
            suro_l_x: b x 10
            m_x     : b x 10 x 10
        
        This function perform dot multiplication of m_x and suro_l_x,
        returning the matrix of shape b x 10
        """
        mirr_suro_l_x = m_x @ suro_l_x.unsqueeze(2)
        return mirr_suro_l_x.squeeze(2)

In [12]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from utils.dataloader import CustomDataset
import json
from utils.train_smt import NumpyEncoder, batch_similarity, print_cfmtx, test
import numpy as np

training_data = datasets.MNIST(
    root="../data",
    train=True,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
)
    
testing_data = datasets.MNIST(
    root="../data",
    train=False,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
)

client_id_list = [0,1,2,3,4]
clients_dataset = [CustomDataset(training_data, json.load(open(f"./jsons/client{client_id}.json", 'r'))) for client_id in client_id_list]

In [13]:
def create_mask_diagonal(dim, dataset):
    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)
    mask = torch.zeros([dim])
    for X, y in train_dataloader:
        label = y.item()
        mask[label] = 1
    return mask

### Representation training

In [14]:
def representation_training(dataloader, model, loss_fn, optimizer, device="cuda:1"):
    """
    This method trains for a discriminative representation space,
    using constrastive learning
    
    Args:
        dataloader: batch_size of 2, drop_last = True
        loss_fn:    mean square error
    """
    model = model.to(device)
    model.train()
    same_class_dis = []
    different_class_dis = []
    
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        representations = model.encoder(X)
        
        alpha = 1.0 if y[0].item() == y[1].item() else -1.0
        distance = loss_fn(representations[0], representations[1])
        loss = alpha * distance
                
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if alpha > 0:
            same_class_dis.append(distance.detach().item())
        else:
            different_class_dis.append(distance.detach().item())
            
    return np.mean(same_class_dis), np.mean(different_class_dis)

In [16]:
client_id = 0
mydataset = clients_dataset[client_id]
clients_mask_diagonal = [None for client_id in client_id_list]

if clients_mask_diagonal[client_id] is None:
    clients_mask_diagonal[client_id] = create_mask_diagonal(dim=10, dataset=mydataset)
    
train_dataloader = DataLoader(mydataset, batch_size=2, shuffle=True, drop_last=True)
loss_fn = torch.nn.MSELoss()

model = DNN_proposal()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for t in range(8):
    same, diff = representation_training(train_dataloader, model, loss_fn, optimizer, "cuda")
    print("Epochs", t, "Same: ", same, "Diff", diff)

Epochs 0 Same:  0.053941293309132256 Diff 0.0785806675752004
Epochs 1 Same:  0.045783394326766334 Diff 0.10896965861320496
Epochs 2 Same:  0.0423656408675015 Diff 0.22712791711091995
Epochs 3 Same:  0.03296972004075845 Diff 0.21574803193410239
Epochs 4 Same:  0.026166454330086707 Diff 0.268422394990921
Epochs 5 Same:  0.004615426994860172 Diff 0.42212681770324706
Epochs 6 Same:  0.005526202265173197 Diff 0.5099455356597901
Epochs 7 Same:  0.005472602788358927 Diff 0.5869247317314148


### Mask training

In [28]:
def mask_training(dataloader, model, optimizer, original_mask_diagonal, device="cuda:1"):
    """
    This method trains to make the model generate a mask that is
    close to the original mask
    """
    original_mask_diagonal = original_mask_diagonal.to(device)
    model = model.to(device)
    model.train()
    
    losses = []
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        mirr_mask_diagonal = model.mask_diagonal(X)
        mask_loss = torch.sum(torch.pow(mirr_mask_diagonal - original_mask_diagonal, 2))/mirr_mask_diagonal.shape[0]
        loss = mask_loss
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        mirr_mask_diagonal = (mirr_mask_diagonal > 1/10) * 1.0
        loss = torch.sum(torch.abs(mirr_mask_diagonal - original_mask_diagonal))/mirr_mask_diagonal.shape[0]
        losses.append(loss.item())
        
    # print("Masking losses", losses)
    return np.mean(losses), mirr_mask_diagonal[0]

In [29]:
import copy
model2 = copy.deepcopy(model)
train_dataloader = DataLoader(mydataset, batch_size=4, shuffle=True, drop_last=False)
mask_optimizer = torch.optim.Adam(model2.parameters(), lr=1e-3)
# Train the mask first
epoch_loss = []
print("True mask diag", clients_mask_diagonal[client_id].tolist())
for t in range(8):
    mask_loss, mirr_mask_diagonal = mask_training(train_dataloader, model2, mask_optimizer, clients_mask_diagonal[client_id], "cuda")
    print("Epochs", t, "mask_loss: ", mask_loss)
    # print("Masking loss", loss.detach().tolist())

True mask diag [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Epochs 0 mask_loss:  4.8125
Epochs 1 mask_loss:  2.0
Epochs 2 mask_loss:  0.3125
Epochs 3 mask_loss:  0.5
Epochs 4 mask_loss:  0.0625
Epochs 5 mask_loss:  0.0
Epochs 6 mask_loss:  0.0
Epochs 7 mask_loss:  0.0


### Classifier training

In [21]:
def classification_training(dataloader, model, loss_fn, optimizer, original_mask, device="cuda:1"):
    original_mask = original_mask.to(device)
    model = model.to(device)
    model.train()
    losses = []
        
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X, original_mask)
        classification_loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        classification_loss.backward()
        optimizer.step()
        losses.append(classification_loss.item())
    
    # print("classification losses", losses)
    return np.mean(losses)

In [22]:
model3 = copy.deepcopy(model2)
train_dataloader = DataLoader(mydataset, batch_size=4, shuffle=True, drop_last=False)
optimizer = torch.optim.Adam(model3.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
# Train the mask first
epoch_loss = []
for t in range(8):
    loss = classification_training(train_dataloader, model3, loss_fn, optimizer, clients_mask_diagonal[client_id], "cuda")
    print("Epochs", t, "loss: ", loss)

Epochs 0 loss:  2.008186399936676
Epochs 1 loss:  1.8600953221321106
Epochs 2 loss:  1.715594321489334
Epochs 3 loss:  1.6210212111473083
Epochs 4 loss:  1.4336721301078796
Epochs 5 loss:  1.3050436675548553
Epochs 6 loss:  1.2295474410057068
Epochs 7 loss:  1.1053496301174164
