<a href="https://colab.research.google.com/github/YianKim/2022_uncertainty_aware_semisupervise/blob/main/Keras_UPS_SVHN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install tensorflow_addons

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow_addons
  Downloading tensorflow_addons-0.17.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 8.4 MB/s 
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.17.1


In [None]:
import matplotlib.pyplot as plt
from tensorflow import keras
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import clone_model

import PIL
from PIL import Image

import pickle
import random
from tqdm import tqdm
from collections import Counter

from keras.layers.core import Lambda
from keras import backend as K

from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import BatchNormalization
from keras.regularizers import l2
from keras.layers import Activation
from keras.layers import Dropout
from keras.layers import MaxPooling2D, AveragePooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.layers import Reshape
from keras import optimizers
from keras.callbacks import *
from sklearn.metrics import *
from keras.models import load_model
import tensorflow_addons as tfa

from torchvision import transforms

import tensorflow as tf
import tensorflow.keras.backend as backend
import math
import gc

In [None]:
def dummy_labels(labels):
  zero_labels = np.zeros([labels.shape[0], 10], np.int8)  
  for i in range(labels.shape[0]):
    zero_labels[i][labels[i]] = 1
  return(zero_labels)

# SVHN

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from scipy.io import loadmat
train_raw = loadmat('/content/drive/MyDrive/SVHN/train_32x32.mat')
test_raw = loadmat('/content/drive/MyDrive/SVHN/test_32x32.mat')

In [None]:
train_images = train_raw['X']
train_labels = train_raw['y']

test_images = test_raw['X']
test_labels = dummy_labels(test_raw['y']-1)

train_images = train_images.swapaxes(2,3).swapaxes(1,2).swapaxes(0,1)
test_images = test_images.swapaxes(2,3).swapaxes(1,2).swapaxes(0,1)

In [None]:
temp = [0,0,0,0,0,0,0,0,0,0]
label_indx = []
unlabel_indx = []

for i in range(73257) :
  if temp[(train_labels).reshape([-1])[i]-1] < 25 :
    temp[(train_labels).reshape([-1])[i]-1] += 1
    label_indx.append(i)
  else :
    unlabel_indx.append(i)

In [None]:
lbl_train_images = train_images[label_indx]
lbl_train_labels = dummy_labels(train_labels[label_indx]-1)

In [None]:
ubl_train_images = train_images[unlabel_indx]
ubl_train_labels = dummy_labels(train_labels[unlabel_indx]-1)

# pseudo labeling

In [None]:
def basic_augmentation(imagearray):
  image = Image.fromarray(imagearray)
  tr2 = transforms.RandomRotation(10)
  tr3 = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
  image = tr2(tr3(image))
  return(np.array(image))

def makeaugs(n, input):
  augs = []
  for j in range(n):
    for i in input:
      augs.append(basic_augmentation(np.array(i, np.uint8)))
  return(np.array(augs))

### 스케줄러

In [None]:
class SGDR(Callback):

    def __init__(self, min_lr=0.0, max_lr=0.03, base_epochs=20, mul_epochs=2):
        super(SGDR, self).__init__()

        self.min_lr = min_lr
        self.max_lr = max_lr
        self.base_epochs = base_epochs
        self.mul_epochs = mul_epochs

        self.cycles = 0.
        self.cycle_iterations = 0.
        self.trn_iterations = 0.

        self._reset()

    def _reset(self, new_min_lr=None, new_max_lr=None,
               new_base_epochs=None, new_mul_epochs=None):
        """Resets cycle iterations."""
        
        if new_min_lr != None:
            self.min_lr = new_min_lr
        if new_max_lr != None:
            self.max_lr = new_max_lr
        if new_base_epochs != None:
            self.base_epochs = new_base_epochs
        if new_mul_epochs != None:
            self.mul_epochs = new_mul_epochs
        self.cycles = 0.
        self.cycle_iterations = 0.
        
    def sgdr(self):
        
        cycle_epochs = self.base_epochs * (self.mul_epochs ** self.cycles)
        tide = ((self.cycles == 0) * 1) * (self.cycle_iterations*self.max_lr + (self.base_epochs - self.cycle_iterations)*self.min_lr) / self.base_epochs + ((self.cycles != 0) * 1)*(self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + np.cos(np.pi * (self.cycle_iterations + 1) / cycle_epochs)))
        return tide
        
    def on_train_begin(self, logs=None):
        
        if self.cycle_iterations == 0:
            K.set_value(self.model.optimizer.lr, self.max_lr)
        else:
            K.set_value(self.model.optimizer.lr, self.sgdr())
            
    def on_epoch_end(self, epoch, logs=None):
        
        logs = logs or {}
        logs['lr'] = K.get_value(self.model.optimizer.lr)
        
        self.trn_iterations += 1
        self.cycle_iterations += 1
        if self.cycle_iterations >= self.base_epochs * (self.mul_epochs ** self.cycles):
            self.cycles += 1
            self.cycle_iterations = 0
            K.set_value(self.model.optimizer.lr, self.max_lr)
        else:
            K.set_value(self.model.optimizer.lr, self.sgdr())

