<a href="https://colab.research.google.com/github/ahmad-PH/nag-notebooks/blob/master/NAG_tripletLossExperiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [235]:
import subprocess
def run_shell_command(cmd):
  p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE)
  print(str(p.communicate()[0], 'utf-8'))
  
def detect_env():
    import os
    if 'content' in os.listdir('/'):
      return "colab"
    else:
      return "IBM"
    
def create_env():
  if detect_env() == "IBM":
    return IBMEnv()
  elif detect_env() == "colab":
    return ColabEnv()


class Env:
  def get_nag_util_files(self):
      import os
      
      print("\ngetting git files ...")
      if os.path.isdir(self.python_files_path):
        os.chdir(self.python_files_path)
        run_shell_command('git pull')
        os.chdir(self.root_folder)
      else:
        run_shell_command('git clone https://github.com/ahmad-PH/nag-public.git')
      print("done.")
  

class IBMEnv(Env):
    def __init__(self):
      self.root_folder = "/root/Derakhshani/adversarial"
      self.temp_csv_path = self.root_folder + "/temp"
      self.python_files_path = self.root_folder + "/nag-public"
      self.python_files_dir = "NAG-11May-beforeDenoiser"
      
      import sys
      sys.path.append('./nag/nag_util')
      
    def get_csv_path(self):
      return self.root_folder + "/textual_notes/CSVs/" + self.save_filename
    
    def get_models_path(self):
      return self.root_folder + "/models/" + self.save_filename
      
    def setup(self):
      self.get_nag_util_files()
#       defaults.device = torch.device('cuda:0')
      import os;
      os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
      os.environ['CUDA_VISIBLE_DEVICES']='0'
      
    def load_dataset(self, compressed_name, unpacked_name):
      pass

    def load_test_dataset(root_folder):
        raise NotImplementedError("test dataset on IBM needs work...")
    
    def set_data_path(self, path):
      self.data_path = Path(self.root_folder + '/datasets/' + path)
    
        
class ColabEnv(Env):
    def __init__(self):
      self.root_folder = '/content'
      self.temp_csv_path = self.root_folder
      self.python_files_path = self.root_folder + '/nag-public'
      self.python_files_dir = "NAG-11May-beforeDenoiser"
      self.torchvision_upgraded = False
      
    def get_csv_path(self):
      return self.root_folder + '/gdrive/My Drive/DL/textual_notes/CSVs/' + self.save_filename
    
    def get_models_path(self):
      return self.root_folder + "/gdrive/My Drive/DL/models/" + self.save_filename
        
    def setup(self):
        # ######################################################
        # # TODO remove this once torchvision 0.3 is present by
        # # default in Colab
        # ######################################################
        try:
            ColabEnv.torchvision_upgraded
        except AttributeError:
          !pip uninstall -y torchvision
          !pip install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl
          ColabEnv.torchvision_upgraded = True
        else:
          print("torchvision already upgraded")
          
        
        drive.mount('/content/gdrive')
        
        self.get_nag_util_files()
        
    def load_dataset(self, compressed_name, unpacked_name):
      if compressed_name not in os.listdir('.'):
        print(compressed_name + ' not found, getting it from drive')
        shutil.copyfile("/content/gdrive/My Drive/DL/{}.tar.gz".format(compressed_name), "./{}.tar.gz".format(compressed_name))

        gunzip_arg = "./{}.tar.gz".format(compressed_name)
        !gunzip -f $gunzip_arg

        tar_arg = "./{}.tar".format(compressed_name)
        !tar -xvf $tar_arg > /dev/null

        os.rename(unpacked_name, compressed_name)

    #     ls_arg = "./{}/train/n01440764".format(compressed_name)
    #     !ls $ls_arg

        !rm $tar_arg

        print("done") 
      else:
        print(compressed_name + " found")
        
    def load_test_dataset(root_folder):
      test_folder = root_folder + '/test/'
      if 'test' not in os.listdir(root_folder):
        os.mkdir(test_folder)
        for i in range(1,11):
          shutil.copy("/content/gdrive/My Drive/DL/full_test_folder/{}.zip".format(i), test_folder)
          shutil.unpack_archive(test_folder + "/{}.zip".format(i), test_folder)
          os.remove(test_folder + "/{}.zip".format(i))
          print("done with the {}th fragment".format(i))
        
    def set_data_path(self, path):
      self.data_path = Path('./' + path)
        

In [236]:
from fastai.vision import *
from fastai.imports import *
from fastai.callbacks import *
from fastai.utils.mem import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import sys; import os; import shutil
if detect_env() == "colab":
  from google.colab import drive

In [237]:
env = create_env()
env.setup()


getting git files ...
Already up-to-date.

done.


In [238]:
sys.path.append(env.python_files_path + '/' + env.python_files_dir)

from nag_util import *
import nag_util

In [239]:
# mode = "sanity_check"
mode = "normal"

In [240]:
if mode == "normal":
  env.load_dataset('dataset','data')
  env.set_data_path('dataset')
elif mode == "sanity_check":
  env.load_dataset('dataset_sanity_check_small', 'dataset_sanity_check_small')  
  env.set_data_path('dataset_sanity_check_small')

In [241]:
batch_size = 8
gpu_flag = True
nag_util.batch_size = batch_size; nag_util.gpu_flag = gpu_flag;
# nag_util.set_globals(gpu_flag, batch_size)
tfms = get_transforms(do_flip=False, max_rotate=0)
data = (ImageList.from_folder(env.data_path)
        .split_by_folder(valid='test')
        .label_from_folder()
        .transform(tfms, size=224)
        .databunch(bs=batch_size, num_workers=1)
        .normalize(imagenet_stats))

# data.show_batch(rows=2, figsize=(5,5))

In [242]:
# model = models.resnet50
model = models.vgg16_bn
# model = torchvision.models.googlenet
model_name = model.__name__
z_dim = 10

class SoftmaxWrapper(nn.Module):
  def __init__(self, m):
    super().__init__()
    self.m = m
    self.softmax = nn.Softmax(dim=-1)
    
  def forward(self, inp):
    out = self.m(inp)
    return self.softmax(out)
  
arch = SoftmaxWrapper(model(pretrained=True).cuda().eval())
nag_util.arch = arch
requires_grad(arch, False)

# vgg:
# layers = []
# blocks = [i-1 for i,o in enumerate(children(arch.features)) if isinstance(o, nn.MaxPool2d)]
# layers = [arch.features[i] for i in blocks]
# layer_weights = [1] * len(layers)

# resnet:
# layers = [
#   arch.layer2[0].downsample,
#   arch.layer3[0].downsample,
#   arch.layer4[0].downsample
# ]
layers = [
    arch.softmax
]

layer_weights = [1.] * len(layers)

# layers = []
# last_layer = None
# for o in children(arch):
#   if isinstance(o, nn.AdaptiveAvgPool2d):
#     layers.append(last_layer)
#   last_layer = o
    
# # layers = [arch.fc]

# layer_weights = [1] * len(layers)

# inception:
# layers = [
#     arch.Conv2d_1a_3x3,
#     arch.Mixed_6e,
#     arch.Mixed_7a,
#     arch.fc    
# ]
# layer_weights = [1.0/4.0] * len(layers)

