### Importing Libraries

In [None]:
import os
import sys
import argparse
import options
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
import pywt
import matplotlib.pyplot as plt
import time
import datetime
from tqdm import tqdm
from PIL import Image
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

In [None]:
##########################
### SETTINGS
##########################

# Hyperparameters
RANDOM_SEED = 1

LEARNING_RATE = 0.001
BATCH_SIZE = 512
NUM_EPOCHS = 25

# Architecture
NUM_FEATURES = 32*32
NUM_CLASSES = 10

# Other
DEVICE = "cuda:0"
# DEVICE = "cpu"
GRAYSCALE = False

In [None]:
# DATA

##########################
### MNIST DATASET
##########################

# Note transforms.ToTensor() scales input images
# to 0-1 range
# Configure data loader

os.makedirs(path, exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        path,
        train= True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(32),transforms.ToTensor()]           # transforms.Normalize([0.5], [0.5])  - we don't have to do the transformations coz it is already normalized to [0,1]
        ),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        path,
        train= False,
        download=True,
        transform=transforms.Compose(
            [ transforms.Resize(32),transforms.ToTensor()]
        ),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [None]:
# classifier

# Classfier Model - RESNET34



def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):       # this is the method that will be used by libraries accessing my model, especially for torchattacks

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        #print(x.shape)
        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        probas = F.softmax(logits, dim=1)
        return logits

    def my_forward(self, x,extract_map = False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        #print("pre_maxpool",x.shape)
        x = self.maxpool(x)

        #print("after conv1", x.shape)
        x = self.layer1(x)
        #print("after layer1", x.shape)
        x = self.layer2(x)
        #print("after layer2", x.shape)
        x = self.layer3(x)
        #print("after layer3", x.shape)
        x = self.layer4(x)
        #print("after layer4", x.shape)
        #x = self.avgpool(x)
        #print(x.shape)
        x = x.view(x.size(0), -1)
        if extract_map:

            return x

        #print(x.shape)             # now the x contains my feature map flattened
        logits = self.fc(x)
        #print(logits.shape)
        probas = F.softmax(logits, dim=1)
        #print(probas.shape)
        return logits, probas

def resnet34(num_classes):
    """Constructs a ResNet-34 model."""
    model = ResNet(block=BasicBlock,
                   layers=[3, 4, 6, 3],
                   num_classes=NUM_CLASSES,
                   grayscale=GRAYSCALE)
    return model


In [None]:
# for training the classifier

classifier = ResNet34().to(DEVICE)
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(classifier.parameters(), lr= LEARNING_RATE)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [None]:
def compute_accuracy(model, data_loader, device):
    correct_pred, num_examples = 0, 0
    for i, (features, targets) in enumerate(data_loader):

        features = features.to(device)
        targets = targets.to(device)

        logits, probas = model.my_forward(features)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100

In [None]:
# Classifier Training

start_time = time.time()
best_acc = 0
for epoch in range(10):

    classifier.train()
    for batch_idx, (features, targets) in enumerate(dataloader):

        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
        #print(features.shape)
        ### FORWARD AND BACK PROP
        outputs = classifier(features)
        #print(logits.shape)
        #print(probas.shape)
        cost = criterion(outputs, targets)

        optimizer.zero_grad()

        cost.backward()

        ### UPDATE MODEL PARAMETERS
        optimizer.step()

        ### LOGGING
        if not batch_idx % 50:
            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f'
                   %(epoch+1, NUM_EPOCHS, batch_idx,
                     len(dataloader), cost))



    classifier.eval()

    with torch.set_grad_enabled(False): # save memory during inference
        acc = compute_accuracy(classifier, test_loader, device=DEVICE)
        print('Epoch: %03d/%03d | Train: %.3f%%' % (
              epoch+1, NUM_EPOCHS, acc
              ))
    if acc>best_acc:
        torch.save(classifier, filename1)           #filename is the path where you want to save your classifer
        best_acc = acc

    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))

### Creating adversrial dataset

In [None]:
import torchattacks
atk1 = torchattacks.FGSM(classifier, eps=0.08)
atk2 = torchattacks.PGD(classifier, eps=0.05, alpha=2/255, steps=10)
atk3 = torchattacks.BIM(classifier, eps=0.05, alpha=2/255, steps=10)

