In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import attacks
import pandas as pd
import utils
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, TensorDataset
from torch import Tensor
# dataloaders, dataset_sizes = utils.data_loader()

In [4]:
def pgd_linf(model, X, y, epsilon=0.1, alpha=0.01, num_iter=5, randomize=False):

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    """ Construct FGSM adversarial examples on the examples X"""
    if randomize:
        delta = torch.rand_like(X, requires_grad=True).to(device)
        delta.data = delta.data * 2 * epsilon - epsilon
    else:
        delta = torch.zeros_like(X, requires_grad=True).to(device)


    for t in range(num_iter):
        loss = nn.CrossEntropyLoss()( model(X + delta), y)
        loss.backward()
        delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-epsilon,epsilon)
        delta.grad.zero_()
    return delta.detach()

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)   

model = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
                          nn.Conv2d(32, 32, 3, padding=1, stride=2), nn.ReLU(),
                          nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
                          nn.Conv2d(64, 64, 3, padding=1, stride=2), nn.ReLU(),
                          Flatten(),
                          nn.Linear(7*7*64, 100), nn.ReLU(),
                          nn.Linear(100, 10)).to(device)
model.load_state_dict(torch.load("model_cnn.pt" ))

def offline_task( input, model,probas, name):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data = []
    label = []
    status = []
    for i in tqdm( range( len(input) ) ): #
        is_corrupt = np.random.choice( [1, 0], p=probas)
        X,y = mnist_train[i]
        X = X.reshape( (1,1,28,28) ).to(device)
        y = torch.Tensor([ y ]).type(torch.LongTensor).to(device)
        if is_corrupt==1:
            delta = pgd_linf(model, X , y).to(device)
            X = X + delta
        data.append(  X[0] )
        label.append(y)
        status.append([is_corrupt])
    data = torch.stack(data)
    label = torch.stack(label)
    status = torch.Tensor(status)

    dataset = TensorDataset( data,label,status )
    torch.save(dataset,'./{}.pt'.format(name) )

# mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor() )
# offline_task(mnist_train,model,[0.5,0.5], 'train')

mnist_val = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor() )
offline_task(mnist_val,model,[0.5,0.5], 'val')


# dataset = torch.load('./train.pt')
# dataloader = DataLoader( dataset, batch_size = 10, shuffle=True)

# for x,y,z in dataloader:
#     print(x.shape)
#     print(y)
#     print(z)


# dataloaders = {
#         'train': DataLoader(mnist_train, batch_size = batch_size, shuffle=True, num_workers=4),
#         'val': DataLoader(mnist_val, batch_size = 100, shuffle=False, num_workers = 4)
#     }


  0%|          | 0/10000 [00:00<?, ?it/s]

In [None]:
def create_offline_task(target):

    mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
    mnist_val = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())

    status =  torch.Tensor([ np.random.choice( [1, 0], p=[0.5, 0.5]) for x in range(X.shape[0])  ])
    status = status.type(torch.FloatTensor).to(device)

    X = [ x+attacks.pgd_linf(model, x, y) if is_corrupt==1 else x for x,d, is_corrupt in zip(X,delta,status) ]
    X = torch.stack(X)#.reshape( (len(X),1,28,28) )

    train_ratio = 0.4
    train_size = int(len(mnist_train) * train_ratio)
    dump_size = len(mnist_train) - train_size
    mnist_train = torch.utils.data.random_split(mnist_train, [train_size, dump_size])[0]

    batch_size = 128

    dataloaders = {
        'train': DataLoader(mnist_train, batch_size = batch_size, shuffle=True, num_workers=4),
        'val': DataLoader(mnist_val, batch_size = 100, shuffle=False, num_workers = 4)
    }

    dataset_sizes = {'train': 0.85, 'val': 0.15}
    print(len(mnist_train))
    print(len(mnist_val))

    return mnist_train,mnist_val

In [None]:
 

controler = nn.Sequential(nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
                          nn.Conv2d(32, 32, 3, padding=1, stride=2), nn.ReLU(),
                          nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
                          nn.Conv2d(64, 64, 3, padding=1, stride=2), nn.ReLU(),
                          Flatten(),
                          nn.Linear(7*7*64, 100), nn.ReLU(),
                          nn.Linear(100, 1),
                          nn.Sigmoid() ).to(device)




n_epochs = 50
bar = trange(n_epochs)
epoch_loss = {'train': 0, 'val': 0}
dft = pd.DataFrame(columns=['train_rmse', 'val_rmse'])

optimizer = optim.Adam(controler.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

for epoch in bar:
    
    # Each epoch has a training and validation phase
    for phase in ['train', 'val']: 
        if phase == 'train':
            controler.train()  # Set model to training mode
        else:
            controler.eval()   # Set model to evaluate mode

        running_loss = 0.0
        bar.set_description(f'Epoch {epoch} {phase}'.ljust(20))

        # Iterate over data.
        for X,y in dataloaders[phase]:
            X,y = X.to(device), y.to(device)
            delta = attacks.pgd_linf(model, X, y)

            status =  torch.Tensor([ np.random.choice( [1, 0], p=[0.5, 0.5]) for x in range(X.shape[0])  ])
            status = status.type(torch.FloatTensor).to(device)
            X = [ x+d if is_corrupt==1 else x for x,d, is_corrupt in zip(X,delta,status) ]
            X = torch.stack(X)#.reshape( (len(X),1,28,28) )
     
            # zero the parameter gradients
#             optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(phase == 'train'):
                
                yp = controler( X )

                feedback = [ idx for idx,y in enumerate(yp)  if y>0.5  ] #or random_exploration(0.1, 0.9)
                if len(feedback)>0:
                    y_feedback = torch.stack([ yp[idx] for idx in feedback ]).type(torch.FloatTensor).to(device)
                    status_feedback = torch.stack([ status[idx] for idx in feedback ]).type(torch.FloatTensor).to(device)
                    status_feedback = torch.reshape(status_feedback, (-1,1) )
                    loss = torch.nn.BCELoss()( y_feedback, status_feedback )

                # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    running_loss += loss.item() * X.size(0)
                    epoch_loss[phase] = running_loss / dataset_sizes[phase]

        if phase == 'train':
            scheduler.step()
            print(scheduler.get_last_lr())
            
        bar.set_postfix( train_loss=f'{epoch_loss["train"]:0.5f}', val_loss=f'{epoch_loss["val"]:0.5f}') 
        dft.loc[epoch, f'{phase}_rmse'] = epoch_loss[phase]
        
        status = {
            "epoch": epoch,
            "model_state": controler.state_dict(),
            "optimizer_state": optimizer.state_dict() }
        
        torch.save(status, './controler.pt')