In [243]:
class Gen(nn.Module):
  def __init__(self, z_dim, gf_dim=64, y_dim = None, df_dim = 64, image_shape = [3,128,128]):
    super(Gen, self).__init__()

    self.bs = None
    self.z_dim = z_dim
    self.gf_dim = gf_dim
    self.y_dim = y_dim
    self.df_dim = df_dim
    self.image_shape = image_shape

    self.z_ = nn.Linear(self.z_dim, self.gf_dim * 7 * 4 * 4, bias=True)
    self.z_.bias.data.fill_(0)
    self.BN_ = nn.BatchNorm2d(self.gf_dim * 7)

    self.CT2d_1 = deconv_layer(self.gf_dim * 8, 
                             self.gf_dim * 4,
                              k_size = (5,5), s = (2,2), pad = (2,2))
    self.CT2d_2 = deconv_layer(self.gf_dim * 5, self.gf_dim * 2)

    self.half = self.gf_dim // 2
    if self.half == 0:
      self.half == 1
    self.CT2d_3 = deconv_layer(self.gf_dim * 2 + self.half, self.gf_dim * 1)

    self.quarter = self.gf_dim // 4
    if self.quarter == 0:
      self.quarter == 1
    self.CT2d_4 = deconv_layer(self.gf_dim * 1 + self.quarter, self.gf_dim * 1)

    self.eighth = self.gf_dim // 8
    if self.eighth == 0:
      self.eighth == 1
    self.CT2d_5 = deconv_layer(self.gf_dim * 1 + self.eighth, self.gf_dim * 1)

    # sixteenth = self.gf_dim // 16
    # if half == 0:
      # half == 1
    self.CT2d_6 = deconv_layer(self.gf_dim * 1 + self.eighth, self.gf_dim * 1)

    # sixteenth = self.gf_dim // 16
    # if half == 0:
      # half == 1
    self.CT2d_7 = deconv_layer(self.gf_dim * 1 + self.eighth, 3, k_size = (5,5), s = (1,1), pad = (2,2), activation = False)


  def forward_z(self, z):
    self.bs = z.shape[0]
      
    # define generator here
    # input: bs * 100
    # Linear (z_dim, gf_dim * 7 * 4 * 4), bias = (True, init with zero), 
    # Reshape (bs, gf_dim * 7 * 4 * 4) -> (bs, gf_dim * 7, 4 , 4)
    # Virtual Batch Norm = VBN
    # ReLU
    # h0 <- relu output
    h0 = F.relu(self.BN_(self.z_(z).contiguous().view(self.bs, -1, 4, 4)))
    assert h0.shape[2:] == (4, 4), "Non-expected shape, it shoud be (4,4)"

    # h0z = self.make_z([bs, gf_dim, 4, 4])
    # h0 = torch.cat([h0, h0z], dim=1)
    # h1 = deconv(gf_dim * 8, gf_dim * 4, kernel = (5, 5), stride = (2,2), padding = (2,2), bias = (True, 0))
    # h1 = ReLU(VBN(h1))
    h0z = self.make_z([self.bs, self.gf_dim, 4, 4])
    h0 = torch.cat([h0, h0z], dim=1)
    h1 = self.CT2d_1(h0)
    assert h1.shape[2:] == (7, 7), "Non-expected shape, it shoud be (7,7)"

    # h1z = self.make_z([bs, gf_dim, 7, 7])
    # h1 = torch.cat([h1, h1z], dim=1)
    # h2 = deconv(gf_dim * 5, gf_dim * 2, kernel = (5, 5), stride = (2,2), padding = (2,2), bias = (True, 0))
    # h2 = ReLU(VBN(h2))
    # assert output size (14,14)
    h1z = self.make_z([self.bs, self.gf_dim, 7, 7])
    h1 = torch.cat([h1, h1z], dim=1)
    h2 = self.CT2d_2(h1)
    assert h2.shape[2:] == (14,14), "Non-expected shape, it shoud be (14,14)"

    # h2z = self.make_z([bs, half, 14, 14])
    # h2 = torch.cat([h2, h2z], dim=1)
    # h3 = deconv(gf_dim  2 + half, gf_dim  1, kernel = (5, 5), stride = (2,2), padding = (2,2), bias = (True, 0))
    # h3 = ReLU(VBN(h3))
    h2z = self.make_z([self.bs, self.half, 14, 14])
    h2 = torch.cat([h2, h2z], dim=1)
    h3 = self.CT2d_3(h2)
    assert h3.shape[2:] == (28,28), "Non-expected shape, it shoud be (28,28)"

    # h3z = self.make_z([bs, quarter, 28, 28])
    # h3 = torch.cat([h3, h3z], dim=1)
    # h4 = deconv(gf_dim * 1 + quarter, gf_dim * 1, kernel = (5, 5), stride = (2,2), padding = (2,2), bias = (True, 0))
    # h4 = ReLU(VBN(h4))
    h3z = self.make_z([self.bs, self.quarter, 28, 28])
    h3 = torch.cat([h3, h3z], dim=1)
    h4 = self.CT2d_4(h3)
    assert h4.shape[2:] == (56,56), "Non-expected shape, it shoud be (56,56)"

    # h4z = self.make_z([bs, self.eighth, 56, 56])
    # h4 = torch.cat([h4, h4z], dim=1)
    # h5 = deconv(gf_dim * 1 + eighth, gf_dim * 1, kernel = (5, 5), stride = (2,2), padding = (2,2), bias = (True, 0))
    # h5 = ReLU(VBN(h5))

    h4z = self.make_z([self.bs, self.eighth, 56, 56])
    h4 = torch.cat([h4, h4z], dim=1)
    h5 = self.CT2d_5(h4)
    assert h5.shape[2:] == (112,112), "Non-expected shape, it shoud be (112,112)"

    # h5z = self.make_z([bs, eighth, 112, 112])
    # h5 = torch.cat([h5, h5z], dim=1)
    # h6 = deconv(gf_dim * 1 + eighth, gf_dim * 1, kernel = (5, 5), stride = (2,2), padding = (2,2), bias = (True, 0))
    # h6 = ReLU(VBN(h5))
    h5z = self.make_z([self.bs, self.eighth, 112, 112])
    h5 = torch.cat([h5, h5z], dim=1)
    h6 = self.CT2d_6(h5)
    assert h6.shape[2:] == (224,224), "Non-expected shape, it shoud be (224,224)"

    # h6z = self.make_z([bs, eighth, 224, 224])
    # h6 = torch.cat([h6, h6z], dim=1)
    # h7 = deconv(gf_dim * 1 + eighth, 3, kernel = (5, 5), stride = (2,2), padding = (2,2), bias = (True, 0))
    # h7 = ReLU(VBN(h7))
    h6z = self.make_z([self.bs, self.eighth, 224, 224])
    h6 = torch.cat([h6, h6z], dim=1)
    h7 = self.CT2d_7(h6)
    assert h7.shape[2:] == (224,224), "Non-expected shape, it shoud be (448,448)"

    # out = 10*tanh(h7)

    #     return 10 *F.tanh(h7)
    ksi = 10.0
    output_coeff = ksi / (255.0 * np.mean(imagenet_stats[1])) 
    # this coeff scales the output to be appropriate for images that are 
    # normalized using imagenet_stats (and are hence in the approximate [-2.5, 2.5]
    # interval)
    return output_coeff * torch.tanh(h7)
    # return 0.15 * torch.tanh(h7)

  def forward(self, inputs):
    self.bs = inputs.shape[0]
    z = inputs.new_empty([self.bs, self.z_dim]).uniform_(-1,1).cuda()
    p, n = self.make_triplet_samples(z, 0.1, 0.1, 2.)
    
    z_out = self.forward_z(z)
    p_out = self.forward_z(p)
    n_out = self.forward_z(n)
    
    return z_out, p_out, n_out, inputs
  
  def forward_single_z(self, z):
    return self.forward_z(z[None]).squeeze()
           
  
  def make_triplet_samples(self, z, margin, r2, r3):
    positive_sample = z + self.random_vector_volume(z.shape, 0, margin).cuda() 
    negative_sample = z + self.random_vector_volume(z.shape, r2, r3).cuda()
#     negative_sample = z + self.random_vector_volume(z.shape, margin, margin * scale).cuda()
    return positive_sample, negative_sample

  def random_vector_surface(self, shape, r = 1.):
    mat = torch.randn(size=shape).cuda()
    norm = torch.norm(mat, p=2, dim=1, keepdim = True).cuda()
    return (mat/norm) * r

#   def random_vector_volume(shape, inner_r = 0, outer_r):
#     d = torch.zeros(shape[0]).uniform_()   ** (1/int(np.prod(shape[0])))
#     d.unsqueeze_(-1)
#     return random_vector_surface(shape, outer_r) * d
  
  def random_vector_volume(self, shape, inner_r, outer_r):
#     d = torch.zeros(shape[0]).uniform_(0, outer_r - inner_r).cuda()
    fraction = torch.empty(shape[0]).uniform_(inner_r, outer_r).cuda()
    fraction = ((fraction / outer_r) ** (1 / shape[1])) * outer_r # volume-normalize the fraction
    fraction.unsqueeze_(-1)
#     return self.random_vector_surface(shape, 1) * d + inner_r
    return self.random_vector_surface(shape, 1) * fraction

  def make_z(self, in_shape):
    return torch.empty(in_shape, device="cuda:0").uniform_(-1,1)


In [244]:
def load_starting_point(learn, name, z_dim):
  if detect_env() != "colab":
    raise NotImplementedError("load_starting_point not implemented for non-colab environments yet.")
  import os
  identity_token = name + '-zdim' + str(z_dim)
  address = '/content/gdrive/My Drive/DL/model_starting_points/' + identity_token
  starting_point_exists = os.path.isfile(address + '.pth')
  if not starting_point_exists:
    print("\n\nno starting point found for model:" + identity_token + ". creating one from the current learner.\n\n")
    learn.save(address)
  learn.load(address)

In [245]:
torch.set_printoptions(precision=2, sci_mode=False, threshold=5000)

def print_softmax_tensor(x):
  print("[", end="")
  for i, x_i in enumerate(x.data):
    if abs(x_i) > 0.01:
      print("{}: {:.2f}".format(i, x_i.item()), end=(", " if (i < x.shape[0]-1) else ""))
  print("]")
  
# print_softmax_tensor(torch.tensor([0.01, 2.5, 5.]))

In [246]:
def js_distance(x1, x2):
  m = 0.5 * (x1 + x2)
  return 0.5 * (F.kl_div(x1, m) + F.kl_div(x2, m))

def kl_distance(x1, x2):
  return F.kl_div(x1, x2)

def wasserstein_distance(x1, x2):
  pass

def l1_distance(x1, x2):
  return F.l1_loss(x1, x2)

def l2_distance(x1, x2):
  return F.mse_loss(x1 * 10, x2 * 10)

def cos_distance(x1, x2):
    return -1 * torch.mean(F.cosine_similarity(x1, x2))

triplet_call_cnt = 0

def triplet_loss(anchor, positive, negative, distance_func, margin):
  # max distance when using l1_distance is 2
  # max distacne when using l2-distance is sqrt(2)
#   print("anchor: ", anchor.min(), anchor.max())
  ap_dist = distance_func(anchor, positive)
  an_dist = distance_func(anchor, negative)

  global triplet_call_cnt
  triplet_call_cnt += 1
  if triplet_call_cnt % 10 == 0:
    print("a: ", end=""); print_softmax_tensor(anchor[0])
    print("p: ", end=""); print_softmax_tensor(positive[0])
    print("n: ", end=""); print_softmax_tensor(negative[0])
    print("ap_dist: {}, an_dist: {}".format(ap_dist, an_dist))
    
  return torch.mean(F.relu(ap_dist - an_dist + margin))

