In [1]:
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import torch.optim as optim
from data import load_imagenette, load_torchvision_dataset


if torch.cuda.is_available() == True:
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(device)
dtype = torch.float32

cuda:0


In [1]:
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Flatten, Conv2D, BatchNormalization, Dropout
from tensorflow.keras import regularizers

AUTOTUNE = tf.data.experimental.AUTOTUNE

import time



2021-10-04 14:02:18.896316: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [2]:
class ResBlock(nn.Module):
    """ ResBlock made from masked Layers"""
    def __init__(self, input_channels, output_channels, padding=0 , stride=1, kernel_size=3):
        super(ResBlock, self).__init__()
        self.s = stride
        self.bn1 = nn.BatchNorm2d(input_channels)
        self.c1 = nn.Conv2d(input_channels, output_channels, kernel_size, padding=padding, stride=self.s, bias=False)
        #self.c1 = MaskedConvLayer((output_channels, input_channels, filter_size, filter_size), padding=padding, bias=False, stride=self.s)
        self.bn2 = nn.BatchNorm2d(output_channels)
        self.c2 = nn.Conv2d(output_channels, output_channels, kernel_size, padding=padding, stride=1, bias=False)
        #self.c2 = MaskedConvLayer((output_channels, output_channels, filter_size, filter_size), padding=padding, bias=False, stride=1)
        if self.s == 2:
            #self.c3 = MaskedConvLayer((output_channels, input_channels, 1,1),padding=0, bias=False, stride=self.s)
            self.c3 = nn.Conv2d(input_channels, output_channels, 1, padding=0, stride=self.s, bias=False)
            
    def forward(self, inputs):
        shortcut = inputs
        x = self.c1(F.relu(self.bn1(inputs)))
        x = self.c2(F.relu(self.bn2(x)))
        if self.s == 2:
            shortcut = self.c3(shortcut)
        x = torch.add(x, shortcut)
        return x

class CifarResNet(nn.Module):
    def __init__(self):
        super(CifarResNet, self).__init__()
        self.c1 = nn.Conv2d(3,64,3,padding=(1,1),bias=False, stride=1)
        self.r1 = ResBlock(64,64,padding=1)
        self.r2 = ResBlock(64,64,padding=1)
        self.r3 = ResBlock(64,128,stride=2,padding=1)
        self.r4 = ResBlock(128,128,padding=1)
        self.r5 = ResBlock(128,256,stride=2,padding=1)
        self.r6 = ResBlock(256,256,padding=1)
        self.r7 = ResBlock(256,512,stride=2,padding=1)
        self.r8 = ResBlock(512,512,padding=1)
        self.p2 = nn.AvgPool2d(4)
        self.d1 = nn.Linear(512,10)
        
    def forward(self, inputs):
        x = self.c1(inputs)
        x = self.r1(x)
        x = self.r2(x)
        x = self.r3(x)
        x = self.r4(x)
        x = self.r5(x)
        x = self.r6(x)
        x = self.r7(x)
        #x = self.r8(x)
        x = self.p2(x)
        x = x.view(x.shape[0], x.shape[1])
        x = self.d1(x)
        return (x)
    
class ImageNetResNet(nn.Module):
    def __init__(self):
        super(ImageNetResNet, self).__init__()
        #self.c1 = MaskedConvLayer((64, 3, 7, 7), padding=(3,3),bias=False, stride=2)
        self.c1 = nn.Conv2d(3,64,7, padding=(3,3),bias=False, stride=2)
        self.p1 = nn.MaxPool2d((3,3), stride=(2,2), padding=(1))
        self.r1 = ResBlock(64,64,padding=1)
        self.r2 = ResBlock(64,64,padding=1)
        self.r3 = ResBlock(64,128,stride=2,padding=1)
        self.r4 = ResBlock(128,128,padding=1)
        self.r5 = ResBlock(128,256,stride=2,padding=1)
        self.r6 = ResBlock(256,256,padding=1)
        self.r7 = ResBlock(256,512,stride=2,padding=1)
        self.r8 = ResBlock(512,512,padding=1)
        self.p2 = nn.AvgPool2d(7)
        self.d1 = nn.Linear(512,10)
    
    def forward(self, inputs):
        x = self.c1(inputs)
        x = self.p1(x)
        x = self.r1(x)
        x = self.r2(x)
        x = self.r3(x)
        x = self.r4(x)
        x = self.r5(x)
        x = self.r6(x)
        x = self.r7(x)
        x = self.r8(x)
        x = self.p2(x)
        x = x.view(x.shape[0], x.shape[1])
        x = self.d1(x)
        return (x)

