In [6]:
from time import time
import matplotlib.pyplot as plt
import copy
import torch
import torchvision
from torchvision import models

import torch.nn.functional as F
from torch import nn
import torch.optim as optim
from torch.optim import lr_scheduler


from foolbox.attacks import LinfPGD, LinearSearchBlendedUniformNoiseAttack
from foolbox import PyTorchModel
#from load_models import load_mobilenet

if torch.cuda.is_available() == True:
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(device)
dtype = torch.float32

cuda:0


# Load Imagenette Data

In [7]:
def load_imagenette(path, bs=32):
    train_transforms = torchvision.transforms.Compose([
        #torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
        #torchvision.transforms.RandomHorizontalFlip(),
        #torchvision.transforms.RandomRotation(20),
        torchvision.transforms.Resize(256),
        torchvision.transforms.RandomResizedCrop(224),
        torchvision.transforms.ToTensor(),
        #torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    val_transforms= torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        #torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


    train_path= path+'/train'
    imagenette_train = torchvision.datasets.ImageFolder(
        root=train_path,
        transform=train_transforms
    )

    val_path=path+'/val'
    imagenette_val = torchvision.datasets.ImageFolder(
        root=val_path,
        transform=val_transforms
    )


    train_loader = torch.utils.data.DataLoader(imagenette_train, num_workers=4,
                                              batch_size=bs,
                                              shuffle=True)
    val_loader = torch.utils.data.DataLoader(imagenette_val, num_workers=4,
                                              batch_size=bs,
                                              shuffle=True)
    return train_loader, val_loader, (imagenette_train, imagenette_val)


In [8]:
PATH = '/home/florian/data/imagenette2'

In [9]:
#train_dl, val_dl,(train_ds, val_ds) = load_imagenette(PATH, 5000)
#x_train, y_train = next(iter(train_dl))
#x_train, y_train = x_train.numpy(), y_train.numpy()

In [10]:
train_dl, val_dl,(train_ds, val_ds) = load_imagenette(PATH, 32)

dataloaders = {
    'train':train_dl, 
    'validation':val_dl
}
dataset_sizes = {
    'train':len(train_dl.dataset), 
    'validation':len(val_dl.dataset)
}

# Load Model

In [6]:
import torch, torchvision

from torchvision import models
from torch import nn

class ImageNetNormalization(nn.Module):
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        super(ImageNetNormalization, self).__init__()
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)

    def forward(self, x):
        return torchvision.transforms.functional.normalize(x, self.mean, self.std)
    

def load_resnet(MODEL_PATH):
    base_resnet = models.resnet18(pretrained=True)

    num_ftrs_in = base_resnet.fc.in_features
    num_ftrs_out = 10
    base_resnet.fc = nn.Linear(num_ftrs_in, num_ftrs_out)
    resnet = torch.nn.Sequential(
        #ImageNetNormalization(),
        base_resnet
    )

    state_dict = torch.load(MODEL_PATH)
    resnet.load_state_dict(state_dict)
    resnet = resnet
    model = resnet
    model.eval()
    return model

def load_mobilenet(MODEL_PATH):

    base_mobilenet = models.mobilenet_v3_small(pretrained=True)

    num_ftrs_in = base_mobilenet.classifier[0].in_features
    num_ftrs_out = base_mobilenet.classifier[0].out_features
    base_mobilenet.classifier[0] = nn.Linear(num_ftrs_in, num_ftrs_out)

    num_ftrs_in = base_mobilenet.classifier[3].in_features
    num_ftrs_out = 10
    base_mobilenet.classifier[3] = nn.Linear(num_ftrs_in, num_ftrs_out)

    mobilenet = torch.nn.Sequential(
        ImageNetNormalization(),
        base_mobilenet
    )

    
    state_dict = torch.load(MODEL_PATH)
    mobilenet.load_state_dict(state_dict)
    mobilenet = mobilenet

    model = mobilenet
    model.eval()
    
    return model

def initialize_mobilenet(pretrained=True):
    model = models.mobilenet_v3_small(pretrained=pretrained)

    num_ftrs_in = model.classifier[0].in_features
    num_ftrs_out = model.classifier[0].out_features
    model.classifier[0] = nn.Linear(num_ftrs_in, num_ftrs_out)

    num_ftrs_in = model.classifier[3].in_features
    num_ftrs_out = 10
    model.classifier[3] = nn.Linear(num_ftrs_in, num_ftrs_out)

    model = torch.nn.Sequential(
        model
    )
    return model

## Mobilenet

In [7]:
#MNET_PATH = '../ET-Adversarials/models/MobileNetV3Small-wo-normalization-layer.pt'
MNET_PATH = '../ET-Adversarials/models/MobileNetV3Small.pt'
model = load_mobilenet(MNET_PATH).to(device)
#model = load_resnet(RNET_PATH)


In [8]:
norm = torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3)
images, labels = next(iter(dataloaders['val']))
images, labels = images.to(device), labels.to(device)
model.eval()
fmodel = PyTorchModel(model, bounds=(0,1), preprocessing=preprocessing)
attack = LinfPGD(abs_stepsize=(2/255), steps=50, random_start=True)
_, advs, success = attack(fmodel, images, labels, epsilons=8/255)