In [247]:
def diversity_loss(input, target):
#   return -1 * torch.mean(torch.pow(f_x_a-f_x_s,2))
  if input.shape[0] != batch_size:
    print("input shape: ", input.shape)
    print("target shape: ", target.shape, "\n\n")
  return torch.mean(F.cosine_similarity(
    input.view([batch_size, -1]),
    target.view([batch_size, -1]), 
  ))

In [248]:
# z1 = torch.tensor([[1., 0.]])
# z2 = torch.tensor([[-1., 0]])
# cos_sim(z1,z2)

# z1 = torch.tensor([10.] + ([0.] * 999))
# z2 = torch.tensor([0., 10.] + [0.] * 998)
# l2_distance(z1, z2)

In [249]:
class FeatureLoss(nn.Module):
    def __name__(self):
      return "feature_loss"
  
    def __init__(self, dis, layers, layer_weights):
        super().__init__()
        
        # define generator here 
        self.dis = dis
        self.diversity_layers = layers
        self.hooks = hook_outputs(self.diversity_layers, detach=False)
        self.weights = layer_weights
        self.metric_names = ["fool_loss"] + [f"div_loss_{i}" for i in range(len(layers))] + ['triplet_loss']# Maybe Gram
        self.triplet_weight = 10.
    
    def make_features(self, x, clone=False):
        y = self.dis(x)
        return y, [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, inp, target):
        sigma_B, sigma_pos, sigma_neg, X_B = inp

        X_A = self.add_perturbation(X_B, sigma_B) 
        X_A_pos = self.add_perturbation(X_B, sigma_pos)
        X_A_neg = self.add_perturbation(X_B, sigma_neg) 
        
        X_S = self.add_perturbation_shuffled(X_B, sigma_B) # Shuffled Addversarial Examples
        
        B_Y, _ = self.make_features(X_B)
        A_Y, A_feat = self.make_features(X_A)
        _, S_feat = self.make_features(X_S)
        pos_softmax, _ = self.make_features(X_A_pos)
        neg_softmax, _ = self.make_features(X_A_neg)
        
        
        fooling_loss =  fool_loss(A_Y, B_Y)
      
        raw_diversity_losses = [diversity_loss(a_f, s_f) for a_f, s_f in zip(A_feat, S_feat)]
        weighted_diversity_losses = [diversity_loss(a_f, s_f) * weight for a_f, s_f, weight in zip(A_feat, S_feat, self.weights)]
        
        raw_triplet_loss = triplet_loss(A_Y, pos_softmax, neg_softmax, l2_distance, 10.)
        weighted_triplet_loss = raw_triplet_loss * self.triplet_weight
    
        self.losses = [fooling_loss] + weighted_diversity_losses + [weighted_triplet_loss]
        self.metrics = dict(zip(self.metric_names, [fooling_loss] + raw_diversity_losses + [weighted_triplet_loss]))
        
        return sum(self.losses)

#     def forward(self, inp, target):
#       sigma_B, sigma_pos, sigma_neg, X_B = inp

#       X_A = self.add_perturbation(X_B, sigma_B) 

#       X_S = self.add_perturbation_shuffled(X_B, sigma_B) # Shuffled Addversarial Examples

#       B_Y, _ = self.make_features(X_B)
#       A_Y, A_feat = self.make_features(X_A)
#       _, S_feat = self.make_features(X_S)

#       fooling_loss =  fool_loss(A_Y, B_Y)

#       raw_diversity_losses = [diversity_loss(a_f, s_f) for a_f, s_f in zip(A_feat, S_feat)]
#       weighted_diversity_losses = [diversity_loss(a_f, s_f) * weight for a_f, s_f, weight in zip(A_feat, S_feat, self.weights)]

#       raw_triplet_loss = triplet_loss(sigma_B, sigma_pos, sigma_neg, l2_distance, 5.)
#       weighted_triplet_loss = raw_triplet_loss * self.triplet_weight

#       self.losses = [fooling_loss] + weighted_diversity_losses + [weighted_triplet_loss]
#       self.metrics = dict(zip(self.metric_names, [fooling_loss] + raw_diversity_losses + [weighted_triplet_loss]))

#       return sum(self.losses)
  
  
    def add_perturbation(self, inp, perturbation):
        return inp.add(perturbation)
  
    def add_perturbation_shuffled(self, inp, perturbation):
#         j = torch.randperm(inp.shape[0])
        j = derangement(inp.shape[0])
        return inp.add(perturbation[j])

In [250]:
feat_loss = FeatureLoss(arch, layers, layer_weights)

In [251]:
env.save_filename = 'vgg16_11'

if Path(env.get_csv_path() + '.csv').exists(): raise FileExistsError("csv_path already exists")
if Path(env.get_models_path()).exists(): raise FileExistsError("models_path already exists")

In [252]:
learn = None; gc.collect()
csv_logger = partial(ImmediateCSVLogger, filename= env.temp_csv_path + '/' + env.save_filename)
# learn = Learner(data, Gen(z_dim=10), loss_func = feat_loss, metrics=[validation], callback_fns=LossMetrics, opt_func = optim.SGD)
# learn = Learner(data, Gen(z_dim=z_dim), loss_func = feat_loss, metrics=[validation], callback_fns=[LossMetrics, DiversityWeightsScheduler])
learn = Learner(data, Gen(z_dim=z_dim), loss_func = feat_loss, metrics=[validation], callback_fns=[LossMetrics, csv_logger])
# load_starting_point(learn, model_name, z_dim)
# random_seed(42, True)

In [253]:
# learn.lr_find(1e-6, 1000)
# learn.recorder.plot()

In [254]:
!ls ./models/vgg16_10

vgg16_10-best.pth  vgg16_10_15.pth  vgg16_10_22.pth  vgg16_10_3.pth
vgg16_10_0.pth	   vgg16_10_16.pth  vgg16_10_23.pth  vgg16_10_4.pth
vgg16_10_1.pth	   vgg16_10_17.pth  vgg16_10_24.pth  vgg16_10_5.pth
vgg16_10_10.pth    vgg16_10_18.pth  vgg16_10_25.pth  vgg16_10_6.pth
vgg16_10_11.pth    vgg16_10_19.pth  vgg16_10_26.pth  vgg16_10_7.pth
vgg16_10_12.pth    vgg16_10_2.pth   vgg16_10_27.pth  vgg16_10_8.pth
vgg16_10_13.pth    vgg16_10_20.pth  vgg16_10_28.pth  vgg16_10_9.pth
vgg16_10_14.pth    vgg16_10_21.pth  vgg16_10_29.pth


In [255]:
# !cp "/content/gdrive/My Drive/DL/models/vgg-16_2.pth"  "/content/"
# learn.load('/content/vgg-16_2')