In [3]:
sum([p.numel() for p in m.parameters()])

NameError: name 'm' is not defined

In [4]:
m = CifarResNet().to(device)
m(torch.randn((1,3,32,32)).to(device))

tensor([[ 0.0784,  0.0841, -0.0207,  0.0004,  0.0868,  0.2435,  0.0250, -0.2190,
          0.2123,  0.0037]], device='cuda:0', grad_fn=<AddmmBackward>)

In [7]:
m = ImageNetResNet().to(device)
m(torch.randn((1,3,224,224)).to(device))

tensor([[-0.2038,  0.1521, -0.0023, -0.2033,  0.3244,  0.0647,  0.2296,  0.2677,
         -0.1040, -0.0898]], device='cuda:0', grad_fn=<AddmmBackward>)

In [4]:
PATH = '/home/florian/data/imagenette2'
train_dl, val_dl = load_imagenette(PATH, 128)

In [5]:
train_dl, val_dl = load_torchvision_dataset('CIFAR10', data_augmentation=True)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
%%time
_fit(m, train_dl, val_dl, 100, device)

[1,     1] loss: 2.30916, train_accuracy: 11.52, time: 0.24
[1,    11] loss: 2.72165, train_accuracy: 15.23, time: 0.21
[1,    21] loss: 2.14381, train_accuracy: 19.34, time: 0.21
[1,    31] loss: 2.11411, train_accuracy: 24.02, time: 0.22
[1,    41] loss: 1.96324, train_accuracy: 28.52, time: 0.21
[1,    51] loss: 1.94638, train_accuracy: 31.25, time: 0.21
[1,    61] loss: 1.84912, train_accuracy: 25.39, time: 0.21
[1,    71] loss: 1.74550, train_accuracy: 32.62, time: 0.22
[1,    81] loss: 1.75294, train_accuracy: 35.55, time: 0.21
[1,    91] loss: 1.84840, train_accuracy: 29.69, time: 0.21
duration: 21 s - train loss: 2.19618 - train accuracy: 25.95 - validation loss: 1.77 - validation accuracy: 33.61 
[2,     1] loss: 1.82785, train_accuracy: 29.49, time: 0.22
[2,    11] loss: 1.74302, train_accuracy: 33.01, time: 0.22
[2,    21] loss: 1.69170, train_accuracy: 34.57, time: 0.21
[2,    31] loss: 1.64378, train_accuracy: 38.87, time: 0.22
[2,    41] loss: 1.64039, train_accuracy: 37.

[12,    61] loss: 0.49532, train_accuracy: 84.96, time: 0.22
[12,    71] loss: 0.51346, train_accuracy: 82.23, time: 0.22
[12,    81] loss: 0.51108, train_accuracy: 79.88, time: 0.22
[12,    91] loss: 0.51293, train_accuracy: 83.40, time: 0.21
duration: 21 s - train loss: 0.52309 - train accuracy: 81.94 - validation loss: 0.96 - validation accuracy: 71.59 
[13,     1] loss: 0.51331, train_accuracy: 83.40, time: 0.21
[13,    11] loss: 0.49318, train_accuracy: 82.62, time: 0.21
[13,    21] loss: 0.46806, train_accuracy: 82.03, time: 0.21
[13,    31] loss: 0.45059, train_accuracy: 84.57, time: 0.21
[13,    41] loss: 0.46364, train_accuracy: 85.74, time: 0.21
[13,    51] loss: 0.46378, train_accuracy: 82.62, time: 0.22
[13,    61] loss: 0.53060, train_accuracy: 81.05, time: 0.22
[13,    71] loss: 0.47208, train_accuracy: 83.20, time: 0.22
[13,    81] loss: 0.59908, train_accuracy: 79.49, time: 0.21
[13,    91] loss: 0.50819, train_accuracy: 82.81, time: 0.22
duration: 21 s - train loss: 0.