### main

In [None]:
def PermaDropout(rate):
    return Lambda(lambda x: K.dropout(x, level=rate))

def create_cnn_13():
  conv1a = Conv2D(32, (3,3), padding = 'same')
  bn1a = BatchNormalization()
  conv1b = Conv2D(32, (3,3), padding = 'same')
  bn1b = BatchNormalization()
  conv1c = Conv2D(32, (3,3), padding = 'same')
  bn1c = BatchNormalization()
  pl1 = MaxPooling2D(2, 2)
  MCdrop1 = PermaDropout(0.4)

  conv2a = Conv2D(64, (3,3), padding = 'same')
  bn2a = BatchNormalization()
  conv2b = Conv2D(64, (3,3), padding = 'same')
  bn2b = BatchNormalization()
  conv2c = Conv2D(64, (3,3), padding = 'same')
  bn2c = BatchNormalization()
  pl2 = MaxPooling2D(2, 2)
  MCdrop2 = PermaDropout(0.4)

  conv3a = Conv2D(128, (3,3))
  bn3a = BatchNormalization()
  conv3b = Conv2D(128, (3,3))
  bn3b = BatchNormalization()
  conv3c = Conv2D(128, (3,3))
  bn3c = BatchNormalization()
  pl3 = AveragePooling2D(4, 1)
  MCdrop3 = PermaDropout(0.4)

  fc1 = Dense(128)
  fc2 = Dense(10)
  activ = keras.layers.LeakyReLU(0.1)

  model = Sequential([
                      keras.Input(shape=(32, 32, 3)), 
                      tfa.layers.WeightNormalization(conv1a), bn1a, activ,
                      tfa.layers.WeightNormalization(conv1b), bn1b, activ,
                      # tfa.layers.WeightNormalization(conv1c), bn1c, activ,
                      pl1, MCdrop1,

                      tfa.layers.WeightNormalization(conv2a), bn2a, activ,
                      tfa.layers.WeightNormalization(conv2b), bn2b, activ,
                      # tfa.layers.WeightNormalization(conv2c), bn2c, activ,
                      pl2, MCdrop2,

                      tfa.layers.WeightNormalization(conv3a), bn3a, activ,
                      tfa.layers.WeightNormalization(conv3b), bn3b, activ,
                      # tfa.layers.WeightNormalization(conv3c), bn3c, activ,
                      pl3, MCdrop3, Flatten(),
                      
                      fc1, activ, fc2
                      ])
  
  return model

def compile_cnn_13(model):

  opt = keras.optimizers.Adam(0.001)

  model.compile(
    optimizer = opt,
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
  )

  return model

def cnn_13():

  model = create_cnn_13()
  model = compile_cnn_13(model)

  return model

def fit_and_labeling_cnn_13(Epoch, Batch):

  X = lbl_train_images
  y = lbl_train_labels
  sgdr = SGDR(min_lr=0.0, max_lr=0.03, base_epochs=20) #스케줄러
  
  augimage, auglabel = makeaugs(5, X), np.concatenate((y,y,y,y,y))
  X = np.concatenate((X, augimage))
  y = np.concatenate((y, auglabel))
  del augimage, auglabel

  model.fit(
      x=X,
      y=y,
      epochs=Epoch,
      verbose=0,
#       validation_data = (valids1, valids2),
      batch_size=Batch,
      # callbacks=[sgdr]
  )
    
  model_test_eval(model, test_images, test_labels)
  T = 1

  for predsamples in (range(10)):
    if predsamples == 0 :
      predictions = np.array(tf.nn.softmax(model.predict(ubl_train_images)/T))
      predictions = predictions.reshape((1,) + predictions.shape)
    else:
      pred = np.array(tf.nn.softmax(model.predict(ubl_train_images)/T))
      pred = pred.reshape((1,) + pred.shape)
      predictions = np.concatenate((predictions, pred))

  return predictions