KeyError: 'val'

In [9]:
def evaluate_acc(model, dataloader, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    norm = torchvision.transforms.Normalize(mean, std)
    model.eval()
    running_corrects = 0.
    running_loss = 0.
    for images,labels in dataloader:
        images,labels = images.to(device),labels.to(device)
        outputs = model(norm(images))
        running_loss += criterion(outputs,labels).detach()
        running_corrects += (outputs.argmax(-1)==labels).sum()
    model.train()
    return (running_loss/len(dataloader.dataset)).item(), (running_corrects/len(dataloader.dataset)).item()

def evaluate_rob_acc(model, attack, eps, dataloader, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), fmodel_bounds=(0,1)):
    norm = torchvision.transforms.Normalize(mean, std)
    model.eval()
    running_corrects = 0.
    running_loss = 0.
    fmodel = PyTorchModel(model, bounds=fmodel_bounds, preprocessing=dict(mean=mean, std=std, axis=-3)) 
    for images,labels in dataloader:
        images,labels = images.to(device),labels.to(device)
        _,adv_images,_ = attack(fmodel, images, labels, epsilons=[8/255])
        adv_outputs = model(norm(adv_images[0]))
        running_loss += criterion(adv_outputs,labels).detach()
        running_corrects += (adv_outputs.argmax(-1)==labels).sum()
    model.train()
    return (running_loss/len(dataloader.dataset)).item(), (running_corrects/len(dataloader.dataset)).item()

In [10]:
criterion =nn.CrossEntropyLoss()
attack = LinfPGD(abs_stepsize=(2/255), steps=7, random_start=True)
eps = [8/255]
dataloader=dataloaders['validation']
print(evaluate_rob_acc(model, LinfPGD(abs_stepsize=(2/255), steps=7, random_start=True), [8/255], dataloader, mean=(0,0,0), std=(1,1,1)))
print(evaluate_acc(model, dataloader, mean=(0,0,0), std=(1,1,1)))

(1.9666242599487305, 0.0)
(0.0028967992402613163, 0.9714649319648743)


In [33]:
model = load_mobilenet(MNET_PATH).to(device)

criterion =nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
dataloaders = dataloaders
num_epochs = 50
save_name=None
scheduler=None
advtrain=True
eps = [8]
dataloader = dataloaders['validation']



liveloss = fit_model(
    model, 
    criterion, 
    optimizer, 
    dataloaders, 
    device, 
    num_epochs,
    mean=(0,0,0),
    std=(1,1,1)
    #advtrain_hook=None,
    #mean=(0.485, 0.456, 0.406), 
    #std=(0.229, 0.224, 0.225)
)



