# Imports

In [1]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, sampler

from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Normalize, Compose

import torchmetrics

from torchinfo import summary

from sklearn.model_selection import train_test_split

import numpy as np

from tqdm import tqdm

# Model Definition

In [2]:
## Residual Layer
class ResidualLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(ResidualLayer, self).__init__()
        self.ConvBlock1 = nn.Sequential(
            nn.Conv2d(in_channels=in_dim, out_channels=out_dim//2, kernel_size=3, padding='same', bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.ConvBlock2 = nn.Sequential(
            nn.Conv2d(in_channels=out_dim//2, out_channels=out_dim//2, kernel_size=3, padding='same', bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.Shortcut1 = nn.Sequential(
            nn.Conv2d(in_channels=in_dim, out_channels=out_dim//2, kernel_size=1, bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.Shortcut2 = nn.Sequential(
            nn.Conv2d(in_channels=out_dim//2, out_channels=out_dim//2, kernel_size=1, bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.DownBlock = nn.Sequential(
            nn.Conv2d(in_channels=out_dim//2, out_channels=out_dim, kernel_size=3, bias=False),
            nn.BatchNorm2d(num_features=out_dim),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        y = self.ConvBlock1(x)
        x = F.relu(self.Shortcut1(x) + y)
        y = self.ConvBlock2(x)
        x = F.relu(self.Shortcut2(x) + y)
        x = self.DownBlock(x)
        return x


In [3]:
## Feature Extractor
class FE(nn.Module):
    def __init__(self, depths=[3, 32, 64]) -> None:
        super(FE, self).__init__()
        self.features = nn.Sequential(
            ResidualLayer(depths[0], depths[1]),
            ResidualLayer(depths[1], depths[2])
        )
        
    def forward(self, x):
        return self.features(x)

In [4]:
summary(FE(),input_size=(1, 3, 32, 32), depth=4)

Layer (type:depth-idx)                   Output Shape              Param #
FE                                       [1, 64, 6, 6]             --
├─Sequential: 1-1                        [1, 64, 6, 6]             --
│    └─ResidualLayer: 2-1                [1, 32, 15, 15]           --
│    │    └─Sequential: 3-1              [1, 16, 32, 32]           --
│    │    │    └─Conv2d: 4-1             [1, 16, 32, 32]           432
│    │    │    └─BatchNorm2d: 4-2        [1, 16, 32, 32]           32
│    │    └─Sequential: 3-2              [1, 16, 32, 32]           --
│    │    │    └─Conv2d: 4-3             [1, 16, 32, 32]           48
│    │    │    └─BatchNorm2d: 4-4        [1, 16, 32, 32]           32
│    │    └─Sequential: 3-3              [1, 16, 32, 32]           --
│    │    │    └─Conv2d: 4-5             [1, 16, 32, 32]           2,304
│    │    │    └─BatchNorm2d: 4-6        [1, 16, 32, 32]           32
│    │    └─Sequential: 3-4              [1, 16, 32, 32]           --
│    │    │

In [5]:
## Classifier
class CLF(nn.Module):
    def __init__(self, depths=[64, 96], num_classes=10) -> None:
        super(CLF, self).__init__()
        self.features = ResidualLayer(depths[0], depths[1])
        self.classfier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=2*2*depths[1], out_features=num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        return self.classfier(x)

In [6]:
summary(CLF(),input_size=(1, 64, 6, 6))

Layer (type:depth-idx)                   Output Shape              Param #
CLF                                      [1, 10]                   --
├─ResidualLayer: 1-1                     [1, 96, 2, 2]             --
│    └─Sequential: 2-1                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-1                  [1, 48, 6, 6]             27,648
│    │    └─BatchNorm2d: 3-2             [1, 48, 6, 6]             96
│    └─Sequential: 2-2                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-3                  [1, 48, 6, 6]             3,072
│    │    └─BatchNorm2d: 3-4             [1, 48, 6, 6]             96
│    └─Sequential: 2-3                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-5                  [1, 48, 6, 6]             20,736
│    │    └─BatchNorm2d: 3-6             [1, 48, 6, 6]             96
│    └─Sequential: 2-4                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-7                  [1, 48, 6, 6]             2,304
│

In [7]:
## Attacker
class ATK(nn.Module):
    def __init__(self, depths=[64, 96], widths=[128, 64]) -> None:
        super(ATK, self).__init__()
        self.features_fe = ResidualLayer(depths[0], depths[1])
        self.features_shadow = ResidualLayer(depths[0], depths[1])
        self.classfier_fe = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=2*2*depths[1], out_features=widths[0]),
            nn.ReLU()
        )
        self.classfier_shadow = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=2*2*depths[1], out_features=widths[0]),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=2*widths[0], out_features=widths[1]),
            nn.ReLU(),
            nn.Linear(in_features=widths[1], out_features=1)
        )
        
    def forward(self, f_shadow, f_actual):
        f_shadow = self.features_shadow(f_shadow)
        f_actual = self.features_fe(f_actual)
        f_shadow = self.classfier_fe(f_shadow)
        f_actual = self.classfier_shadow(f_actual)
        return self.classifier(torch.cat([f_shadow, f_actual], dim=1))

In [8]:
summary(ATK(),input_data=(torch.randn(1, 64, 6, 6), torch.randn(1, 64, 6, 6)))

Layer (type:depth-idx)                   Output Shape              Param #
ATK                                      [1, 1]                    --
├─ResidualLayer: 1-1                     [1, 96, 2, 2]             --
│    └─Sequential: 2-1                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-1                  [1, 48, 6, 6]             27,648
│    │    └─BatchNorm2d: 3-2             [1, 48, 6, 6]             96
│    └─Sequential: 2-2                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-3                  [1, 48, 6, 6]             3,072
│    │    └─BatchNorm2d: 3-4             [1, 48, 6, 6]             96
│    └─Sequential: 2-3                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-5                  [1, 48, 6, 6]             20,736
│    │    └─BatchNorm2d: 3-6             [1, 48, 6, 6]             96
│    └─Sequential: 2-4                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-7                  [1, 48, 6, 6]             2,304
│

# Data Preparation

In [9]:
tfms = Compose([
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 256
batch_size_shadow = 2*batch_size

clf_trainds = CIFAR10(root='PyTorch-StudioGAN/data/', transform=tfms)

shadow_testds = CIFAR10(root='PyTorch-StudioGAN/data/', train=False, transform=tfms)
shadow_train_dl = DataLoader(clf_trainds, batch_size=batch_size_shadow, shuffle=True)
shadow_test_dl = DataLoader(shadow_testds, batch_size=batch_size_shadow, shuffle=False)

clf_train_idx, clf_test_idx = train_test_split(
    np.arange(len(clf_trainds.targets)), test_size=0.5, shuffle=True, stratify=clf_trainds.targets
)
train_sampler = sampler.SubsetRandomSampler(clf_train_idx)
test_sampler = sampler.SubsetRandomSampler(clf_test_idx)

clf_train_dl = DataLoader(clf_trainds, batch_size=batch_size, sampler=train_sampler)
clf_test_dl = DataLoader(clf_trainds, batch_size=batch_size, sampler=test_sampler)

atk_trainds = CIFAR10(root='PyTorch-StudioGAN/data/', transform=tfms)
atk_targets = np.asarray(atk_trainds.targets)
atk_targets[clf_train_idx.tolist()] = 1
atk_targets[clf_test_idx.tolist()] = 0
atk_trainds.targets = atk_targets.tolist()

atk_train_idx, atk_test_idx = train_test_split(
    np.arange(len(atk_trainds.targets)), test_size=0.2, shuffle=True, stratify=atk_trainds.targets
)
atk_train_sampler = sampler.SubsetRandomSampler(atk_train_idx)
atk_test_sampler = sampler.SubsetRandomSampler(atk_test_idx)

atk_train_dl = DataLoader(atk_trainds, batch_size=batch_size, sampler=atk_train_sampler)
atk_test_dl = DataLoader(atk_trainds, batch_size=batch_size, sampler=atk_test_sampler)

# Training

In [10]:
def freeze(model):
    for param in model.parameters():
        param.requires_grad_(False)
    model.eval()

def unfreeze(model):
    for param in model.parameters():
        param.requires_grad_(True)
    model.train()

In [11]:
@torch.no_grad()
def eval_atk(fe, shadow_fe, atk, atk_test_dl, atk_criterion, device):
    fe.eval()
    shadow_fe.eval()
    atk.eval()
    acc = torchmetrics.Accuracy().to(device)
    loss = []
    for (X, y) in atk_test_dl:
        X, y = X.to(device), y.to(device)
        features_actual = fe(X)
        features_shadow = shadow_fe(X)
        atk_y = atk(features_shadow, features_actual).squeeze()
        loss.append(atk_criterion(atk_y, y.float()).item())
        acc(atk_y, y)
    loss = np.asarray(loss).mean()
    print(f'Adversary Loss: {loss} | Adversary Accuracy: {acc.compute()}')

In [12]:
@torch.no_grad()
def eval_clf(fe, clf, clf_test_dl, clf_criterion, device):
    fe.eval()
    clf.eval()
    acc = torchmetrics.Accuracy().to(device)
    loss = []
    for (X, y) in clf_test_dl:
        X, y = X.to(device), y.to(device)
        features = fe(X)
        clf_y = clf(features)
        loss.append(clf_criterion(clf_y, y).item())
        acc(clf_y, y)
    loss = np.asarray(loss).mean()
    print(f'Classifier Loss: {loss} | Classifier Accuracy: {acc.compute()}')

In [13]:
def train(clf_train_dl, clf_test_dl, atk_train_dl, atk_test_dl, shadow_train_dl, shadow_test_dl, n_epochs, steps_per_epoch = 5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    feature_extractor = FE().to(device)
    adversary = ATK().to(device)
    classifier = CLF().to(device)
    shadow_feature_extractor = FE().to(device)
    shadow_feature_extractor.load_state_dict(feature_extractor.state_dict())
    shadow_classifier = CLF().to(device)
    shadow_classifier.load_state_dict(classifier.state_dict())

    atk_criterion = nn.BCEWithLogitsLoss().to(device)
    clf_criterion = nn.CrossEntropyLoss().to(device)

    fe_optim = optim.AdamW(feature_extractor.parameters())
    atk_optim = optim.AdamW(adversary.parameters())
    clf_optim = optim.AdamW(classifier.parameters())
    shadow_fe_optim = optim.AdamW(shadow_feature_extractor.parameters())
    shadow_clf_optim = optim.AdamW(shadow_classifier.parameters())
    
    for epoch in range(n_epochs):
        print(f'EPOCH [{epoch+1}/{n_epochs}]:')
        # TODO: Clean train steps
        ## train FE+CLF
        print('Training FE+CLF...')
        bar = tqdm(range(steps_per_epoch))
        for i in bar:
            for (X, y) in clf_train_dl:
                X, y = X.to(device), y.to(device)
                fe_optim.zero_grad()
                clf_optim.zero_grad()
                features = feature_extractor(X)
                predictions = classifier(features)
                loss = clf_criterion(predictions, y)
                loss.backward()
                fe_optim.step()
                clf_optim.step()
            bar.set_postfix_str(f'[Loss at epoch {i+1}: {loss.item():4f}]', refresh=True)
        print('FE+CLF evaluation:')
        eval_clf(feature_extractor, classifier, clf_test_dl, clf_criterion, device)

        ## train shadow FE
        print('Training Shadow FE+Shadow CLF...')
        bar = tqdm(range(steps_per_epoch))
        for i in bar:
            for (X, y) in shadow_train_dl:
                X, y = X.to(device), y.to(device)
                shadow_fe_optim.zero_grad()
                shadow_clf_optim.zero_grad()
                features = shadow_feature_extractor(X)
                predictions = shadow_classifier(features)
                loss = clf_criterion(predictions, y)
                loss.backward()
                shadow_fe_optim.step()
                shadow_clf_optim.step()
            bar.set_postfix_str(f'[Loss at epoch {i+1}: {loss.item():4f}]', refresh=True)
        print('Shadow FE+Shadow CLF evaluation:')
        eval_clf(shadow_feature_extractor, shadow_classifier, shadow_test_dl, clf_criterion, device)

        ## train ATK
        print('Training ATK...')
        freeze(feature_extractor)
        freeze(shadow_feature_extractor)
        bar = tqdm(range(steps_per_epoch//2))
        for i in bar:
            for (X, y) in atk_train_dl:
                X, y = X.to(device), y.to(device).float()
                atk_optim.zero_grad()
                features_actual = feature_extractor(X)
                features_shadow = shadow_feature_extractor(X)
                adversary_preds = adversary(features_shadow, features_actual).squeeze()
                loss_atk = atk_criterion(adversary_preds, y)    
                loss_atk.backward()
                atk_optim.step()
            bar.set_postfix_str(f'[Loss at epoch {i+1}: {loss_atk.item():4f}]', refresh=True)
        unfreeze(feature_extractor)
        unfreeze(shadow_feature_extractor)
        
        print('Adversary Evaluation:')
        eval_atk(feature_extractor, shadow_feature_extractor, adversary, atk_test_dl, atk_criterion, device)
    return feature_extractor, shadow_feature_extractor, adversary, classifier


In [14]:
fe, shadow_fe, adv, clf = train(clf_train_dl, clf_test_dl, atk_train_dl, atk_test_dl, shadow_train_dl, shadow_test_dl, 1, 50)

EPOCH [1/1]:
Training FE+CLF...


100%|██████████| 50/50 [03:25<00:00,  4.10s/it, [Loss at epoch 50: 0.000629]]


FE+CLF evaluation:
Classifier Loss: 1.1783652281274601 | Classifier Accuracy: 0.7829999923706055
Training Shadow FE+Shadow CLF...


100%|██████████| 50/50 [06:09<00:00,  7.38s/it, [Loss at epoch 50: 0.000461]]


Shadow FE+Shadow CLF evaluation:
Classifier Loss: 1.0214308142662047 | Classifier Accuracy: 0.8202000260353088
Training ATK...


100%|██████████| 25/25 [02:48<00:00,  6.75s/it, [Loss at epoch 25: 0.057444]]


Adversary Evaluation:
Adversary Loss: 3.3722291350364686 | Adversary Accuracy: 0.49549999833106995