def model_test_eval(model, test_images, test_labels):
  T = 1
  pred = np.array(tf.nn.softmax(model.predict(test_images)/T))
  for i in range(1,10):
    pred += np.array(tf.nn.softmax(model.predict(test_images)))
  acc = (np.argmax(pred,axis=1) == np.argmax(test_labels,axis=1))*1
  acc = sum(acc)/len(acc)
  print("test set 성능 : " + str(acc))

In [None]:
def label_selecting():
  K_conf = 0.9
  K_uncert = 0.05

  pseudo = np.argmax(np.mean(predictions, axis=0), axis=1)
  conf = np.max(np.mean(predictions, axis=0), axis=1)
  uncert = np.std(predictions, axis=0)
  uncert = np.array([uncert[i][pseudo[i]] for i in range(len(pseudo))])

  select_pseudo = (1*(conf > K_conf)) * (1*(uncert < K_uncert))

  labels = []
  for i in pseudo:
    temp = [0,0,0,0,0,0,0,0,0,0]
    temp[i] = 1
    labels.append(temp)
  pseudo = np.array(labels)
#   pseudo = np.mean(predictions, axis=0)

  lbl_idx = []
  ubl_idx = []
  k = 0
  for i in select_pseudo:
    if i == 1:
      lbl_idx.append(k)
    if i == 0:
      ubl_idx.append(k)
    k += 1

    
  ubl_append = ubl_train_images[lbl_idx]
  pseudo_append = pseudo[lbl_idx]
    
  if itr < 20:
      try: 
        numsamples = np.min(list(Counter(np.argmax(pseudo_append, axis=1)).values()))
      except:
        numsamples = 0
      multlabel = np.argmax(pseudo_append, axis=1)
      sufindx = random.sample(range(len(multlabel)), len(multlabel))

      idxcounter = [0,0,0,0,0,0,0,0,0,0]
      idxsample = []

      for i in sufindx:
#         if idxcounter[multlabel[i]] < numsamples+25:
        if idxcounter[multlabel[i]] < min(500, numsamples):
          idxcounter[multlabel[i]] += 1
          idxsample.append(i)
      
      image1 = np.concatenate((lbl_train_images, ubl_append[idxsample]))
      label1 = np.concatenate((lbl_train_labels, pseudo_append[idxsample]))
      image2 = np.concatenate((ubl_train_images[ubl_idx], ubl_append[np.delete(list(range(len(ubl_append))), idxsample)]))
  
  else:
      image1 = np.concatenate((lbl_train_images, ubl_append))
      label1 = np.concatenate((lbl_train_labels, pseudo_append))
      image2 = ubl_train_images[ubl_idx]

  return image1, label1, image2

In [None]:
import time
start = time.time()
epoch = 500

for itr in range(20):
  model = cnn_13()
  print(Counter(np.argmax(lbl_train_labels, axis=1)))
  predictions = fit_and_labeling_cnn_13(epoch, 64)
  lbl_train_images, lbl_train_labels, ubl_train_images = label_selecting()
  del predictions
  epoch = 50
#   teacher_model = model
  gc.collect()
  print("time :", time.time() - start)

print("time :", time.time() - start)

Counter({0: 25, 8: 25, 1: 25, 2: 25, 4: 25, 7: 25, 6: 25, 3: 25, 5: 25, 9: 25})
test set 성능 : 0.6868085433312846
time : 302.83457040786743
Counter({0: 525, 8: 525, 1: 525, 2: 525, 4: 525, 7: 525, 6: 525, 3: 525, 5: 525, 9: 525})
test set 성능 : 0.7584895513214506
time : 698.5731875896454
Counter({0: 1025, 8: 1025, 1: 1025, 2: 1025, 4: 1025, 7: 1025, 6: 1025, 3: 1025, 5: 1025, 9: 1025})
test set 성능 : 0.7780808236017209
time : 1377.8334851264954
Counter({0: 1525, 8: 1525, 1: 1525, 2: 1525, 4: 1525, 7: 1525, 6: 1525, 3: 1525, 5: 1525, 9: 1525})
test set 성능 : 0.7896051014136447
time : 2256.0211379528046
Counter({0: 2025, 8: 2025, 1: 2025, 2: 2025, 4: 2025, 7: 2025, 6: 2025, 3: 2025, 5: 2025, 9: 2025})
test set 성능 : 0.796020282728949
time : 3350.698147535324
Counter({0: 2525, 8: 2525, 1: 2525, 2: 2525, 4: 2525, 7: 2525, 6: 2525, 3: 2525, 5: 2525, 9: 2525})
test set 성능 : 0.7979025814382299
time : 4731.074002504349
Counter({0: 3025, 8: 3025, 1: 3025, 2: 3025, 4: 3025, 7: 3025, 6: 3025, 3: 3025,