0 0.37955433519907067
train  - clean acc -  (0.03237147629261017, 0.6791635751724243)
train  - robust acc -  (0.11662627011537552, 0.0844862163066864)
0 0.7365605095541401
validation  - clean acc -  (0.0269654281437397, 0.7365604639053345)
validation  - robust acc -  (0.10990846902132034, 0.0993630513548851)
1 0.38071602069912347
train  - clean acc -  (0.033691901713609695, 0.667018711566925)
train  - robust acc -  (0.05850416049361229, 0.36371317505836487)
1 0.729171974522293
validation  - clean acc -  (0.028699427843093872, 0.7291719317436218)
validation  - robust acc -  (0.05353434011340141, 0.41681525111198425)
2 0.39835251874537964
train  - clean acc -  (0.03150007873773575, 0.7036646008491516)
train  - robust acc -  (0.05277319997549057, 0.4283451437950134)
2 0.7536305732484077
validation  - clean acc -  (0.026546522974967957, 0.753630518913269)
validation  - robust acc -  (0.04844166710972786, 0.47337576746940613)
3 0.41841799556447357
train  - clean acc -  (0.029828444123268127

26 0.8476433121019108
validation  - clean acc -  (0.016576098278164864, 0.847643256187439)
validation  - robust acc -  (0.03933171182870865, 0.5757961273193359)
27 0.5912979195268772
train  - clean acc -  (0.017629902809858322, 0.8431724905967712)
train  - robust acc -  (0.037125322967767715, 0.5877072811126709)
27 0.8476433121019108
validation  - clean acc -  (0.016378119587898254, 0.847643256187439)
validation  - robust acc -  (0.03901614248752594, 0.5745222568511963)
28 0.5986904636181223
train  - clean acc -  (0.016217505559325218, 0.8490865230560303)
train  - robust acc -  (0.036667875945568085, 0.5982680320739746)
29 0.8509554140127389
validation  - clean acc -  (0.015536785125732422, 0.8509553670883179)
validation  - robust acc -  (0.04037930816411972, 0.5740126967430115)
30 0.602492343436477
train  - clean acc -  (0.01626785285770893, 0.8503537774085999)
train  - robust acc -  (0.03670831769704819, 0.5964726805686951)
30 0.8496815286624204
validation  - clean acc -  (0.01605595

In [35]:
print('final robust accuracy: ',evaluate_rob_acc(model, LinfPGD(abs_stepsize=(2/255), steps=7, random_start=True), [8/255], dataloaders['validation'], mean=(0,0,0), std=(1,1,1)))
print('final clean accuracy: ',evaluate_acc(model, dataloaders['validation'], mean=(0,0,0), std=(1,1,1)))

final robust accuracy:  (0.040747176855802536, 0.5834394693374634)
final clean accuracy:  (0.01431315392255783, 0.8603821396827698)


In [36]:
SAVE_PATH = '../ET-Adversarials/models/MobileNetV3Small-adversarially-trained.pt'

torch.save(model.state_dict(), SAVE_PATH)

In [28]:
def PGD7_training():
    # Madry adversarial training with PGD-7
    attack = LinfPGD(abs_stepsize=(2/255), steps=7, random_start=True)
    attack_kwargs = {"epsilons": (8/255)}
    fmodel_bounds = (0,1)

    def PGD7_attack(fmodel, images, labels):
        return attack(fmodel, images, labels, **attack_kwargs)[1]
    
    return PGD7_attack, fmodel_bounds

def fit_model(
    model, 
    criterion, 
    optimizer, 
    dataloaders, 
    device, 
    num_epochs,
    advtrain_hook=PGD7_training, 
    mean=(0.485, 0.456, 0.406), 
    std=(0.229, 0.224, 0.225)
):
    BEST_MODEL_PATH = './tmp/best_model.pt'
    begin = time()
    model = model.to(device) # Moves and/or casts the parameters and buffers to device.
    best_val_acc = 0
    norm = torchvision.transforms.Normalize(mean, std)
    if advtrain_hook is not None:
        attack, fmodel_bounds = advtrain_hook()

    for epoch in range(num_epochs): # Number of passes through the entire training & validation datasets
        logs = {}
        for phase in ['train', 'validation']: # First train, then validate

            # Switch between training and test eval mode depending on phase.
            model.train() if phase == 'train' else model.eval()

            running_loss = 0.0 # keep track of loss
            running_corrects = 0 # count of carrectly classified images

            for images, labels in dataloaders[phase]:
                images = images.to(device) # Perform Tensor device conversion
                labels = labels.to(device)

                if phase == "train" and advtrain_hook is not None:
                    # Perturb images before computing gradients
                    model.eval()
                    preprocessing = dict(mean=mean, std=std, axis=-3)
                    fmodel = PyTorchModel(model, bounds=fmodel_bounds, preprocessing=preprocessing)
                    images = attack(fmodel, images, labels)
                    model.train()

                # Compute gradients and update weights
                outputs = model(norm(images))
                loss = criterion(outputs, labels)
                if phase == "train":
                    optimizer.zero_grad() # Set all previously calculated gradients to 0
                    loss.backward() # Calculate gradients
                    optimizer.step() # Update weights

                preds = torch.argmax(outputs, dim=1) # Get model's predictions
                running_loss += loss.detach() * images.size(0) # multiply mean loss by the number of elements
                running_corrects += torch.sum(preds == labels.data) # add number of correct predictions to total

            epoch_loss = running_loss.item() / len(dataloaders[phase].dataset) # get the "mean" loss for the epoch
            epoch_acc = running_corrects.item() / len(dataloaders[phase].dataset) # Get proportion of correct predictions

            print(epoch, epoch_acc)
            #print(epoch_loss)

            # Logging
            prefix = ''
            if phase == 'validation':
                prefix = 'val_'

            logs[prefix + 'log loss'] = epoch_loss
            logs[prefix + 'accuracy'] = epoch_acc

            
                
            acc = evaluate_acc(model, dataloaders[phase], mean=mean, std=std)
            rob_acc = evaluate_rob_acc(model, LinfPGD(abs_stepsize=(2/255), steps=7, random_start=True), [8/255], dataloaders[phase], mean=mean, std=std)
            
            if phase == 'validation' and acc[1]+rob_acc[1]>best_val_acc:
                torch.save(model.state_dict(), BEST_MODEL_PATH)
            print(phase, ' - clean acc - ', acc)
            print(phase,' - robust acc - ', rob_acc)
    model.load_state_dict(torch.load(BEST_MODEL_PATH))
    print('final robust accuracy: ',evaluate_rob_acc(model, LinfPGD(abs_stepsize=(2/255), steps=7, random_start=True), [8/255], dataloaders['validation'], mean=mean, std=std))
    print('final clean accuracy: ',evaluate_acc(model, dataloaders['validation'], mean=mean, std=std))

In [None]:
BEST_MODEL_PATH = './tmp/best_model.pt'
torch.save(model.state_dict(), BEST_MODEL_PATH)