atk = torchattacks.MultiAttack([atk1,atk2,atk3])

In [None]:
adv_train = []

for i, (imgs, labels) in enumerate(dataloader):

        # # Configure input
        imgs = imgs.to(DEVICE)          # clean imgs

        labels = labels.to(DEVICE)       # clean labels


        attacked_imgs = atk(imgs,labels)        # attacked imgs

        adv_train.extend(list(zip(imgs,attacked_imgs,labels)))

In [None]:
adv_test = []

for i, (imgs, labels) in enumerate(test_loader):

        # Configure input
        imgs = imgs.to(DEVICE)

        labels = labels.to(DEVICE)


        attacked_imgs = atk(imgs,labels)

        adv_test.extend(list(zip(imgs,attacked_imgs,labels)))

In [None]:
batch_size = 256  # Adjust this as needed
adv_train_loader = DataLoader(adv_train, batch_size=batch_size, shuffle=False)
adv_test_loader = DataLoader(adv_test, batch_size=batch_size, shuffle=False)

### DWT Transform

In [None]:
def dwt_batch(batch_img):
    # Convert the batch of images to NumPy array
    batch_np = batch_img.to("cpu").numpy()

    # Perform DWT on the batch of images
    LL_batch, (LH_batch, HL_batch, HH_batch) = pywt.dwt2(batch_np, "haar")

    # Convert the DWT coefficients back to PyTorch tensors
    LL_batch = torch.tensor(LL_batch)
    LH_batch = torch.tensor(LH_batch)
    HL_batch = torch.tensor(HL_batch)
    HH_batch = torch.tensor(HH_batch)

    return LL_batch, LH_batch, HL_batch, HH_batch

Plotting and visualizing the adversarial attack on frequency domain through plots

In [None]:
b = next(iter(adv_train_loader))
clean = b[0]
noisy = b[1]
targets = b[2]

LL, LH, HL, HH = dwt_batch(clean)
LL_1, LH_1, HL_1, HH_1 = dwt_batch(noisy)

# DWT plot of clean

plt.figure(figsize=(3, 3))

plt.subplot(2,2,1)
plt.imshow(torch.permute(LL[0],(1,2,0)).to("cpu"))
plt.axis("off")
plt.subplot(2,2,2)
plt.imshow(torch.log(1+torch.permute(LH[0],(1,2,0)).to("cpu")))
plt.axis("off")
plt.subplot(2,2,3)
plt.imshow(torch.log(1+torch.permute(HL[0],(1,2,0)).to("cpu")))
plt.axis("off")
plt.subplot(2,2,4)
plt.imshow(torch.log(1+torch.permute(HH[0],(1,2,0)).to("cpu")))
plt.axis("off")

plt.show()

# DWT plot of noisy

plt.figure(figsize=(3, 3))

plt.subplot(2,2,1)
plt.imshow(torch.permute(LL_1[0],(1,2,0)).to("cpu"))
plt.axis("off")
plt.subplot(2,2,2)
plt.imshow(torch.log(1+torch.permute(LH_1[0],(1,2,0)).to("cpu")))
plt.axis("off")
plt.subplot(2,2,3)
plt.imshow(torch.log(1+torch.permute(HL_1[0],(1,2,0)).to("cpu")))
plt.axis("off")
plt.subplot(2,2,4)
plt.imshow(torch.log(1+torch.permute(HH_1[0],(1,2,0)).to("cpu")))
plt.axis("off")
plt.show()

### Denoiser Architecture

In [None]:
import restormer_arch as base

In [None]:
def bicubic_downsample(image, scale_factor=2):

    new_height = int(image.size(-2) / scale_factor)
    new_width = int(image.size(-1) / scale_factor)

    downsampled_image = F.interpolate(image, size=(new_height, new_width), mode='bicubic', align_corners=False)

    return downsampled_image

def bicubic_upsample(image, scale_factor=2):
    
    upsampled_image = F.interpolate(image, scale_factor=scale_factor, mode='bicubic', align_corners=False)

    return upsampled_image



In [None]:
from einops import rearrange