duration: 21 s - train loss: 0.28864 - train accuracy: 89.86 - validation loss: 0.74 - validation accuracy: 79.40 
[24,     1] loss: 0.24826, train_accuracy: 91.60, time: 0.22
[24,    11] loss: 0.24130, train_accuracy: 91.21, time: 0.21
[24,    21] loss: 0.25133, train_accuracy: 91.02, time: 0.22
[24,    31] loss: 0.24302, train_accuracy: 91.60, time: 0.21
[24,    41] loss: 0.31349, train_accuracy: 89.26, time: 0.22
[24,    51] loss: 0.32660, train_accuracy: 87.50, time: 0.21
[24,    61] loss: 0.24841, train_accuracy: 91.80, time: 0.22
[24,    71] loss: 0.25528, train_accuracy: 92.38, time: 0.22
[24,    81] loss: 0.28798, train_accuracy: 91.02, time: 0.22
[24,    91] loss: 0.32571, train_accuracy: 88.28, time: 0.22
duration: 21 s - train loss: 0.28075 - train accuracy: 90.36 - validation loss: 1.12 - validation accuracy: 72.42 
[25,     1] loss: 0.19596, train_accuracy: 92.77, time: 0.22
[25,    11] loss: 0.19380, train_accuracy: 93.95, time: 0.21
[25,    21] loss: 0.24630, train_accur

[35,    21] loss: 0.18538, train_accuracy: 94.14, time: 0.22
[35,    31] loss: 0.15778, train_accuracy: 94.34, time: 0.21
[35,    41] loss: 0.17411, train_accuracy: 92.97, time: 0.21
[35,    51] loss: 0.27937, train_accuracy: 90.82, time: 0.21
[35,    61] loss: 0.16613, train_accuracy: 93.36, time: 0.22
[35,    71] loss: 0.23233, train_accuracy: 90.62, time: 0.22
[35,    81] loss: 0.20276, train_accuracy: 93.95, time: 0.22
[35,    91] loss: 0.17553, train_accuracy: 93.55, time: 0.21
duration: 21 s - train loss: 0.19217 - train accuracy: 93.26 - validation loss: 0.82 - validation accuracy: 81.72 
[36,     1] loss: 0.11530, train_accuracy: 96.88, time: 0.22
[36,    11] loss: 0.17840, train_accuracy: 93.16, time: 0.21
[36,    21] loss: 0.16246, train_accuracy: 94.92, time: 0.22
[36,    31] loss: 0.18018, train_accuracy: 92.97, time: 0.22
[36,    41] loss: 0.18178, train_accuracy: 93.55, time: 0.22
[36,    51] loss: 0.24731, train_accuracy: 92.38, time: 0.22
[36,    61] loss: 0.14275, trai

