# Imports

In [1]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, sampler

from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Normalize, Compose

import torchmetrics

from torchinfo import summary

from sklearn.model_selection import train_test_split

import numpy as np

from tqdm import tqdm

# Model Definition

In [2]:
## Residual Layer
class ResidualLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(ResidualLayer, self).__init__()
        self.ConvBlock1 = nn.Sequential(
            nn.Conv2d(in_channels=in_dim, out_channels=out_dim//2, kernel_size=3, padding='same', bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.ConvBlock2 = nn.Sequential(
            nn.Conv2d(in_channels=out_dim//2, out_channels=out_dim//2, kernel_size=3, padding='same', bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.Shortcut1 = nn.Sequential(
            nn.Conv2d(in_channels=in_dim, out_channels=out_dim//2, kernel_size=1, bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.Shortcut2 = nn.Sequential(
            nn.Conv2d(in_channels=out_dim//2, out_channels=out_dim//2, kernel_size=1, bias=False),
            nn.BatchNorm2d(num_features=out_dim//2)
        )
        self.DownBlock = nn.Sequential(
            nn.Conv2d(in_channels=out_dim//2, out_channels=out_dim, kernel_size=3, bias=False),
            nn.BatchNorm2d(num_features=out_dim),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        y = self.ConvBlock1(x)
        x = F.relu(self.Shortcut1(x) + y)
        y = self.ConvBlock2(x)
        x = F.relu(self.Shortcut2(x) + y)
        x = self.DownBlock(x)
        return x


In [3]:
## Feature Extractor
class FE(nn.Module):
    def __init__(self, depths=[3, 32, 64]) -> None:
        super(FE, self).__init__()
        self.features = nn.Sequential(
            ResidualLayer(depths[0], depths[1]),
            ResidualLayer(depths[1], depths[2])
        )
        
    def forward(self, x):
        return self.features(x)

In [4]:
summary(FE(),input_size=(1, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
FE                                       [1, 64, 6, 6]             --
├─Sequential: 1-1                        [1, 64, 6, 6]             --
│    └─ResidualLayer: 2-1                [1, 32, 15, 15]           --
│    │    └─Sequential: 3-1              [1, 16, 32, 32]           464
│    │    └─Sequential: 3-2              [1, 16, 32, 32]           80
│    │    └─Sequential: 3-3              [1, 16, 32, 32]           2,336
│    │    └─Sequential: 3-4              [1, 16, 32, 32]           288
│    │    └─Sequential: 3-5              [1, 32, 15, 15]           4,672
│    └─ResidualLayer: 2-2                [1, 64, 6, 6]             --
│    │    └─Sequential: 3-6              [1, 32, 15, 15]           9,280
│    │    └─Sequential: 3-7              [1, 32, 15, 15]           1,088
│    │    └─Sequential: 3-8              [1, 32, 15, 15]           9,280
│    │    └─Sequential: 3-9              [1, 32, 15, 15]           1

In [5]:
## Classifier
class CLF(nn.Module):
    def __init__(self, depths=[64, 96], num_classes=10) -> None:
        super(CLF, self).__init__()
        self.features = ResidualLayer(depths[0], depths[1])
        self.classfier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=384, out_features=num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        return self.classfier(x)

In [6]:
summary(CLF(),input_size=(1, 64, 6, 6))

Layer (type:depth-idx)                   Output Shape              Param #
CLF                                      [1, 10]                   --
├─ResidualLayer: 1-1                     [1, 96, 2, 2]             --
│    └─Sequential: 2-1                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-1                  [1, 48, 6, 6]             27,648
│    │    └─BatchNorm2d: 3-2             [1, 48, 6, 6]             96
│    └─Sequential: 2-2                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-3                  [1, 48, 6, 6]             3,072
│    │    └─BatchNorm2d: 3-4             [1, 48, 6, 6]             96
│    └─Sequential: 2-3                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-5                  [1, 48, 6, 6]             20,736
│    │    └─BatchNorm2d: 3-6             [1, 48, 6, 6]             96
│    └─Sequential: 2-4                   [1, 48, 6, 6]             --
│    │    └─Conv2d: 3-7                  [1, 48, 6, 6]             2,304
│

In [7]:
# Attacker
class ATK(nn.Module):
    def __init__(self, num_classes=10, embed_dim=128) -> None:
        super(ATK, self).__init__()
        self.true_down = nn.Sequential(
            nn.Embedding(num_classes, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim//2),
            nn.ReLU()
        )
        self.pred_down = nn.Sequential(
            nn.Linear(num_classes, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim//2),
            nn.ReLU()
        )
        self.classfier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim//2),
            nn.ReLU(),
            nn.Linear(embed_dim//2, 1)
        )
        
    def forward(self, y_true, y_pred):
       fy1 = self.true_down(y_true)
       fy2 = self.pred_down(y_pred)
       fy = torch.cat([fy1, fy2], dim=-1)
       return self.classfier(fy)

In [8]:
summary(ATK(),input_data=(torch.randint(0, 10, (1,)), torch.randn(1, 10)))

Layer (type:depth-idx)                   Output Shape              Param #
ATK                                      [1, 1]                    --
├─Sequential: 1-1                        [1, 64]                   --
│    └─Embedding: 2-1                    [1, 128]                  1,280
│    └─ReLU: 2-2                         [1, 128]                  --
│    └─Linear: 2-3                       [1, 64]                   8,256
│    └─ReLU: 2-4                         [1, 64]                   --
├─Sequential: 1-2                        [1, 64]                   --
│    └─Linear: 2-5                       [1, 128]                  1,408
│    └─ReLU: 2-6                         [1, 128]                  --
│    └─Linear: 2-7                       [1, 64]                   8,256
│    └─ReLU: 2-8                         [1, 64]                   --
├─Sequential: 1-3                        [1, 1]                    --
│    └─Linear: 2-9                       [1, 64]                   8,256


# Data Preparation

In [9]:
class CustomTensorDataset(Dataset):
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        X = self.tensors[0][index]
        if self.transform is not None:
            X = self.transform(X)
        return X, tuple(self.tensors[i][index] for i in range(1, len(self.tensors)))
    
    def __len__(self):
        return self.tensors[0].shape[0]

In [10]:
tfms = Compose([
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
batch_size = 256

clf_trainds = CIFAR10(root='PyTorch-StudioGAN/data/', transform=tfms)

clf_train_idx, clf_test_idx = train_test_split(
    np.arange(len(clf_trainds.targets)), test_size=0.5, shuffle=True, stratify=clf_trainds.targets
)
train_sampler = sampler.SubsetRandomSampler(clf_train_idx)
test_sampler = sampler.SubsetRandomSampler(clf_test_idx)

clf_train_dl = DataLoader(clf_trainds, batch_size=batch_size, sampler=train_sampler)
clf_test_dl = DataLoader(clf_trainds, batch_size=batch_size, sampler=test_sampler)

y_true_atk = [t for t in clf_trainds.targets]
y_atk = np.zeros(len(clf_trainds), dtype=int)
y_atk[clf_train_idx.tolist()] = 1

atk_trainds = CustomTensorDataset(
    tensors=(clf_trainds.data, torch.tensor(y_true_atk), torch.tensor(y_atk)),
    transform=tfms
)

atk_train_idx, atk_test_idx = train_test_split(
    np.arange(len(atk_trainds.tensors[2].tolist())), test_size=0.2, shuffle=True, stratify=atk_trainds.tensors[2].tolist()
)
atk_train_sampler = sampler.SubsetRandomSampler(atk_train_idx)
atk_test_sampler = sampler.SubsetRandomSampler(atk_test_idx)

atk_train_dl = DataLoader(atk_trainds, batch_size=batch_size, sampler=atk_train_sampler)
atk_test_dl = DataLoader(atk_trainds, batch_size=batch_size, sampler=atk_test_sampler)

# Training

In [11]:
def freeze(model):
    for param in model.parameters():
        param.requires_grad_(False)
    model.eval()

def unfreeze(model):
    for param in model.parameters():
        param.requires_grad_(True)
    model.train()

In [12]:
def eval_atk(fe, atk, clf, atk_test_dl, atk_criterion, device):
    fe.eval()
    atk.eval()
    clf.eval()
    acc = torchmetrics.Accuracy().to(device)
    loss = []
    with torch.no_grad():
        for (X, (y1, y2)) in atk_test_dl:
            X, y1, y2 = X.to(device), y1.to(device), y2.to(device)
            features = fe(X)
            clf_y = clf(features)
            atk_y = atk(y1, clf_y).squeeze()
            loss.append(atk_criterion(atk_y, y2.float()).item())
            acc(atk_y, y2)
    loss = np.asarray(loss).mean()
    print(f'Adversary Loss: {loss} | Adversary Accuracy: {acc.compute()}')

In [13]:
def eval_clf(fe, clf, clf_test_dl, clf_criterion, device):
    fe.eval()
    clf.eval()
    acc = torchmetrics.Accuracy().to(device)
    loss = []
    with torch.no_grad():
        for (X, y) in clf_test_dl:
            X, y = X.to(device), y.to(device)
            features = fe(X)
            clf_y = clf(features)
            loss.append(clf_criterion(clf_y, y).item())
            acc(clf_y, y)
    loss = np.asarray(loss).mean()
    print(f'Classifier Loss: {loss} | Classifier Accuracy: {acc.compute()}')

In [14]:
def train(clf_train_dl, clf_test_dl, atk_train_dl, atk_test_dl, n_epochs, steps_per_epoch = 5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    feature_extractor = FE().to(device)
    adversary = ATK().to(device)
    classifier = CLF().to(device)

    atk_criterion = nn.BCEWithLogitsLoss().to(device)
    clf_criterion = nn.CrossEntropyLoss().to(device)

    fe_optim = optim.AdamW(feature_extractor.parameters())
    atk_optim = optim.AdamW(adversary.parameters())
    clf_optim = optim.AdamW(classifier.parameters())
    
    for epoch in range(n_epochs):
        print(f'EPOCH [{epoch+1}/{n_epochs}]:')
        # TODO: Clean train steps
        ## train FE+CLF
        print('Training FE+CLF...')
        for _ in tqdm(range(steps_per_epoch)):
            for (X, y) in clf_train_dl:
                X, y = X.to(device), y.to(device)
                fe_optim.zero_grad()
                clf_optim.zero_grad()
                features = feature_extractor(X)
                predictions = classifier(features)
                loss = clf_criterion(predictions, y)
                loss.backward()
                fe_optim.step()
                clf_optim.step()
        
        ## train FE+ATK
        print('Training FE+ATK...')
        freeze(classifier)
        for _ in tqdm(range(steps_per_epoch)):
            for (X, (y1, y2)) in atk_train_dl:
                X, y1, y2 = X.to(device), y1.to(device), y2.to(device).float()

                atk_optim.zero_grad()
                features = feature_extractor(X)
                classifier_preds = classifier(features)
                adversary_preds = adversary(y1, classifier_preds).squeeze()
                loss_atk = atk_criterion(adversary_preds, y2)    
                loss_atk.backward()
                atk_optim.step()
                
                fe_optim.zero_grad()
                features = feature_extractor(X)
                classifier_preds = classifier(features)
                adversary_preds = adversary(y1, classifier_preds).squeeze()   
                loss_fe = atk_criterion(1 - adversary_preds, y2) + clf_criterion(classifier_preds, y1)
                loss_fe.backward()
                fe_optim.step()      
        unfreeze(classifier)

        ## train CLF
        print('Training CLF...')
        freeze(feature_extractor)
        for _ in tqdm(range(steps_per_epoch)):
            for (X, y) in clf_train_dl:
                X, y = X.to(device), y.to(device)
                clf_optim.zero_grad()
                features = feature_extractor(X)
                predictions = classifier(features)
                loss = clf_criterion(predictions, y)
                loss.backward()
                clf_optim.step()
        unfreeze(feature_extractor)

        ## train ATK
        print('Training ATK...')
        freeze(feature_extractor)
        freeze(classifier)
        for _ in tqdm(range(steps_per_epoch)):
            for (X, (y1, y2)) in atk_train_dl:
                X, y1, y2 = X.to(device), y1.to(device), y2.to(device).float()
                atk_optim.zero_grad()
                features = feature_extractor(X)
                classifier_preds = classifier(features)
                adversary_preds = adversary(y1, classifier_preds).squeeze()
                loss_atk = atk_criterion(adversary_preds, y2)    
                loss_atk.backward()
                atk_optim.step()
        unfreeze(feature_extractor)
        unfreeze(classifier)
        
        ## eval steps
        eval_clf(feature_extractor, classifier, clf_test_dl, clf_criterion, device)
        eval_atk(feature_extractor, adversary, classifier, atk_test_dl, atk_criterion, device)
    return feature_extractor, adversary, classifier


In [15]:
fe, adv, clf = train(clf_train_dl, clf_test_dl, atk_train_dl, atk_test_dl, 2, 75)

EPOCH [1/2]:
Training FE+CLF...


100%|██████████| 75/75 [05:10<00:00,  4.14s/it]


Training FE+ATK...


100%|██████████| 75/75 [07:52<00:00,  6.30s/it]


Training CLF...


100%|██████████| 75/75 [04:31<00:00,  3.62s/it]


Training ATK...


100%|██████████| 75/75 [04:00<00:00,  3.20s/it]


Classifier Loss: 0.4615499711766535 | Classifier Accuracy: 0.9100800156593323
Adversary Loss: 0.6302769497036934 | Adversary Accuracy: 0.6105999946594238
EPOCH [2/2]:
Training FE+CLF...


100%|██████████| 75/75 [05:15<00:00,  4.21s/it]


Training FE+ATK...


100%|██████████| 75/75 [07:59<00:00,  6.39s/it]


Training CLF...


100%|██████████| 75/75 [04:33<00:00,  3.65s/it]


Training ATK...


100%|██████████| 75/75 [04:01<00:00,  3.22s/it]


Classifier Loss: 2.9020180142655665 | Classifier Accuracy: 0.4946799874305725
Adversary Loss: 0.632641413807869 | Adversary Accuracy: 0.6589000225067139