class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(CrossAttention, self).__init__()

        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.q = nn.Conv2d(dim, dim , kernel_size=1, bias=bias)
        self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
        self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
        self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)



    def forward(self, x, y):

        b, c, h, w = x.shape

        q = self.q_dwconv(self.q(x))
        kv = self.kv_dwconv(self.kv(y))
        k, v = kv.chunk(2, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)

        return out

        # return out+x

In [None]:
class my_denoiser(nn.Module):

    def __init__(self,
                 in_channels = 3,
                 out_channels = 3,
                 embed_dim = 48,
                 LayerNorm_type = "WithBias",
                 ffn_expansion_factor = 2.66,
                 bias = False):

        super(my_denoiser, self).__init__()

        self.patch_embed_1 = base.OverlapPatchEmbed(in_channels, embed_dim)
        self.patch_embed_dwt_1 = base.OverlapPatchEmbed(in_channels, embed_dim) 
        self.T1 = base.TransformerBlock(dim = embed_dim, num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)  
        self.T1_dwt = base.TransformerBlock(dim = embed_dim, num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)

        self.cross_attn_1 = CrossAttention(dim = embed_dim, num_heads = 4, bias = bias)


        
        self.patch_embed_2 = base.OverlapPatchEmbed(in_channels, embed_dim)
        self.patch_embed_dwt_2 = base.OverlapPatchEmbed(in_channels, embed_dim)
        self.T2 = base.TransformerBlock(dim = embed_dim, num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)
        self.T2_dwt = base.TransformerBlock(dim = embed_dim, num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)

        self.cross_attn_2 = CrossAttention(dim = embed_dim, num_heads = 4, bias = bias)


        self.patch_embed_3 = base.OverlapPatchEmbed(in_channels, embed_dim)
        self.patch_embed_dwt_3 = base.OverlapPatchEmbed(in_channels, embed_dim)

        self.T3 = base.TransformerBlock(dim = embed_dim , num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)
        self.T3_dwt = base.TransformerBlock(dim = embed_dim , num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)

        self.cross_attn_3 = CrossAttention(dim = embed_dim, num_heads = 4, bias = bias)

        self.T4 = base.TransformerBlock(dim = embed_dim, num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)

        self.cross_attn_2_3 = CrossAttention(dim = embed_dim, num_heads = 4, bias = bias)  

        self.cross_attn_1_23 = CrossAttention(dim = embed_dim, num_heads = 4, bias = bias)      

        self.T5 = base.TransformerBlock(dim = embed_dim, num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)

        self.T6 = base.TransformerBlock(dim = embed_dim , num_heads = 4, ffn_expansion_factor = ffn_expansion_factor, bias = bias, LayerNorm_type = LayerNorm_type)


        self.cross_attn_4_5 = CrossAttention(dim = embed_dim , num_heads = 4, bias = bias)    


        self.cross_attn_6_45 = CrossAttention(dim = embed_dim , num_heads = 4, bias = bias)     

        
        self.out_project = nn.Conv2d(embed_dim, out_channels = out_channels, kernel_size = 1, bias = bias)


    def forward(self, in_img):

        #----scale 1-----#
        in_enc_1 = self.patch_embed_1(in_img)
        out_enc_1 = self.T1(in_enc_1)
        LL_1, LH_1, HL_1, HH_1 = dwt_batch(in_img.detach())
        dwt_1 = bicubic_upsample(LH_1.detach() + HL_1.detach() + HH_1.detach())
        dwt_1 = dwt_1.to(DEVICE)
        in_enc_1_dwt = self.patch_embed_dwt_1(dwt_1)
        out_enc_1_dwt = self.T1_dwt(in_enc_1_dwt)


        CA_1 = self.cross_attn_1(out_enc_1, out_enc_1_dwt)         

    
        LL_1 = (LL_1.detach()).to(DEVICE)
        in_enc_2 = self.patch_embed_2(LL_1)
        out_enc_2 = self.T2(in_enc_2)
        LL_2, LH_2, HL_2, HH_2 = dwt_batch(LL_1.detach())
        dwt_2 = bicubic_upsample(LH_2.detach() + HL_2.detach() + HH_2.detach())
        dwt_2 = dwt_2.to(DEVICE)
        in_enc_2_dwt = self.patch_embed_dwt_2(dwt_2)
        out_enc_2_dwt = self.T2_dwt(in_enc_2_dwt)

        CA_2 = self.cross_attn_2(out_enc_2, out_enc_2_dwt)    


       
        LL_2 = (LL_2.detach()).to(DEVICE)
        in_enc_3 = self.patch_embed_3(LL_2)
        out_enc_3 = self.T3(in_enc_3)
        _ , LH_3, HL_3, HH_3 = dwt_batch(LL_2.detach())
        dwt_3 = bicubic_upsample(LH_3.detach() + HL_3.detach() + HH_3.detach())
        dwt_3 = dwt_3.to(DEVICE)

        in_enc_3_dwt = self.patch_embed_dwt_3(dwt_3)
        out_enc_3_dwt = self.T3_dwt(in_enc_3_dwt)

        CA_3 = self.cross_attn_3(out_enc_3, out_enc_3_dwt) 

        t4 = self.T4(CA_3)              
        t4 = t4 + CA_3  


        CA_3_up = bicubic_upsample(CA_3)


        CA_2_3 = self.cross_attn_2_3(CA_2, CA_3_up) 

        CA_2_3 = CA_2_3 + CA_2

        t5 = self.T5(CA_2_3)            

        t5 = t5 + CA_2_3 

        CA_2_3_up = bicubic_upsample(CA_2_3)

        CA_1_23 = self.cross_attn_1_23(CA_1,CA_2_3_up)


        CA_1_23 = CA_1_23 + CA_1

        t6 = self.T5(CA_1_23)   

        t6 = t6+ CA_1_23 

        t4_up = bicubic_upsample(t4)

        CA_4_5 = self.cross_attn_4_5(t5,t4_up)  


        CA_4_5 = CA_4_5 + t5

        CA_4_5_up = bicubic_upsample(CA_4_5)

        CA_6_45 = self.cross_attn_6_45(t6, CA_4_5_up)  


        CA_6_45 = CA_6_45 + t6

        out = self.out_project(CA_6_45)

        return out