learn.load('/root/Derakhshani/adversarial/models/vgg16_10/vgg16_10_29')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Learner(data=ImageDataBunch;

Train: LabelList (9000 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n02454379,n02454379,n02454379,n02454379,n02454379
Path: /root/Derakhshani/adversarial/datasets/dataset;

Valid: LabelList (50000 items)
x: ImageList
Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)
y: CategoryList
n02454379,n02454379,n02454379,n02454379,n02454379
Path: /root/Derakhshani/adversarial/datasets/dataset;

Test: None, model=Gen(
  (z_): Linear(in_features=10, out_features=7168, bias=True)
  (BN_): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (CT2d_1): deconv_layer(
    (CT2d): ConvTranspose2d(512, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (BN2d): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (CT2d_2): deconv_layer(
    (CT2d): ConvTranspose2d(3

In [195]:
if mode == "sanity_check":
  print("\n\n\nWARNING: you are training on a sanity_check dataset.\n\n\n\n")
if len(learn.callback_fns) == 1:
  print("\n\n\nWARNING: you are not using the DiversityWeightsScheduler callback.\n\n\n")

    
saver_best = SaveModelCallback(learn, every='improvement', monitor='validation', name=env.save_filename + "-best")
saver_every_epoch = SaveModelCallback(learn, every='epoch', name=env.save_filename)

# import cProfile

# pr = cProfile.Profile()
# pr.enable()
learn.fit(30, lr=5e-03, callbacks=[saver_best, saver_every_epoch])
# pr.disable()

# learn.fit(30, lr=5e-03, wd=0.005, callbacks=[saver_best, saver_every_epoch])
# learn.fit_one_cycle(20, max_lr=5e-1, callbacks=[saver_callback])
# learn.fit_one_cycle(8, max_lr=5e-01) #mohammad's setting that got 77 validation start on resnet with diversity loss on AdaptiveAvgPool2d
# learn.fit_one_cycle(5, max_lr=2e-2) #used for vgg-19-bn
# learn.fit_one_cycle(5, max_lr=3e-3) # used for resnet50

shutil.copyfile(env.temp_csv_path + '/' + env.save_filename + ".csv", env.get_csv_path() + '.csv')
shutil.copytree(env.data_path/"models", env.get_models_path())

# pr.print_stats()

# shutil.copyfile("/content/dataset/models/" + save_filename + "-best.pth", "/content/gdrive/My Drive/DL/models/" + save_filename + ".pth")

epoch,train_loss,valid_loss,validation,fool_loss,div_loss_0,triplet_loss,time
0,100.627815,100.756844,0.647,0.342822,0.620214,99.793793,08:59
1,100.044991,100.162331,0.871,0.118285,0.450352,99.593681,08:57
2,99.877731,99.898506,0.899,0.09294,0.382543,99.423035,08:57
3,99.75032,99.783417,0.91,0.083422,0.331236,99.368767,08:57
4,99.7005,99.940247,0.914,0.077206,0.454322,99.408714,08:58
5,99.622986,99.614777,0.906,0.088418,0.305684,99.220665,08:57
6,99.77916,99.85154,0.923,0.066126,0.455241,99.330162,08:57
7,99.578293,99.683853,0.909,0.086096,0.312228,99.28553,08:57
8,99.546761,99.562775,0.925,0.064251,0.314543,99.18399,08:57
9,99.548126,99.594177,0.911,0.077767,0.330165,99.186226,08:58


a: [414: 0.06, 423: 0.16, 446: 0.04, 464: 0.11, 597: 0.05, 636: 0.07, 748: 0.15, 774: 0.02, 828: 0.03, 861: 0.01, ]
p: [414: 0.06, 423: 0.15, 446: 0.04, 464: 0.12, 597: 0.05, 636: 0.07, 748: 0.14, 774: 0.02, 828: 0.03, 861: 0.01, ]
n: [414: 0.06, 423: 0.15, 446: 0.04, 464: 0.11, 597: 0.05, 636: 0.07, 748: 0.15, 774: 0.02, 828: 0.03, 861: 0.01, ]
ap_dist: 7.597187959618168e-06, an_dist: 0.00012840304407291114
a: [100: 1.00, ]
p: [100: 1.00, ]
n: [100: 1.00, ]
ap_dist: 8.314139449794311e-06, an_dist: 4.494517270359211e-06
a: [712: 0.02, 828: 0.97, ]
p: [712: 0.02, 828: 0.97, ]
n: [712: 0.02, 828: 0.97, ]
ap_dist: 9.282163773605134e-06, an_dist: 3.3816706945799524e-06
a: [160: 0.05, 184: 0.04, 185: 0.01, 186: 0.03, 188: 0.02, 189: 0.18, 191: 0.10, 202: 0.15, 204: 0.01, 218: 0.01, 219: 0.03, 226: 0.03, 265: 0.02, 267: 0.03, 805: 0.04, 852: 0.01, 904: 0.04, ]
p: [160: 0.05, 184: 0.04, 185: 0.01, 186: 0.03, 188: 0.02, 189: 0.18, 191: 0.10, 202: 0.15, 204: 0.01, 218: 0.01, 219: 0.03, 226: 0.0

a: [39: 0.42, 48: 0.03, 314: 0.04, 318: 0.02, 319: 0.01, 363: 0.06, 904: 0.34, ]
p: [39: 0.43, 48: 0.02, 314: 0.05, 318: 0.02, 319: 0.02, 363: 0.05, 904: 0.32, ]
n: [39: 0.06, 48: 0.01, 556: 0.02, 828: 0.02, 904: 0.84, ]
ap_dist: 0.0006345316651277244, an_dist: 0.019828269258141518
a: [452: 0.05, 533: 0.79, 658: 0.08, 735: 0.02, 911: 0.02, ]
p: [452: 0.05, 533: 0.84, 658: 0.06, 735: 0.01, 911: 0.02, ]
n: [40: 0.01, 904: 0.91, ]
ap_dist: 0.00035640556598082185, an_dist: 0.034285563975572586
a: [400: 0.29, 439: 0.02, 456: 0.01, 542: 0.02, 667: 0.32, 723: 0.01, 731: 0.10, 971: 0.05, ]
p: [400: 0.26, 439: 0.01, 456: 0.01, 501: 0.01, 542: 0.02, 641: 0.01, 667: 0.26, 731: 0.10, 920: 0.01, 971: 0.08, ]
n: [197: 0.05, 199: 0.17, 205: 0.02, 214: 0.01, 223: 0.01, 224: 0.06, 226: 0.03, 233: 0.01, 369: 0.01, 400: 0.02, 439: 0.02, 487: 0.02, 667: 0.04, 731: 0.06, 752: 0.01, 834: 0.01, 920: 0.08, 971: 0.03, ]
ap_dist: 0.0005222423351369798, an_dist: 0.005783132743090391
a: [481: 0.03, 482: 0.08, 485

a: [441: 0.02, 443: 0.03, 455: 0.03, 504: 0.01, 599: 0.04, 647: 0.01, 651: 0.01, 692: 0.27, 700: 0.06, 722: 0.01, 728: 0.02, 737: 0.07, 746: 0.01, 794: 0.01, 886: 0.09, 898: 0.01, 904: 0.05, 999: 0.02]
p: [441: 0.02, 443: 0.04, 455: 0.02, 599: 0.03, 692: 0.27, 700: 0.07, 728: 0.02, 737: 0.08, 746: 0.01, 794: 0.01, 886: 0.08, 898: 0.01, 904: 0.05, 999: 0.03]
n: [411: 0.01, 441: 0.01, 443: 0.06, 455: 0.03, 599: 0.03, 651: 0.01, 692: 0.25, 700: 0.06, 728: 0.02, 737: 0.07, 746: 0.01, 794: 0.01, 886: 0.08, 898: 0.01, 904: 0.05, 999: 0.03]
ap_dist: 0.001311380648985505, an_dist: 0.009759420529007912
a: [794: 0.01, 904: 0.96, ]
p: [794: 0.01, 904: 0.97, ]
n: [773: 0.01, 794: 0.04, 904: 0.84, ]
ap_dist: 0.00045787551789544523, an_dist: 0.017677756026387215
a: [554: 0.21, 562: 0.05, 904: 0.71, ]
p: [554: 0.03, 562: 0.02, 904: 0.93, ]
n: [554: 0.06, 562: 0.02, 904: 0.90, ]
ap_dist: 0.003031934378668666, an_dist: 0.01809913106262684
a: [448: 0.02, 580: 0.02, 853: 0.90, ]
p: [448: 0.02, 580: 0.03,

a: [310: 0.01, 599: 0.09, 644: 0.04, 700: 0.01, 722: 0.02, 794: 0.02, 828: 0.06, 868: 0.05, 904: 0.29, 927: 0.03, 949: 0.03, 950: 0.03, 951: 0.01, 989: 0.03, ]
p: [301: 0.01, 310: 0.01, 599: 0.06, 644: 0.02, 700: 0.02, 722: 0.02, 794: 0.01, 828: 0.09, 868: 0.07, 904: 0.24, 927: 0.04, 928: 0.01, 949: 0.04, 950: 0.04, 951: 0.01, 968: 0.01, 989: 0.03, ]
n: [398: 0.03, 522: 0.01, 644: 0.91, ]
ap_dist: 0.0020694152917712927, an_dist: 0.01944788172841072
a: [74: 0.02, 76: 0.07, 77: 0.32, 306: 0.07, 308: 0.04, 316: 0.03, 319: 0.06, 904: 0.35, ]
p: [74: 0.01, 76: 0.07, 77: 0.30, 306: 0.06, 308: 0.07, 316: 0.05, 319: 0.11, 599: 0.01, 904: 0.26, ]
n: [74: 0.01, 76: 0.20, 77: 0.50, 303: 0.01, 319: 0.12, ]
ap_dist: 0.0005175953265279531, an_dist: 0.016687681898474693
a: [904: 0.99, ]
p: [904: 0.99, ]
n: [904: 0.99, ]
ap_dist: 0.0003465483314357698, an_dist: 0.0008134212112054229
a: [199: 0.01, 205: 0.15, 208: 0.03, 256: 0.02, 369: 0.02, 379: 0.04, 752: 0.06, 805: 0.01, 852: 0.31, 904: 0.03, ]
p: [

n: [47: 0.06, 178: 0.05, 205: 0.02, 208: 0.03, 210: 0.08, 319: 0.06, 411: 0.01, 457: 0.04, 549: 0.04, 599: 0.01, 605: 0.02, 633: 0.03, 643: 0.02, 676: 0.03, 721: 0.01, 750: 0.02, 796: 0.01, 904: 0.07, 906: 0.01, 917: 0.02, 921: 0.02, ]
ap_dist: 0.0006149281398393214, an_dist: 0.03904120624065399
a: [69: 0.14, 443: 0.02, 458: 0.08, 611: 0.02, 640: 0.02, 646: 0.18, 647: 0.01, 688: 0.02, 712: 0.02, 741: 0.02, 748: 0.04, 861: 0.02, 868: 0.08, 918: 0.02, 922: 0.04, ]
p: [69: 0.12, 443: 0.01, 458: 0.08, 478: 0.01, 549: 0.02, 611: 0.01, 640: 0.01, 646: 0.10, 647: 0.02, 666: 0.01, 700: 0.01, 712: 0.02, 748: 0.05, 783: 0.02, 861: 0.02, 868: 0.13, 883: 0.01, 922: 0.04, 999: 0.01]
n: [904: 0.92, ]
ap_dist: 0.0009870862122625113, an_dist: 0.035293128341436386
a: [444: 0.01, 456: 0.16, 477: 0.08, 587: 0.01, 688: 0.42, 740: 0.02, 754: 0.04, 758: 0.01, 764: 0.03, 786: 0.13, ]
p: [456: 0.11, 477: 0.05, 587: 0.01, 688: 0.44, 740: 0.02, 754: 0.03, 764: 0.03, 786: 0.22, ]
n: [456: 0.21, 477: 0.05, 613: 0

a: [904: 0.99, ]
p: [904: 0.99, ]
n: [904: 0.99, ]
ap_dist: 0.00024357323127333075, an_dist: 0.011537984013557434
a: [547: 0.99, ]
p: [547: 0.99, ]
n: [547: 1.00, ]
ap_dist: 0.012968714348971844, an_dist: 0.02124372497200966
a: [456: 0.89, 494: 0.01, 843: 0.05, ]
p: [456: 0.88, 494: 0.01, 843: 0.06, ]
n: [456: 0.87, 494: 0.01, 843: 0.06, ]
ap_dist: 0.0005476564401760697, an_dist: 0.06056831777095795
a: [96: 0.04, 369: 0.03, 381: 0.01, 411: 0.01, 414: 0.02, 722: 0.02, 752: 0.04, 794: 0.02, 815: 0.01, 852: 0.03, 904: 0.53, ]
p: [96: 0.04, 369: 0.03, 411: 0.02, 414: 0.02, 722: 0.03, 752: 0.03, 794: 0.02, 815: 0.01, 852: 0.02, 904: 0.52, ]
n: [414: 0.03, 457: 0.01, 487: 0.02, 489: 0.02, 591: 0.01, 615: 0.06, 664: 0.03, 688: 0.02, 746: 0.02, 752: 0.11, 788: 0.01, 791: 0.01, 794: 0.01, 799: 0.02, 805: 0.03, 851: 0.01, 879: 0.01, 890: 0.05, 904: 0.08, 971: 0.02, ]
ap_dist: 0.00033155723940581083, an_dist: 0.042803384363651276
a: [409: 0.02, 508: 0.02, 530: 0.02, 611: 0.04, 620: 0.09, 664: 0.0

a: [611: 0.99, ]
p: [611: 0.98, 646: 0.01, ]
n: [611: 1.00, ]
ap_dist: 0.0004815561987925321, an_dist: 0.038275063037872314
a: [313: 0.02, 319: 0.25, 320: 0.19, 904: 0.50, ]
p: [313: 0.02, 319: 0.22, 320: 0.20, 904: 0.52, ]
n: [319: 0.05, 418: 0.26, 446: 0.01, 456: 0.01, 546: 0.01, 549: 0.01, 563: 0.01, 623: 0.02, 633: 0.03, 688: 0.02, 749: 0.03, 767: 0.02, 769: 0.04, 783: 0.01, 798: 0.02, 823: 0.01, 916: 0.02, 918: 0.09, 922: 0.05, ]
ap_dist: 0.0010637878440320492, an_dist: 0.024202359840273857
a: [904: 0.98, ]
p: [904: 0.99, ]
n: [904: 0.95, ]
ap_dist: 0.00039363838732242584, an_dist: 0.03213104233145714
a: [918: 1.00, ]
p: [918: 1.00, ]
n: [446: 0.01, 549: 0.17, 709: 0.03, 721: 0.03, 767: 0.01, 769: 0.01, 918: 0.54, 921: 0.01, 922: 0.10, ]
ap_dist: 0.0005482765845954418, an_dist: 0.05023867264389992
a: [916: 0.02, 922: 0.94, ]
p: [916: 0.02, 922: 0.94, ]
n: [549: 0.01, 916: 0.05, 922: 0.87, ]
ap_dist: 0.00031549425330013037, an_dist: 0.0489024817943573
a: [498: 0.02, 509: 0.04, 582:

a: [904: 0.99, ]
p: [904: 1.00, ]
n: [904: 0.97, ]
ap_dist: 0.0017156440299004316, an_dist: 0.025248436257243156
a: [14: 0.01, 36: 0.01, 39: 0.03, 49: 0.02, 50: 0.06, 58: 0.08, 82: 0.04, 84: 0.09, 85: 0.07, 131: 0.02, 135: 0.08, 142: 0.01, 308: 0.02, 310: 0.01, 316: 0.04, 319: 0.03, 580: 0.01, 904: 0.03, 975: 0.01, ]
p: [39: 0.02, 49: 0.01, 50: 0.05, 58: 0.18, 82: 0.04, 84: 0.08, 85: 0.05, 131: 0.01, 135: 0.04, 308: 0.02, 316: 0.07, 318: 0.02, 319: 0.05, 363: 0.01, 580: 0.01, 904: 0.06, ]
n: [36: 0.01, 39: 0.03, 49: 0.02, 50: 0.05, 58: 0.08, 78: 0.01, 82: 0.04, 84: 0.05, 85: 0.10, 131: 0.01, 135: 0.07, 305: 0.02, 308: 0.05, 310: 0.03, 316: 0.05, 319: 0.03, 363: 0.02, 580: 0.01, 904: 0.04, ]
ap_dist: 0.0010853918502107263, an_dist: 0.02090313844382763
a: [398: 0.01, 446: 0.02, 489: 0.03, 507: 0.18, 553: 0.03, 637: 0.01, 695: 0.09, 743: 0.01, 771: 0.05, 789: 0.08, 794: 0.03, 799: 0.01, 904: 0.02, 918: 0.31, ]
p: [446: 0.02, 489: 0.03, 507: 0.19, 553: 0.04, 637: 0.02, 695: 0.12, 771: 0.05

n: [55: 0.03, 64: 0.01, 84: 0.02, 96: 0.21, 319: 0.02, 489: 0.01, 580: 0.10, 752: 0.04, 815: 0.01, 852: 0.02, 890: 0.01, 904: 0.22, 956: 0.03, ]
ap_dist: 4.6554519940400496e-05, an_dist: 0.046503473073244095
a: [507: 0.01, 607: 0.01, 611: 0.65, 637: 0.03, 695: 0.06, 721: 0.01, 748: 0.06, 750: 0.02, 861: 0.01, 999: 0.03]
p: [611: 0.71, 637: 0.03, 695: 0.06, 748: 0.05, 750: 0.02, 861: 0.01, 999: 0.02]
n: [60: 0.04, 468: 0.01, 478: 0.01, 489: 0.02, 502: 0.01, 509: 0.02, 611: 0.03, 636: 0.01, 637: 0.01, 641: 0.01, 695: 0.02, 723: 0.02, 748: 0.05, 774: 0.01, 790: 0.01, 805: 0.09, 815: 0.01, 879: 0.03, 950: 0.15, ]
ap_dist: 0.0004138390359003097, an_dist: 0.04922392964363098
a: [467: 0.02, 489: 0.03, 509: 0.08, 549: 0.04, 582: 0.29, 599: 0.03, 611: 0.11, 692: 0.04, 879: 0.05, 922: 0.15, 952: 0.01, ]
p: [415: 0.01, 467: 0.02, 489: 0.01, 509: 0.10, 549: 0.03, 582: 0.26, 599: 0.02, 611: 0.11, 692: 0.04, 879: 0.03, 922: 0.23, ]
n: [489: 0.02, 599: 0.03, 794: 0.02, 904: 0.83, ]
ap_dist: 0.0006770

a: [319: 0.06, 320: 0.01, 473: 0.03, 477: 0.01, 587: 0.01, 599: 0.01, 623: 0.02, 633: 0.02, 709: 0.01, 752: 0.07, 783: 0.01, 784: 0.03, 828: 0.07, 904: 0.27, ]
p: [319: 0.06, 320: 0.01, 473: 0.02, 587: 0.01, 599: 0.01, 623: 0.01, 633: 0.02, 709: 0.01, 752: 0.07, 784: 0.02, 828: 0.08, 904: 0.33, ]
n: [426: 0.04, 611: 0.45, 635: 0.07, 769: 0.02, 778: 0.02, 798: 0.03, 835: 0.25, 892: 0.03, ]
ap_dist: 0.005613110028207302, an_dist: 0.036546431481838226
a: [904: 0.93, ]
p: [904: 0.93, ]
n: [904: 0.94, ]
ap_dist: 0.0011732240673154593, an_dist: 0.02691601775586605
a: [184: 0.59, 191: 0.02, 271: 0.01, 274: 0.33, ]
p: [184: 0.66, 191: 0.03, 271: 0.01, 274: 0.26, ]
n: [184: 0.64, 191: 0.02, 271: 0.01, 274: 0.30, ]
ap_dist: 0.0005844576517120004, an_dist: 0.0471716932952404
a: [454: 0.19, 467: 0.01, 478: 0.04, 489: 0.01, 549: 0.02, 582: 0.05, 611: 0.02, 624: 0.01, 692: 0.03, 781: 0.01, 860: 0.01, 917: 0.09, 921: 0.03, 922: 0.17, ]
p: [96: 0.02, 454: 0.16, 467: 0.01, 478: 0.04, 489: 0.02, 549: 0.

a: [55: 0.02, 59: 0.03, 64: 0.08, 313: 0.02, 319: 0.01, 411: 0.02, 443: 0.02, 605: 0.01, 620: 0.02, 646: 0.01, 664: 0.01, 688: 0.05, 712: 0.03, 738: 0.01, 752: 0.01, 790: 0.01, 828: 0.02, 868: 0.08, 904: 0.28, 943: 0.02, 948: 0.01, ]
p: [55: 0.04, 59: 0.04, 64: 0.12, 313: 0.02, 319: 0.01, 411: 0.02, 443: 0.02, 620: 0.02, 664: 0.01, 688: 0.03, 712: 0.02, 738: 0.01, 752: 0.01, 790: 0.01, 828: 0.02, 868: 0.10, 904: 0.20, 943: 0.02, 948: 0.02, ]
n: [611: 1.00, ]
ap_dist: 0.003775110002607107, an_dist: 0.0599118173122406
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [611: 1.00, ]
ap_dist: 0.0017523614224046469, an_dist: 0.04397724196314812
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [611: 0.97, ]
ap_dist: 0.001771491370163858, an_dist: 0.05251225456595421
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [922: 0.97, ]
ap_dist: 5.827882318953925e-07, an_dist: 0.046943582594394684
a: [489: 0.21, 580: 0.01, 611: 0.06, 646: 0.13, 658: 0.01, 741: 0.02, 806: 0.03, 815: 0.21, 824: 0.13, 839: 0.01, 911: 0.06, ]
p: [474: 0.01, 

a: [611: 1.00, ]
p: [611: 1.00, ]
n: [664: 0.01, 688: 0.01, 781: 0.03, 922: 0.90, ]
ap_dist: 9.503161709289998e-05, an_dist: 0.04855623468756676
a: [922: 1.00, ]
p: [922: 1.00, ]
n: [611: 1.00, ]
ap_dist: 0.0022835799027234316, an_dist: 0.07522698491811752
a: [18: 0.01, 50: 0.02, 82: 0.03, 96: 0.01, 97: 0.02, 99: 0.03, 127: 0.01, 134: 0.02, 138: 0.30, 489: 0.12, 599: 0.01, 815: 0.07, 904: 0.15, ]
p: [18: 0.01, 50: 0.02, 77: 0.01, 82: 0.02, 96: 0.01, 97: 0.01, 99: 0.02, 134: 0.01, 138: 0.24, 489: 0.09, 599: 0.01, 815: 0.08, 904: 0.30, ]
n: [18: 0.01, 50: 0.03, 82: 0.02, 96: 0.01, 97: 0.02, 99: 0.03, 134: 0.01, 138: 0.21, 489: 0.07, 815: 0.07, 904: 0.35, ]
ap_dist: 0.0008484211866743863, an_dist: 0.05385381728410721
a: [409: 0.03, 411: 0.01, 489: 0.02, 508: 0.02, 530: 0.01, 599: 0.03, 611: 0.04, 620: 0.06, 664: 0.03, 688: 0.03, 692: 0.05, 746: 0.02, 781: 0.02, 782: 0.02, 865: 0.01, 868: 0.02, 886: 0.02, 892: 0.02, 904: 0.13, ]
p: [409: 0.03, 489: 0.02, 508: 0.02, 530: 0.01, 599: 0.02, 61

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



a: [489: 0.42, 815: 0.55, ]
p: [489: 0.29, 815: 0.69, ]
n: [489: 0.27, 815: 0.68, 904: 0.03, ]
ap_dist: 0.01280798576772213, an_dist: 0.03917097672820091
a: [458: 0.99, ]
p: [458: 1.00, ]
n: [688: 0.67, 922: 0.31, ]
ap_dist: 0.00013333246170077473, an_dist: 0.05937478318810463
a: [664: 0.02, 688: 0.05, 782: 0.01, 916: 0.01, 922: 0.84, ]
p: [664: 0.02, 688: 0.04, 922: 0.87, ]
n: [458: 1.00, ]
ap_dist: 0.0022284386213868856, an_dist: 0.06950470060110092
a: [458: 0.99, ]
p: [458: 1.00, ]
n: [458: 1.00, ]
ap_dist: 0.004080630838871002, an_dist: 0.03286509960889816
a: [454: 0.06, 458: 0.88, 921: 0.02, 922: 0.03, ]
p: [454: 0.05, 458: 0.87, 921: 0.02, 922: 0.03, ]
n: [454: 0.13, 458: 0.14, 509: 0.01, 594: 0.01, 611: 0.02, 624: 0.01, 921: 0.02, 922: 0.60, ]
ap_dist: 0.000599214225076139, an_dist: 0.06773196905851364
a: [458: 0.08, 611: 0.92, ]
p: [458: 0.05, 611: 0.95, ]
n: [458: 1.00, ]
ap_dist: 0.00026561625418253243, an_dist: 0.09159781038761139
a: [18: 0.04, 361: 0.85, 388: 0.03, ]
p: [18

a: [157: 0.02, 212: 0.02, 231: 0.03, 232: 0.07, 446: 0.03, 749: 0.08, 916: 0.01, 918: 0.06, 921: 0.04, 922: 0.55, ]
p: [232: 0.02, 446: 0.02, 749: 0.02, 918: 0.01, 922: 0.86, ]
n: [157: 0.02, 200: 0.02, 212: 0.02, 217: 0.06, 231: 0.08, 232: 0.12, 458: 0.03, 749: 0.09, 907: 0.02, 918: 0.09, 921: 0.17, 922: 0.15, ]
ap_dist: 0.007719255052506924, an_dist: 0.07461748272180557
a: [289: 0.02, 381: 0.01, 383: 0.01, 489: 0.67, 815: 0.14, 839: 0.03, 912: 0.02, ]
p: [289: 0.01, 489: 0.68, 815: 0.19, 839: 0.04, 912: 0.01, ]
n: [489: 0.80, 815: 0.15, ]
ap_dist: 0.0002935748198069632, an_dist: 0.07328332215547562
a: [592: 0.01, 688: 0.95, 922: 0.01, ]
p: [531: 0.02, 592: 0.04, 688: 0.37, 704: 0.01, 922: 0.48, ]
n: [409: 0.03, 487: 0.02, 530: 0.22, 531: 0.37, 605: 0.01, 664: 0.01, 688: 0.14, 782: 0.03, 826: 0.06, ]
ap_dist: 0.008896177634596825, an_dist: 0.07252372801303864
a: [458: 0.38, 692: 0.01, 916: 0.08, 921: 0.02, 922: 0.45, ]
p: [458: 0.01, 916: 0.08, 922: 0.85, ]
n: [611: 1.00, ]
ap_dist: 0

a: [319: 0.30, 320: 0.04, 418: 0.44, 563: 0.06, 752: 0.06, 767: 0.04, ]
p: [319: 0.23, 320: 0.03, 418: 0.46, 563: 0.07, 752: 0.07, 767: 0.06, ]
n: [418: 0.52, 563: 0.05, 752: 0.01, 767: 0.29, 769: 0.04, 918: 0.02, ]
ap_dist: 0.0003021348384208977, an_dist: 0.07462058216333389
a: [39: 0.02, 84: 0.02, 318: 0.03, 458: 0.01, 489: 0.03, 562: 0.05, 580: 0.29, 611: 0.01, 815: 0.01, 835: 0.02, 936: 0.31, 937: 0.01, 938: 0.02, 971: 0.01, ]
p: [39: 0.02, 84: 0.02, 318: 0.03, 458: 0.01, 489: 0.03, 562: 0.05, 580: 0.29, 611: 0.01, 815: 0.01, 835: 0.02, 936: 0.32, 937: 0.01, 938: 0.02, 971: 0.01, ]
n: [611: 1.00, ]
ap_dist: 0.0002137883857358247, an_dist: 0.07911109179258347
a: [922: 1.00, ]
p: [922: 1.00, ]
n: [458: 1.00, ]
ap_dist: 0.014682017266750336, an_dist: 0.10824834555387497
a: [77: 0.02, 409: 0.05, 646: 0.31, 688: 0.02, 794: 0.02, 815: 0.06, 868: 0.09, 892: 0.03, 904: 0.16, ]
p: [77: 0.02, 409: 0.05, 646: 0.20, 688: 0.02, 794: 0.02, 815: 0.07, 868: 0.11, 892: 0.03, 904: 0.19, 948: 0.01, ]

a: [565: 0.36, 611: 0.18, 781: 0.04, 916: 0.01, 917: 0.04, 921: 0.02, 922: 0.21, ]
p: [549: 0.01, 565: 0.50, 611: 0.10, 781: 0.04, 916: 0.02, 917: 0.03, 921: 0.02, 922: 0.16, ]
n: [565: 0.46, 611: 0.52, ]
ap_dist: 0.000621944316662848, an_dist: 0.06380723416805267
a: [922: 0.98, ]
p: [688: 0.02, 922: 0.97, ]
n: [922: 0.99, ]
ap_dist: 2.614044933579862e-05, an_dist: 0.10267700999975204
a: [904: 0.97, ]
p: [904: 0.97, ]
n: [664: 0.13, 688: 0.46, 752: 0.02, 781: 0.23, 782: 0.02, 815: 0.01, 922: 0.05, ]
ap_dist: 0.0023564163129776716, an_dist: 0.08245081454515457
a: [458: 1.00, ]
p: [458: 1.00, ]
n: [458: 0.03, 611: 0.97, ]
ap_dist: 0.0005763106746599078, an_dist: 0.047974441200494766
a: [688: 0.56, 754: 0.01, 781: 0.03, 815: 0.03, 922: 0.33, ]
p: [688: 0.01, 922: 0.96, ]
n: [96: 0.06, 313: 0.01, 315: 0.04, 319: 0.02, 530: 0.02, 599: 0.09, 620: 0.04, 688: 0.08, 790: 0.01, 815: 0.03, 852: 0.04, 904: 0.35, 948: 0.01, ]
ap_dist: 0.01012332085520029, an_dist: 0.08621133863925934
a: [458: 0.06,

p: [96: 0.11, 411: 0.03, 443: 0.03, 448: 0.04, 496: 0.06, 721: 0.01, 904: 0.56, ]
n: [546: 0.07, 594: 0.12, 688: 0.78, ]
ap_dist: 0.005115066189318895, an_dist: 0.017763866111636162
a: [562: 0.05, 611: 0.04, 664: 0.04, 688: 0.72, 698: 0.04, 782: 0.04, ]
p: [562: 0.04, 611: 0.04, 664: 0.03, 688: 0.75, 698: 0.02, 782: 0.04, ]
n: [421: 0.01, 458: 0.09, 483: 0.01, 489: 0.01, 498: 0.01, 562: 0.26, 565: 0.02, 611: 0.17, 628: 0.02, 698: 0.12, 718: 0.02, 916: 0.01, 921: 0.03, ]
ap_dist: 0.00018711057782638818, an_dist: 0.032422665506601334
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [458: 1.00, ]
ap_dist: 0.011406839825212955, an_dist: 0.11591082811355591
a: [55: 0.01, 64: 0.02, 72: 0.01, 311: 0.05, 313: 0.01, 315: 0.09, 319: 0.22, 320: 0.03, 546: 0.01, 594: 0.02, 664: 0.01, 688: 0.20, 752: 0.02, 815: 0.02, 904: 0.04, 981: 0.01, ]
p: [319: 0.01, 594: 0.03, 611: 0.01, 664: 0.03, 688: 0.82, 769: 0.03, ]
n: [72: 0.02, 303: 0.02, 307: 0.01, 311: 0.03, 313: 0.02, 315: 0.02, 319: 0.12, 320: 0.03, 904: 0.46

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



a: [611: 1.00, ]
p: [611: 1.00, ]
n: [611: 1.00, ]
ap_dist: 0.012175695970654488, an_dist: 0.044760361313819885
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [458: 0.99, ]
ap_dist: 0.00011913209891645238, an_dist: 0.06166951358318329
a: [904: 0.99, ]
p: [904: 0.99, ]
n: [904: 0.98, ]
ap_dist: 0.008691808208823204, an_dist: 0.054840076714754105
Better model found at epoch 8 with validation value: 0.925000011920929.
a: [549: 0.04, 614: 0.01, 815: 0.15, 922: 0.71, 971: 0.02, 982: 0.01, ]
p: [549: 0.04, 815: 0.07, 922: 0.81, ]
n: [549: 0.05, 611: 0.08, 815: 0.02, 922: 0.79, ]
ap_dist: 0.0002652605762705207, an_dist: 0.11535416543483734
a: [904: 1.00, ]
p: [904: 1.00, ]
n: [815: 1.00, ]
ap_dist: 0.0004522631352301687, an_dist: 0.09176768362522125
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [611: 1.00, ]
ap_dist: 0.003105894895270467, an_dist: 0.15757736563682556
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [458: 0.95, 835: 0.02, 921: 0.02, ]
ap_dist: 8.265633368864655e-05, an_dist: 0.08397726714611053
a: [922: 1.0

a: [458: 0.98, 611: 0.01, ]
p: [458: 0.98, 611: 0.01, ]
n: [611: 1.00, ]
ap_dist: 3.044543154828716e-05, an_dist: 0.1010923683643341
a: [299: 0.04, 922: 0.96, ]
p: [299: 0.13, 922: 0.86, ]
n: [922: 1.00, ]
ap_dist: 0.00030507854535244405, an_dist: 0.025458959862589836
a: [440: 0.02, 455: 0.13, 508: 0.05, 530: 0.01, 599: 0.10, 711: 0.02, 722: 0.01, 737: 0.05, 746: 0.02, 794: 0.01, 852: 0.01, 878: 0.01, 898: 0.03, 904: 0.06, 907: 0.20, 966: 0.03, 999: 0.01]
p: [440: 0.02, 455: 0.14, 508: 0.05, 530: 0.01, 599: 0.10, 711: 0.02, 722: 0.01, 737: 0.05, 746: 0.02, 794: 0.01, 852: 0.01, 878: 0.01, 898: 0.03, 904: 0.05, 907: 0.21, 966: 0.03, 999: 0.01]
n: [440: 0.17, 455: 0.13, 509: 0.01, 559: 0.01, 582: 0.01, 737: 0.33, 878: 0.02, 898: 0.08, 907: 0.15, ]
ap_dist: 3.1198364013107494e-05, an_dist: 0.07342735677957535
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [611: 1.00, ]
ap_dist: 0.00012238114140927792, an_dist: 0.1517583131790161
a: [55: 0.03, 61: 0.23, 62: 0.03, 64: 0.02, 96: 0.02, 319: 0.02, 752: 

a: [815: 1.00, ]
p: [815: 0.99, ]
n: [815: 1.00, ]
ap_dist: 0.00017795010353438556, an_dist: 0.03195762634277344
a: [458: 1.00, ]
p: [458: 1.00, ]
n: [611: 1.00, ]
ap_dist: 6.538275101775071e-06, an_dist: 0.11675622314214706
a: [611: 1.00, ]
p: [611: 1.00, ]
n: [922: 1.00, ]
ap_dist: 0.00028992953593842685, an_dist: 0.06847590953111649
a: [611: 1.00, ]
p: [458: 0.01, 611: 0.99, ]
n: [611: 1.00, ]
ap_dist: 0.024024929851293564, an_dist: 0.0945165604352951
a: [815: 0.99, ]
p: [74: 0.01, 815: 0.79, 904: 0.17, ]
n: [815: 1.00, ]
ap_dist: 0.008199981413781643, an_dist: 0.040215782821178436
a: [77: 0.02, 409: 0.08, 688: 0.02, 815: 0.03, 904: 0.78, ]
p: [77: 0.03, 409: 0.04, 688: 0.01, 815: 0.03, 904: 0.83, ]
n: [77: 0.02, 409: 0.02, 904: 0.92, ]
ap_dist: 0.019435834139585495, an_dist: 0.05043068900704384
a: [815: 1.00, ]
p: [815: 1.00, ]
n: [815: 1.00, ]
ap_dist: 1.6470987247885205e-05, an_dist: 2.306606256752275e-05
a: [904: 1.00, ]
p: [904: 1.00, ]
n: [904: 1.00, ]
ap_dist: 0.0211979541927

a: [664: 0.13, 688: 0.70, 782: 0.03, 815: 0.11, 922: 0.02, ]
p: [664: 0.17, 688: 0.51, 782: 0.03, 815: 0.02, 922: 0.24, ]
n: [815: 1.00, ]
ap_dist: 0.004933744203299284, an_dist: 0.12199704349040985
a: [144: 0.99, ]
p: [144: 1.00, ]
n: [58: 0.01, 144: 0.94, ]
ap_dist: 0.0005947242025285959, an_dist: 0.026700634509325027
a: [7: 0.02, 8: 0.04, 37: 0.01, 151: 0.14, 159: 0.01, 163: 0.02, 184: 0.02, 202: 0.03, 211: 0.02, 225: 0.02, 242: 0.04, 243: 0.02, 245: 0.13, 254: 0.02, 273: 0.03, 282: 0.01, 292: 0.01, 489: 0.04, 580: 0.01, 738: 0.05, 852: 0.03, 955: 0.01, ]
p: [7: 0.03, 8: 0.05, 151: 0.17, 159: 0.01, 163: 0.01, 184: 0.03, 185: 0.01, 202: 0.03, 211: 0.02, 225: 0.02, 242: 0.03, 243: 0.01, 245: 0.10, 254: 0.01, 273: 0.03, 282: 0.01, 330: 0.01, 489: 0.05, 580: 0.01, 738: 0.06, 852: 0.03, 955: 0.01, ]
n: [7: 0.07, 8: 0.05, 37: 0.03, 45: 0.01, 151: 0.15, 202: 0.02, 225: 0.02, 242: 0.03, 243: 0.01, 245: 0.11, 254: 0.02, 281: 0.01, 282: 0.03, 285: 0.02, 292: 0.03, 363: 0.04, 539: 0.01, 738: 0

p: [419: 0.01, 446: 0.02, 549: 0.05, 591: 0.01, 611: 0.50, 635: 0.01, 692: 0.07, 709: 0.04, 748: 0.02, 769: 0.03, 868: 0.02, 893: 0.02, 922: 0.01, ]
n: [446: 0.05, 478: 0.03, 549: 0.12, 611: 0.13, 664: 0.08, 692: 0.10, 781: 0.06, 782: 0.02, 815: 0.12, 868: 0.01, 893: 0.04, 921: 0.01, 922: 0.08, ]
ap_dist: 0.0019959278870373964, an_dist: 0.09090009331703186
a: [815: 1.00, ]
p: [815: 0.95, 922: 0.05, ]
n: [815: 1.00, ]
ap_dist: 0.00013816206774208695, an_dist: 0.07374440878629684
a: [458: 0.96, 611: 0.02, ]
p: [458: 0.95, 611: 0.02, ]
n: [611: 1.00, ]
ap_dist: 0.007089363411068916, an_dist: 0.10325079411268234
a: [458: 1.00, ]
p: [458: 1.00, ]
n: [611: 1.00, ]
ap_dist: 9.909018262987956e-05, an_dist: 0.13125815987586975
a: [458: 1.00, ]
p: [458: 1.00, ]
n: [611: 1.00, ]
ap_dist: 7.737319265288534e-07, an_dist: 0.0996999442577362
a: [458: 1.00, ]
p: [458: 1.00, ]
n: [458: 1.00, ]
ap_dist: 0.0006973594427108765, an_dist: 0.05167998746037483
a: [58: 0.01, 61: 0.03, 752: 0.01, 904: 0.79, ]
p

a: [815: 1.00, ]
p: [815: 1.00, ]
n: [77: 0.02, 319: 0.01, 815: 0.01, 904: 0.92, ]
ap_dist: 0.00022834142146166414, an_dist: 0.0931861400604248
a: [318: 0.04, 904: 0.94, ]
p: [318: 0.04, 904: 0.94, ]
n: [318: 0.05, 904: 0.93, ]
ap_dist: 1.942948620126117e-05, an_dist: 0.10185740143060684
a: [904: 0.99, ]
p: [904: 1.00, ]
n: [904: 0.99, ]
ap_dist: 2.012778895732481e-05, an_dist: 0.07934422045946121
a: [815: 1.00, ]
p: [815: 1.00, ]
n: [611: 1.00, ]
ap_dist: 0.04201538488268852, an_dist: 0.14648455381393433
a: [16: 0.02, 96: 0.03, 319: 0.03, 368: 0.01, 381: 0.04, 456: 0.02, 489: 0.03, 491: 0.01, 595: 0.01, 688: 0.02, 724: 0.05, 752: 0.01, 790: 0.02, 791: 0.17, 815: 0.09, 852: 0.04, 904: 0.03, ]
p: [16: 0.02, 96: 0.02, 319: 0.02, 368: 0.01, 381: 0.03, 456: 0.02, 489: 0.03, 491: 0.01, 554: 0.01, 595: 0.02, 688: 0.02, 724: 0.06, 754: 0.01, 790: 0.03, 791: 0.21, 815: 0.06, 852: 0.05, 904: 0.02, ]
n: [15: 0.01, 16: 0.04, 94: 0.01, 96: 0.05, 303: 0.02, 306: 0.01, 313: 0.01, 904: 0.70, ]
ap_dis

KeyboardInterrupt: 

In [None]:
!cp "/content/gdrive/My Drive/DL/models/resnet50-dir/resnet50-dir-best.pth" "/content/resnet50-best.pth"
learn.load("/content/resnet50-best")

In [None]:
learn.fit(1, lr = 0., wd=0.)

In [None]:
learn.validate(metrics=[feat_loss])

In [None]:
z1 = torch.empty(10).uniform_(-1,1).cuda()
z2 = torch.empty(10).uniform_(-1,1).cuda()
# print("z1: ", z1)
# print("z2: ", z2)
print("distance: ", torch.norm(z1-z2,p=2))
model = learn.model.eval()

z_s = interpolate(z1, z2, 0.15)
print(len(z_s))
for i,z in enumerate(z_s):
  img = noise_to_image(model.forward_single_z(z))
  img.show()
  img.save('./pics/' + str(i) + '.png')

In [256]:
def generate_perturbations(learn, n_perturbations):
  initial_training_mode = learn.model.training
  
  model = learn.model.eval()
  input_img = (learn.data.valid_ds[0][0].data)[None].cuda()
  perturbations = []
  for i in range(n_perturbations):
    perturbation = model(input_img)[0].squeeze()
    perturbations.append(perturbation)
    
  learn.model.train(initial_training_mode)  
  return perturbations
  
  
# def compute_mean_prediction_histogram(learn, perturbations):
#   pred_histogram = [0] * 1000
#   for j, perturbation in enumerate(perturbations):
#     for i in range(len(learn.data.valid_ds)):
#       img = learn.data.valid_ds[i][0].data[None].cuda()
#       perturbed_img = img + perturbation
#       pred = torch.argmax(arch(perturbed_img).squeeze())
#       pred_histogram[pred]+= 1./len(perturbations)
#     print("finished creating histogram for the %dth perturbation"%j)
#   return pred_histogram

  
def compute_mean_prediction_histogram(learn, perturbations):
  pred_histogram = [0] * 1000
  for j, perturbation in enumerate(perturbations):
    batch_no = -1
    for batch, _ in learn.data.valid_dl:
      batch_no += 1
      if batch_no % 100 == 0 : print("at batch no {}".format(batch_no))
      perturbed_batch = batch + perturbation[None]
      preds = arch(perturbed_batch).argmax(1)
      for pred in preds:
        pred_histogram[pred]+= 1. / len(perturbations)
    print("finished creating histogram for the %dth perturbation"%j)

  pred_histogram = np.asarray(np.array(pred_histogram) / len(perturbations))

  return pred_histogram


def diversity(learn, n_perturbations, percentage):
  pred_histogram = compute_mean_prediction_histogram(
      learn, generate_perturbations(learn, n_perturbations)
  )
  print("finished creating the prediction histogram")
  pred_histogram_sum = np.sum(pred_histogram)
  
  indexed_pred_histogram = [(i, hist_element) for i,hist_element in  
                            enumerate(pred_histogram)]
  
  indexed_pred_histogram.sort(key=lambda x: x[1], reverse = True)
  
  cumulative_percent = 0
  n_used_classes = 0
  top_classes = []
  while cumulative_percent < percentage:
    hist_elem = indexed_pred_histogram[n_used_classes]
    cumulative_percent += (hist_elem[1] / pred_histogram_sum) * 100.
    top_classes.append(hist_elem[0])
    n_used_classes += 1
  
  #top_classes is a useful piece of info that is currently unused
  return n_used_classes, indexed_pred_histogram, top_classes

In [None]:
n, hist, tk = diversity(learn, 10, 95)
n, hist, tk

at batch no 0
at batch no 100
at batch no 200
at batch no 300
at batch no 400
at batch no 500
at batch no 600
at batch no 700
at batch no 800
at batch no 900
at batch no 1000
at batch no 1100
at batch no 1200
at batch no 1300
at batch no 1400
at batch no 1500
at batch no 1600
at batch no 1700
at batch no 1800
at batch no 1900
at batch no 2000
at batch no 2100
at batch no 2200
at batch no 2300
at batch no 2400
at batch no 2500
at batch no 2600
at batch no 2700
at batch no 2800
at batch no 2900
at batch no 3000
at batch no 3100
at batch no 3200
at batch no 3300
at batch no 3400
at batch no 3500
at batch no 3600
at batch no 3700
at batch no 3800
at batch no 3900
at batch no 4000
at batch no 4100
at batch no 4200
at batch no 4300
at batch no 4400
at batch no 4500
at batch no 4600
at batch no 4700
at batch no 4800
at batch no 4900
at batch no 5000
at batch no 5100
at batch no 5200
at batch no 5300
at batch no 5400
at batch no 5500
at batch no 5600
at batch no 5700
at batch no 5800
at batch 

In [None]:
# learn.recorder.plot_losses()
# learn.recorder.plot_lr()
# learn.recorder.plot_metrics()

In [None]:
fooling_rates = []
model = learn.model.eval()
learn.metrics = [validation_single_perturbation]
for i in range(10):
  global_perturbations = model(torch.rand(1, 3, 224, 244).cuda())[0]
  nag_util.global_perturbations = global_perturbations
  fooling_rates.append(learn.validate()[1].cpu().item())
  print("%d : %f"%(i, fooling_rates[-1]))

mean = np.mean(fooling_rates)
stddev = np.std(fooling_rates)
print(mean, stddev); print(fooling_rates)

In [None]:
#the Image works good for floats in range [0..1]
model = learn.model.eval()

x_img = learn.data.train_ds[4][0]
x = x_img.data.cuda()
z = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.float32).cuda()
# z = torch.empty(z_dim).uniform_(-1,1).cuda()
p = model.forward_single_z(z).detach()
x = normalize(x)

p_x = x + p
p_x = denormalize(p_x)
p_x.clamp_(0,1)


#prepare images
p_x_img = Image(p_x)
p = scale_to_range(p, [0., 1.])
p_img = Image(p)
# x_img.show()
p_img.show()
# p_x_img.show()

# print_range(p)
# print_range(x)
# print_range(p_x)

In [None]:
z1 = torch.tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], dtype=torch.float32).cuda()
p1 = model.forward_single_z(z1)

z2 = torch.tensor([1, -1, -1, -1, -1, -1, -1, -1, -1, -1], dtype=torch.float32).cuda()
p2 = model.forward_single_z(z2)

z3 = torch.tensor([1, 1, -1, -1, -1, -1, -1, -1, -1, -1], dtype=torch.float32).cuda()
p3 = model.forward_single_z(z3)

l2_distance(p1, p3)

In [None]:
#the Image works good for floats in range [0..1]
model = learn.model.eval()

x_img = learn.data.train_ds[4][0]
x = x_img.data[None].cuda()
p = model(x)[0].squeeze().detach() 
x = x.squeeze()
x = normalize(x)

p_x = x + p
p_x = denormalize(p_x)
p_x.clamp_(0,1)


#prepare images
p_x_img = Image(p_x)
p = scale_to_range(p, [0.,1.])
p_img = Image(p)
# x_img.show()
p_img.show()
# p_x_img.show()

print_range(p)
print_range(x)
print_range(p_x)