In [4]:
import torchvision
from torchvision import utils
from torchvision import datasets
from torchvision import models
import torchvision.transforms as T
import torchvision.datasets as D

from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim
import argparse
import datetime
import os

from lib.models import build_model

from lib import augmentation, build_dataset, teachaugment
from lib.utils import utils, lr_scheduler
from lib.models import build_model
from lib.losses import non_saturating_loss


%matplotlib inline 
from matplotlib import pyplot as plt


In [5]:
dataset = 'MNIST'
root = '/hdd/hdd4/lsj/torchvision_dataset/MNIST'
log_dir = '/hdd/hdd4/lsj/teach_augment'

In [90]:
base_aug, train_trans, val_trans, normalizer = augmentation.get_transforms(dataset)
train_data, eval_data, n_classes = build_dataset(dataset, root, train_trans, val_trans)

In [91]:
subset_indices = [0, 1, 2, 3, 4, 5, 7, 13, 15, 17]
train_data = torch.utils.data.Subset(train_data, subset_indices)
sampler = None
train_loader = DataLoader(train_data, 10,
                          num_workers=8, pin_memory=True,
                        #   drop_last=True
                          )

In [92]:
eval_loader = DataLoader(eval_data, 1)
n_channel = 1
n_class = 10

In [93]:
rbuffer = augmentation.replay_buffer.ReplayBuffer(0.9)

g_offset = 0.5
g_scale = 0.5
g_scale_unlimited = False

c_scale = 0.8
c_scale_unlimited = False
c_shift_unlimited = False

c_reg_coef = 10
device = 'cuda'

trainable_aug = augmentation.build_augmentation(n_class, n_channel,
                                                g_offset, g_scale, g_scale_unlimited,
                                                c_scale, c_scale_unlimited, c_shift_unlimited,
                                                c_reg_coef, normalizer, rbuffer,
                                                1,
                                                True).to(device)


------------------- self.g_scale_unlimited: False -------------------
------------------- self.c_scale_unlimited: False -------------------
------------------- self.c_shift_unlimited: False -------------------


In [96]:
trainable_aug

AugmentationContainer(
  (c_aug): ColorAugmentation(
    (context_layer): Conv2d(10, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (color_enc1): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (color_enc_body): Sequential(
      (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Dropout2d(p=0.8, inplace=False)
      (3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (5): LeakyReLU(negative_slope=0.2, inplace=True)
      (6): Dropout2d(p=0.8, inplace=False)
    )
    (c_regress): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (noise_enc): Sequential(
      (0): Linear(in_features=138, out_features=512, bias=False)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU

In [94]:
base_aug = torch.nn.Sequential(*base_aug).to(device)

In [95]:
trainable_aug.eval()

for i, data in enumerate(train_loader):
    inputs, targets = data
    inputs, targets = inputs.to(device), targets.to(device)
    break

context = targets

In [84]:
c_param, g_param, A = trainable_aug.get_params(inputs, context)

In [85]:
A_mean = torch.mean(A.abs(), dim=0)
A_mean

tensor([[1., 0., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<MeanBackward1>)

In [87]:
scale = c_param[0]
shift = c_param[1]

print(scale.mean())
print(shift.mean())

tensor(1.0000, device='cuda:0', grad_fn=<MeanBackward0>)
tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)


In [68]:
trainable_aug.c_aug

ColorAugmentation(
  (context_layer): Conv2d(10, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (color_enc1): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (color_enc_body): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout2d(p=0.8, inplace=False)
    (3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (5): LeakyReLU(negative_slope=0.2, inplace=True)
    (6): Dropout2d(p=0.8, inplace=False)
  )
  (c_regress): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (noise_enc): Sequential(
    (0): Linear(in_features=138, out_features=512, bias=False)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout(p=0.8, inplace

In [69]:
trainable_aug

AugmentationContainer(
  (c_aug): ColorAugmentation(
    (context_layer): Conv2d(10, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (color_enc1): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (color_enc_body): Sequential(
      (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
      (2): Dropout2d(p=0.8, inplace=False)
      (3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (5): LeakyReLU(negative_slope=0.2, inplace=True)
      (6): Dropout2d(p=0.8, inplace=False)
    )
    (c_regress): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (noise_enc): Sequential(
      (0): Linear(in_features=138, out_features=512, bias=False)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU

In [70]:
for m in trainable_aug.modules():
    if isinstance(m, nn.Conv2d):
        m.weight.data.fill_(0)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.constant_(m.weight, 0)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [71]:
c_param, g_param, A = trainable_aug.get_params(inputs, context)
A_mean = torch.mean(A.abs(), dim=0)
A_mean

tensor([[1., 0., 0.],
        [0., 1., 0.]], device='cuda:0', grad_fn=<MeanBackward1>)

In [72]:
list(trainable_aug.c_aug.named_parameters())

[('logits',
  Parameter containing:
  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
         requires_grad=True)),
 ('context_layer.weight',
  Parameter containing:
  tensor([[[[0.]],
  
           [[0.]],
  
           [[0.]],
  
           ...,
  
           [[0.]],
  
           [[0.]],
  
           [[0.]]],
  
  
          [[[0.]],
  
           [[0.]],
  
           [[0.]],
  
           ...,
  
           [[0.]],
  
           [[0.]],
  
           [[0.]]],
  
  
          [[[0.]],
  
           [[0.]],
  
           [[0.]],
  
           ...,
  
           [[0.]],
  
           [[0.]],
  
           [[0.]]],
  
  
          ...,
  
  
          [[[0.]],
  
           [[0.]],
  
           [[0.]],
  
           ...,
  
           [[0.]],
  
           [[0.]],
  
           [[0.]]],
  
  
          [[[0.]],
  
           [[0.]],
  
           [[0.]],
  
           ...,
  
           [[0.]],
  
           [[0.]],
  
           [[0.]]],
  
  
          [[[0.]],


In [73]:
trainable_aug.g_aug.body

Sequential(
  (0): Linear(in_features=138, out_features=512, bias=False)
  (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (2): LeakyReLU(negative_slope=0.2, inplace=True)
  (3): Dropout(p=0.8, inplace=False)
  (4): Linear(in_features=512, out_features=512, bias=False)
  (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (6): LeakyReLU(negative_slope=0.2, inplace=True)
  (7): Dropout(p=0.8, inplace=False)
)