In [None]:
class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
        return loss

### Training of Denoiser

In [None]:
denoiser = my_denoiser().to(DEVICE)
criterion = CharbonnierLoss().cuda()
optimizer = torch.optim.Adam(denoiser.parameters(), lr = 0.004, betas = (0.9, 0.999), eps = 1e-8)


In [None]:
# training block
torch.cuda.empty_cache()

best_loss = 100

for epoch in range(15):
    epoch_start_time = time.time()
    epoch_loss = 0
    train_id = 1


    for i, data in enumerate(tqdm(adv_train_loader), 0):
        # zero_grad
        optimizer.zero_grad()

        target = data[0].cuda()     # first element is clean
        input_ = data[1].cuda()     # 2nd element is noisy

        restored = denoiser(input_)
        restored = restored.to(DEVICE)
        loss = criterion(restored, target)

        loss.backward(retain_graph=True)
        optimizer.step()
        epoch_loss +=loss.item()


    with torch.no_grad():
        denoiser.eval()
        correct_pred, num_examples = 0, 0

        for ii, data_val in enumerate((adv_test_loader), 0):
            clean = data_val[0].cuda()
            noisy = data_val[1].cuda()
            targets = data_val[2].cuda()

            restored = denoiser(noisy).detach()
            # restored = torch.clamp(restored,0,1)
            restored = restored.to(DEVICE)

            outputs = classifier(restored)
            _, predicted_labels = outputs.max(1)
            num_examples += targets.size(0)
            correct_pred += (predicted_labels == targets).sum()

        acc = correct_pred.float()/num_examples * 100
        print("The accuracy at current level is ", acc)

        denoiser.train()
        torch.cuda.empty_cache()




    print("------------------------------------------------------------------")
    print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}".format(epoch, time.time()-epoch_start_time,epoch_loss))
    print("------------------------------------------------------------------")

    if epoch_loss<best_loss:
        best_loss = epoch_loss
        torch.save(denoiser, best_denoiser_path)

### Adversrial Training 

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
parser.add_argument("--n_residual_blocks", type=int, default=9, help="number of residual blocks in generator")

parser.add_argument('-f')
opt_1 = parser.parse_args()
print(opt_1)

In [None]:
# adversarial training