[46,    61] loss: 0.10638, train_accuracy: 96.09, time: 0.22
[46,    71] loss: 0.13524, train_accuracy: 94.73, time: 0.22
[46,    81] loss: 0.11968, train_accuracy: 96.09, time: 0.21
[46,    91] loss: 0.12787, train_accuracy: 95.90, time: 0.21
duration: 21 s - train loss: 0.12761 - train accuracy: 95.54 - validation loss: 0.64 - validation accuracy: 85.05 
[47,     1] loss: 0.09298, train_accuracy: 96.48, time: 0.22
[47,    11] loss: 0.14290, train_accuracy: 95.31, time: 0.22
[47,    21] loss: 0.10729, train_accuracy: 96.48, time: 0.22
[47,    31] loss: 0.11930, train_accuracy: 97.27, time: 0.21
[47,    41] loss: 0.11706, train_accuracy: 95.12, time: 0.22
[47,    51] loss: 0.11583, train_accuracy: 96.88, time: 0.22
[47,    61] loss: 0.10102, train_accuracy: 96.68, time: 0.22
[47,    71] loss: 0.15364, train_accuracy: 94.53, time: 0.22
[47,    81] loss: 0.12495, train_accuracy: 95.51, time: 0.21
[47,    91] loss: 0.12622, train_accuracy: 96.09, time: 0.21
duration: 21 s - train loss: 0.

duration: 21 s - train loss: 0.09974 - train accuracy: 96.49 - validation loss: 1.88 - validation accuracy: 70.07 
[58,     1] loss: 0.09201, train_accuracy: 95.90, time: 0.22
[58,    11] loss: 0.08211, train_accuracy: 97.27, time: 0.21
[58,    21] loss: 0.13349, train_accuracy: 96.29, time: 0.22
[58,    31] loss: 0.11773, train_accuracy: 95.51, time: 0.21
[58,    41] loss: 0.10411, train_accuracy: 96.68, time: 0.21
[58,    51] loss: 0.07550, train_accuracy: 97.27, time: 0.22
[58,    61] loss: 0.08089, train_accuracy: 96.88, time: 0.22
[58,    71] loss: 0.07216, train_accuracy: 98.05, time: 0.22
[58,    81] loss: 0.09736, train_accuracy: 96.68, time: 0.22
[58,    91] loss: 0.10075, train_accuracy: 95.70, time: 0.21
duration: 21 s - train loss: 0.09827 - train accuracy: 96.53 - validation loss: 0.64 - validation accuracy: 85.62 
[59,     1] loss: 0.07835, train_accuracy: 96.68, time: 0.22
[59,    11] loss: 0.09496, train_accuracy: 97.07, time: 0.22
[59,    21] loss: 0.08488, train_accur

[69,    21] loss: 0.06560, train_accuracy: 98.24, time: 0.21
[69,    31] loss: 0.07181, train_accuracy: 96.68, time: 0.22
[69,    41] loss: 0.08324, train_accuracy: 97.66, time: 0.22
[69,    51] loss: 0.06328, train_accuracy: 97.66, time: 0.22
[69,    61] loss: 0.05066, train_accuracy: 98.05, time: 0.22
[69,    71] loss: 0.10185, train_accuracy: 96.29, time: 0.22
[69,    81] loss: 0.05843, train_accuracy: 97.85, time: 0.21
[69,    91] loss: 0.07956, train_accuracy: 97.07, time: 0.22
duration: 21 s - train loss: 0.07911 - train accuracy: 97.29 - validation loss: 1.09 - validation accuracy: 81.60 
[70,     1] loss: 0.06561, train_accuracy: 97.27, time: 0.22
[70,    11] loss: 0.12023, train_accuracy: 95.51, time: 0.22
[70,    21] loss: 0.08062, train_accuracy: 97.66, time: 0.22
[70,    31] loss: 0.07498, train_accuracy: 97.85, time: 0.21
[70,    41] loss: 0.11229, train_accuracy: 95.90, time: 0.22
[70,    51] loss: 0.08267, train_accuracy: 96.68, time: 0.22
[70,    61] loss: 0.07575, trai

[80,    61] loss: 0.05663, train_accuracy: 97.46, time: 0.21
[80,    71] loss: 0.05860, train_accuracy: 98.05, time: 0.21
[80,    81] loss: 0.05521, train_accuracy: 98.44, time: 0.22
[80,    91] loss: 0.04316, train_accuracy: 98.05, time: 0.24
duration: 21 s - train loss: 0.05957 - train accuracy: 97.91 - validation loss: 0.90 - validation accuracy: 84.93 
[81,     1] loss: 0.06014, train_accuracy: 98.05, time: 0.22
[81,    11] loss: 0.07466, train_accuracy: 97.46, time: 0.21
[81,    21] loss: 0.06304, train_accuracy: 98.63, time: 0.22
[81,    31] loss: 0.10172, train_accuracy: 96.68, time: 0.22
[81,    41] loss: 0.04484, train_accuracy: 98.83, time: 0.22
[81,    51] loss: 0.05833, train_accuracy: 98.44, time: 0.22
[81,    61] loss: 0.02482, train_accuracy: 99.61, time: 0.22
[81,    71] loss: 0.09741, train_accuracy: 97.07, time: 0.21
[81,    81] loss: 0.11929, train_accuracy: 97.46, time: 0.22
[81,    91] loss: 0.05356, train_accuracy: 98.44, time: 0.22
duration: 21 s - train loss: 0.

duration: 21 s - train loss: 0.05034 - train accuracy: 98.34 - validation loss: 1.02 - validation accuracy: 83.32 
[92,     1] loss: 0.06067, train_accuracy: 98.44, time: 0.22
[92,    11] loss: 0.05157, train_accuracy: 98.05, time: 0.22
[92,    21] loss: 0.03914, train_accuracy: 98.44, time: 0.22
[92,    31] loss: 0.03496, train_accuracy: 98.83, time: 0.21
[92,    41] loss: 0.06225, train_accuracy: 97.85, time: 0.22
[92,    51] loss: 0.10543, train_accuracy: 96.68, time: 0.22
[92,    61] loss: 0.04716, train_accuracy: 98.24, time: 0.22
[92,    71] loss: 0.07610, train_accuracy: 97.85, time: 0.22
[92,    81] loss: 0.06546, train_accuracy: 97.85, time: 0.21
[92,    91] loss: 0.03806, train_accuracy: 98.05, time: 0.22
duration: 21 s - train loss: 0.05596 - train accuracy: 98.12 - validation loss: 1.01 - validation accuracy: 84.08 
[93,     1] loss: 0.05653, train_accuracy: 98.24, time: 0.22
[93,    11] loss: 0.03948, train_accuracy: 98.63, time: 0.22
[93,    21] loss: 0.05557, train_accur