auxiliary_loss = torch.nn.CrossEntropyLoss()

classifier_1 = torch.load("trained_classifier_path", map_location='cuda')
optimizer_C = torch.optim.Adam(classifier_1.parameters(), lr=opt_1.lr, betas=(opt_1.b1, opt_1.b2))

classifier_1.cuda()

auxiliary_loss.cuda()

In [None]:
best_denoiser = torch.load(best_denoiser_path, map_location = DEVICE)

In [None]:
from torch.autograd import Variable
cuda = True if torch.cuda.is_available() else False
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


In [None]:

# ----------
#  Training
# ----------
batches_done = 0
opt_1.n_epochs = 20

for epoch in range(opt_1.n_epochs):

    torch.save(classifier_1, classifier_stage_1_path)

    for i, data in enumerate(adv_train_loader):


        clean = data[0].to(DEVICE)
        noisy = data[1].to(DEVICE)
        labels = data[2].to(DEVICE)


        with torch.cuda.amp.autocast():
            gen_imgs = best_denoiser(noisy).detach()

        gen_imgs = gen_imgs.to(DEVICE)



        optimizer_C.zero_grad()

        # Loss for real images
        real_aux = classifier_1(clean)
        c_real_loss = auxiliary_loss(real_aux, labels)

        # Loss for fake images
        fake_aux = classifier_1(gen_imgs.float())
        fake_aux = F.log_softmax(fake_aux,dim=1)      
        # log_softmax is done for numerical stability
        c_fake_loss = auxiliary_loss(fake_aux, labels)

        
        c_loss = c_fake_loss
       
        # Calculate classifier accuracy
        pred = np.concatenate([real_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy()], axis=0)
        c_acc_real = np.mean(np.argmax(pred, axis=1) == gt)     # classifier accuracy for real images

        pred = np.concatenate([fake_aux.data.cpu().numpy()], axis=0)
        gt = np.concatenate([labels.data.cpu().numpy()], axis=0)
        c_acc_generated = np.mean(np.argmax(pred, axis=1) == gt)  # classifer accuracy for fake images

        c_loss.backward(retain_graph=True)      

        print(
            "[Epoch %d/%d] [Batch %d/%d] [C loss: %f, real_acc: %d%% gen_acc %d%%]"
            % (epoch, opt_1.n_epochs, i, len(adv_train_loader), c_loss.item(), 100 * c_acc_real,100*c_acc_generated)
        )
        batches_done = epoch * len(adv_train_loader) + i


### Testing

In [None]:
fgsm = torchattacks.FGSM(classifier, eps=0.15)
pgd = torchattacks.PGD(classifier, eps=0.1, alpha=0.1, steps=15)
upgd = torchattacks.UPGD(classifier, eps=0.1,alpha=0.1,steps=15)
mifgsm = torchattacks.MIFGSM(classifier,eps=0.2,steps = 15,decay = 1)
bim = torchattacks.BIM(classifier, eps=0.2)