Unnamed: 0,epoch,train_loss,train_accuracy,validation_loss,validation_accuracy,duration,criterion,optimizer,method,batchsize
0,1.0,2.196176,25.951128,1.766533,33.61,21.328791,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
1,2.0,1.595128,40.962384,1.640888,42.39,42.681702,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
2,3.0,1.343528,51.022781,1.623494,47.01,64.041028,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
3,4.0,1.152945,58.990923,1.770410,46.16,85.327015,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
4,5.0,1.011138,64.042191,1.212974,60.79,106.638274,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
...,...,...,...,...,...,...,...,...,...,...
95,96.0,0.049020,98.313745,1.075013,83.73,2049.109720,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
96,97.0,0.052766,98.246743,0.957367,85.57,2070.442676,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
97,98.0,0.052664,98.261833,0.980498,84.10,2091.747197,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0
98,99.0,0.049976,98.292582,0.794041,86.56,2113.115365,CrossEntropyLoss(),Adam (\nParameter Group 0\n amsgrad: False\...,standard,512.0


In [6]:
#from . helpers import _craft_advs, _evaluate_model

def _fit(model, train_loader, val_loader, epochs, device, patience=None, evaluate_robustness=False):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    
    train_stats=pd.DataFrame([])
    total_time = 0
    epochs_trained = 0
    train_loss_hist, train_acc_hist, val_loss_hist, val_acc_hist = [], [], [], []
    for epoch in range(epochs):  # loop over the dataset multiple times
        t0 = time.time()
        acc_epoch_loss, avg_epoch_loss, epoch_accuracy, acc_epoch_accuracy = 0.0, 0.0, 0.0, 0.0
        
        
        for i, data in enumerate(train_loader, 0):
            t00 = time.time()
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            batchsize = labels.size(0)
            correct = (predicted == labels).sum().item()
            accuracy = 100 * correct / batchsize
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

                # print statistics
            acc_epoch_loss += loss.item() 
            avg_epoch_loss = acc_epoch_loss / (i+1)
            acc_epoch_accuracy += accuracy
            avg_epoch_accuracy = acc_epoch_accuracy / (i+1)
            t11 = time.time()
            if i%10 == 0:
                print('[%d, %5d] loss: %.5f, train_accuracy: %.2f, time: %.2f' %(epoch + 1, i + 1, loss.item(), accuracy, t11-t00))
        t1 = time.time()
        total_time += t1 - t0
        accuracy, loss = _evaluate_model(model, val_loader, device, criterion)
        #print('duration:', t1-t0,'- train loss: ',avg_epoch_loss,' - train accuracy: ',avg_epoch_accuracy,' - validation accuracy: ', accuracy,' - validation loss: ', loss)
        print('duration: %d s - train loss: %.5f - train accuracy: %.2f - validation loss: %.2f - validation accuracy: %.2f ' %(t1-t0, avg_epoch_loss, avg_epoch_accuracy, loss, accuracy))
        train_loss_hist.append(avg_epoch_loss)
        train_acc_hist.append(avg_epoch_accuracy)
        val_loss_hist.append(loss)
        val_acc_hist.append(accuracy)
        data = {
            'epoch': epoch+1,
            'train_loss':avg_epoch_loss, 
            'train_accuracy':avg_epoch_accuracy,
            'validation_loss':loss,
            'validation_accuracy':accuracy,
            'duration':total_time,
            'criterion':criterion,
            'optimizer':optimizer,
            'method': 'standard',
            'batchsize': len(next(iter(train_loader))[1])
        }
        
        
        if epoch%3==0 and evaluate_robustness == True:
            (l_0_robustness, l_0_loss), (l_2_robustness, l_2_loss), (l_inf_robustness, l_inf_loss) = _evaluate_robustness(model, val_loader, device)
            date['l_0_robustness'] = l_0_robustness
            date['l_2_robustness'] = l_2_robustness
            date['l_inf_robustness'] = l_inf_robustness
        
        train_stats = train_stats.append(data, ignore_index=True)
        
        if patience != None and patience < epoch and stop_early(val_loss_hist, patience) == True:
            epochs_trained = i + 1
            print('stopped early after', patience, 'epochs without decrease of validation loss')
            break
    print('Finished Training')
    
    return train_stats
def _evaluate_model(model, data_loader, device, criterion):
    correct = 0
    total = 0
    acc_loss = 0.0
    avg_loss = 0.0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            #print(i)
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if criterion != None:
                loss = criterion(outputs, labels)
                acc_loss += loss.item() 
                avg_loss = acc_loss / (i+1)
            #print(outputs)
            _, predicted = torch.max(outputs.data, 1)
            #print(predicted)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    model.train()
    return accuracy, avg_loss

In [2]:


class CifarResNet(tf.keras.Model):
    def __init__(self):
        super(CifarResNet, self).__init__()
        self.conv1 = Conv2D(64, kernel_size=3,activation='relu', padding='same', strides=1)        
        #self.pool1 = layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')
        
        self.res_block1 = ResBlock(64, 64)
        self.res_block3 = ResBlock(64, 64)
        self.res_block4 = ResBlock(64, 128, 2)
        self.res_block7 = ResBlock(128, 128)
        self.res_block8 = ResBlock(128, 256, 2)
        self.res_block13 = ResBlock(256, 256)
        self.res_block14 = ResBlock(256 ,512, 2)
        self.res_block16 = ResBlock(512, 512)
        self.pool2 = layers.GlobalAveragePooling2D()
        self.dense1 = Dense(10, activation='softmax')
        
    def call(self,inputs, training=False):
        x = self.conv1(inputs)
        x = self.res_block1(x)
        x = self.res_block3(x)
        x = self.res_block4(x)
        x = self.res_block7(x)
        x = self.res_block8(x)
        x = self.res_block13(x)
        x = self.res_block14(x)
        x = self.res_block16(x)
        x = self.pool2(x)
        x = self.dense1(x)
        return x
    

class ImagenetteResNet(tf.keras.Model):
    def __init__(self):
        super(ResNet, self).__init__()
        self.conv1 = Conv2D(64, kernel_size=7,activation='relu', padding='same', strides=2)
        self.pool1 = layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')
        
        self.res_block1 = ResBlock(64, 64)
        self.res_block3 = ResBlock(64, 64)
        self.res_block4 = ResBlock(64, 128, 2)
        self.res_block7 = ResBlock(128, 128)
        self.res_block8 = ResBlock(128, 256, 2)
        self.res_block13 = ResBlock(256, 256)
        self.res_block14 = ResBlock(256 ,512, 2)
        self.res_block16 = ResBlock(512, 512)
        self.pool2 = layers.GlobalAveragePooling2D()
        self.dense1 = Dense(10, activation='softmax')
        
    def call(self,inputs, training=False):
        x = self.conv1(inputs)
        x = self.pool1(x)
        x = self.res_block1(x)
        x = self.res_block3(x)
        x = self.res_block4(x)
        x = self.res_block7(x)
        x = self.res_block8(x)
        x = self.res_block13(x)
        x = self.res_block14(x)
        x = self.res_block16(x)
        x = self.pool2(x)
        x = self.dense1(x)
        return x
    





class ResBlock(tf.keras.layers.Layer):
    def __init__(self, input_channels=3 ,output_channels = 64, stride=1, filter_size=3):
        super(ResBlock, self).__init__()
        self.stride = stride
        self.conv1 = Conv2D(output_channels, 3, strides=self.stride, padding='same')
        self.bn1 = layers.BatchNormalization()
        self.conv2 = Conv2D(output_channels, 3, strides=1, padding='same')
        self.bn2 = layers.BatchNormalization()
        if stride == 2:
            self.conv3 = Conv2D(output_channels, 1, strides=self.stride, padding='same')
            self.bn3 = layers.BatchNormalization()
        self.add1 = layers.Add()
    
    def call(self, inputs, training=False):
        x = self.bn1(inputs, training=training)
        x = tf.nn.relu(x)
        x = self.conv1(x)
        x = self.bn2(x, training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        #print(x.shape)
        if self.stride == 2:
            inputs = self.conv3(x)
            inputs = self.bn3(x)
        #print(inputs.shape)
        return (self.add1([x, inputs]))

    
class std_conv(tf.keras.layers.Layer):
    def __init__(self, filters_in, filters_out, strides=1, regulization=10, do=0):
        super(std_conv, self).__init__()
        self.conv = Conv2D(filters_out, 3, activation='relu', padding='same', strides=strides,kernel_regularizer=regularizers.l2(regulization))
        self.do = Dropout(do)
        self.bn = BatchNormalization()
    def call(self, x, training):
        x = self.conv(x)
        x = self.bn(x, training)
        x = self.do(x)
        return x
        


In [15]:
tf_m = CifarResNet()

In [9]:
tf_m(tf.random.uniform([1,224,224,3]))

2021-10-04 09:20:21.593150: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.7
2021-10-04 09:20:22.388727: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.10


<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[0.09670077, 0.11641659, 0.10450514, 0.09369759, 0.08706291,
        0.11690781, 0.07528384, 0.10028592, 0.12022388, 0.0889155 ]],
      dtype=float32)>

In [16]:
ds_train, ds_test, _, _ = load_data("imagenette", )
ds_train, ds_test, _, _ = load_data("cifar10", )

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


In [17]:
tf_m.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy() ,
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
        metrics=['accuracy'],
        experimental_run_tf_function=True
    )

In [18]:
%%time
hist = tf_m.fit(
    x=ds_train,
    epochs=100,
    validation_data=ds_test,
)

Epoch 1/100








Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 7

In [5]:


def load_data(dataset,ratio='100%'):

    def augment(image,label):
        #image = tf.image.convert_image_dtype(image, tf.float32)
        #image = tf.image.rot90(image, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)) # random rotation
        image = tf.image.random_flip_left_right(image)
        #image = tf.image.random_flip_up_down(image)
        #image = tf.image.random_hue(image, 0.08)
        #image = tf.image.random_saturation(image, 0.6, 1.6)
        #image = tf.image.random_contrast(image, 0.7, 1.3)
        #image = tf.image.random_brightness(image, max_delta=0.5) # Random brightness
        image = tf.image.resize_with_crop_or_pad(image, 224+60, 224+60) # Add 60 pixels of padding
        image = tf.image.random_crop(image, size=[224,224,3]) # Random crop back to 28x28
        return image,label
    
    def cifar_augment(image,label):
        image = tf.image.random_flip_left_right(image)
        image = tf.image.resize_with_crop_or_pad(image, 32+6, 32+6)
        image = tf.image.random_crop(image, size=[32,32,3])
        return image,label
    
    @tf.function
    def load_image(datapoint):
        input_image, label = normalize(datapoint)
        return input_image, label

    
    if dataset=='cifar10':
        ds, info = tfds.load(name=dataset, with_info=True, split=[f"train[:{ratio}]",f"test[:{ratio}]"])
        ds_train=ds[0]
        ds_test=ds[1]
        def normalize(x):
            y = {'image': tf.image.convert_image_dtype(x['image'], tf.float32), 'label': x['label']}
            y = (tf.image.resize(y['image'], (32,32)), y['label'])
            return y
        num_train_examples= info.splits['train'].num_examples
        BATCH_SIZE = 512

        ds_train = (
            ds_train
            .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            .take(num_train_examples)
            .cache()
            .shuffle(num_train_examples)
            .map(cifar_augment, num_parallel_calls=AUTOTUNE)
            .batch(BATCH_SIZE)
            .prefetch(AUTOTUNE)
        ) 

        ds_test = ds_test.map(
            normalize, )
        ds_test = ds_test.batch(BATCH_SIZE)
        ds_test = ds_test.cache()
        ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)



        attack_set = list(ds[1].map(load_image))[:1000]

        attack_images = tf.convert_to_tensor([sample[0] for sample in attack_set])
        attack_labels = tf.convert_to_tensor([sample[1] for sample in attack_set])
        return ds_train, ds_test, attack_images, attack_labels
    
    
    if dataset=='imagenette':
        ds, info = tfds.load(name=dataset, with_info=True, split=[f"train[:{ratio}]",f"validation[:{ratio}]"])
        
        ds_train=ds[0]
        ds_test=ds[1]
        def normalize(x):
            y = {'image': tf.image.convert_image_dtype(x['image'], tf.float32), 'label': x['label']}
            y = (tf.image.resize(y['image'], (224,224)), y['label'])
            return y


        num_train_examples= info.splits['train'].num_examples
        BATCH_SIZE = 128

        ds_train = (
            ds_train
            .map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            .take(num_train_examples)
            .cache()
            .shuffle(num_train_examples)
            .map(augment, num_parallel_calls=AUTOTUNE)
            .batch(BATCH_SIZE)
            .prefetch(AUTOTUNE)
        ) 

        ds_test = ds_test.map(
            normalize, )
        ds_test = ds_test.batch(BATCH_SIZE)
        ds_test = ds_test.cache()
        ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)



        attack_set = list(ds[1].map(load_image))[:256]

        attack_images = tf.convert_to_tensor([sample[0] for sample in attack_set])
        attack_labels = tf.convert_to_tensor([sample[1] for sample in attack_set])

        return ds_train, ds_test, attack_images, attack_labels
    
    return False