In [None]:
class CustomAdversarialDataset():
    def __init__(self, adversarial_data, transform=None):
        self.adversarial_data = adversarial_data
        self.transform = transform

    def __getitem__(self, idx):
        image, label = self.adversarial_data[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
def attacked_dataset(loader,atk):

    atk_data = []

    for i, (imgs, labels) in enumerate(loader):

            # Configure input
            imgs = imgs.to(DEVICE)

            labels = labels.to(DEVICE)


            attacked_imgs = atk(imgs,labels).detach()

            atk_data.extend(list(zip(imgs,attacked_imgs,labels)))

    test_transform = transforms.Compose([transforms.ToTensor()])

    # Create a custom dataset
    adversarial_dataset = CustomAdversarialDataset(atk_data, transform=test_transform)

    # Create a DataLoader for testing
    batch_size = 256  # Adjust this as needed
    test_loader = DataLoader(atk_data, batch_size=batch_size, shuffle=False)

    return test_loader

In [None]:
fgsm_atk = attacked_dataset(test_loader, fgsm)
pgd_atk = attacked_dataset(test_loader, pgd)
upgd_atk = attacked_dataset(test_loader, upgd)
mifgsm_atk = attacked_dataset(test_loader, mifgsm)
bim_atk = attacked_dataset(test_loader, bim)

In [None]:
classifier_2 = torch.load(classifier_stage_1_path, map_location = DEVICE)    # the adversirally trained classifier

In [None]:
def compute_accuracy_clean(model, data_loader, device):
    correct_pred, num_examples = 0, 0

    for i, data in enumerate(data_loader):

        clean = data[0].to(device)       # clean images
        targets = data[2].to(device)

        logits, probas = model.my_forward(clean)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100

In [None]:
def compute_accuracy_attacked(model, data_loader, device):
    correct_pred, num_examples = 0, 0

    for i, data in enumerate(data_loader):

        noisy = data[1].to(device)       # noisy images
        targets = data[2].to(device)

        logits, probas = model.my_forward(noisy)
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100

In [None]:
def compute_accuracy_denoised(model, denoiser, data_loader, device):
    correct_pred, num_examples = 0, 0

    for i, data in enumerate(data_loader):

        clean = data[0].to(device)
        noisy = data[1].to(device)
        targets = data[2].to(device)


        with torch.cuda.amp.autocast():
            restored = denoiser(noisy).detach()

        restored = torch.clamp(restored,0,1)
        restored = restored.to(DEVICE)


        logits, probas = model.my_forward(restored.float())
        _, predicted_labels = torch.max(probas, 1)
        num_examples += targets.size(0)
        correct_pred += (predicted_labels == targets).sum()
    return correct_pred.float()/num_examples * 100

In [None]:
print("FGSM")

clean_acc = compute_accuracy_clean(classifier, fgsm_atk, DEVICE)
atk_acc = compute_accuracy_attacked(classifier, fgsm_atk, DEVICE)
denoised_acc = compute_accuracy_denoised(classifier_2, best_denoiser,fgsm_atk, DEVICE)   # using classifier after adversarial training


print("Clean Accuracy", clean_acc.item(),"%")
print("Attacked Accuracy", atk_acc.item(),"%")
print("Denoised Accuracy", denoised_acc.item(),"%")

In [None]:
print("PGD")

clean_acc = compute_accuracy_clean(classifier, pgd_atk, DEVICE)
atk_acc = compute_accuracy_attacked(classifier, pgd_atk, DEVICE)
denoised_acc = compute_accuracy_denoised(classifier, best_denoiser, pgd_atk, DEVICE)   # using classifier after adversarial training



print("Clean Accuracy", clean_acc.item(),"%")
print("Attacked Accuracy", atk_acc.item(),"%")
print("Denoised Accuracy", denoised_acc.item(),"%")

In [None]:
print("UPGD")

clean_acc = compute_accuracy_clean(classifier, upgd_atk, DEVICE)
atk_acc = compute_accuracy_attacked(classifier, upgd_atk, DEVICE)
denoised_acc = compute_accuracy_denoised(classifier_2, best_denoiser, upgd_atk, DEVICE)   # using classifier after adversarial training

print("Clean Accuracy", clean_acc.item(),"%")
print("Attacked Accuracy", atk_acc.item(),"%")
print("Denoised Accuracy", denoised_acc.item(),"%")

In [None]:
print("MIFGSM")

clean_acc = compute_accuracy_clean(classifier, mifgsm_atk, DEVICE)
atk_acc = compute_accuracy_attacked(classifier, mifgsm_atk, DEVICE)
denoised_acc = compute_accuracy_denoised(classifier_2, best_denoiser,mifgsm_atk, DEVICE)   # using classifier after adversarial training

print("Clean Accuracy", clean_acc.item(),"%")
print("Attacked Accuracy", atk_acc.item(),"%")
print("Denoised Accuracy", denoised_acc.item(),"%")

In [None]:
print("BIM")

clean_acc = compute_accuracy_clean(classifier, bim_atk, DEVICE)
atk_acc = compute_accuracy_attacked(classifier, bim_atk, DEVICE)
denoised_acc = compute_accuracy_denoised(classifier_2, best_denoiser,bim_atk, DEVICE)   # using classifier after adversarial training

print("Clean Accuracy", clean_acc.item(),"%")
print("Attacked Accuracy", atk_acc.item(),"%")
print("Denoised Accuracy", denoised_acc.item(),"%")