In [93]:
from __future__ import print_function
import sys
from syslog import LOG_MAIL

import os
from urllib3 import Retry
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import random
import argparse
import numpy as np

from math import log2
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

from IPython.display import clear_output
import ipywidgets as widgets
from ipywidgets import interact, interact_manual, FloatSlider, IntSlider
from IPython.display import Image
import torch
%matplotlib inline

# UNICON Sharpening

In [94]:
def sharpening(labels, T):
    labels_sp = labels**(1/T)
    return labels_sp / labels_sp.sum(dim=0, keepdim=True)

# DINO Sharpening

In [95]:
def sharpening_DINO(labels, T):
    labels_sp = labels / T
    m = nn.Softmax(dim=0)
    return m(labels_sp)

# Visualization Sharpening

In [98]:
@interact
def visDensity(
    UNICON_T = FloatSlider(min=0., max=5., value=0.5, step=0.01),
    DINO_T = FloatSlider(min=0., max=5., value=0.06, step=0.01),
    inputX_0 = FloatSlider(min=0., max=1., value=0.4, step=0.01),
    inputX_1 = FloatSlider(min=0., max=1., value=0.2, step=0.01),
    inputX_2 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
    inputX_3 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
    inputX_4 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
    inputX_5 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
    inputX_6 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
    inputX_7 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
    inputX_8 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
    inputX_9 = FloatSlider(min=0., max=1., value=0.05, step=0.01),
): 
    num_class = 10
    inputX = [inputX_0, inputX_1, inputX_2,inputX_3,inputX_4,inputX_5,inputX_6,inputX_7,inputX_8,inputX_9]
    inputX = torch.FloatTensor(inputX)
    inputX_sp_UNICON = sharpening(inputX, UNICON_T)
    inputX_sp_DINO = sharpening_DINO(inputX, DINO_T)
    plt.figure(figsize=(24,6))
    plt.tight_layout()
    plt.subplot(131)
    plt.ylim(0,1)
    bars = plt.bar(range(0, 10), inputX.numpy())
    plt.bar_label(bars, fmt='%.2f')
    plt.title("Origin")
    plt.xticks(range(0, 10))
    plt.grid(True)
    plt.subplot(132)
    plt.ylim(0,1)
    bars = plt.bar(range(0, 10), inputX_sp_UNICON.numpy())
    plt.bar_label(bars, fmt='%.2f')
    plt.grid(True)
    plt.title("UNICON : " + str(UNICON_T))
    plt.xticks(range(0, 10))
    plt.subplot(133)
    plt.ylim(0,1)
    bars = plt.bar(range(0, 10), inputX_sp_DINO.numpy())
    plt.bar_label(bars, fmt='%.2f')
    plt.title("DINO T: " + str(DINO_T))
    plt.xticks(range(0, 10))
    plt.grid(True)
    plt.show()
    return 

interactive(children=(FloatSlider(value=0.5, description='UNICON_T', max=5.0, step=0.01), FloatSlider(value=0.…

備忘錄：
1. 預設參數
2. 相同強度 Tps 0.1
3. asym的情況 0.55 0.45 
4. 

In [83]:
a = torch.tensor([[ 2.5849e-02,  1.3727e-01, -2.9217e-03, -5.0169e-02,  4.0312e-02,
          9.4924e-02, -1.7414e-02,  1.1302e-01,  3.8955e-02, -1.1567e-01,
         -4.2032e-02,  7.0781e-03, -7.0436e-03, -4.3275e-02, -6.1347e-02,
         -1.4306e-02, -5.9328e-02, -2.7418e-02, -1.0878e-01, -9.2970e-02,
         -2.3346e-02,  1.2615e-01, -4.5262e-02, -1.8080e-03, -4.1264e-02,
          5.5995e-02, -9.7353e-02,  1.0617e-02,  7.0020e-02,  9.7239e-02,
         -3.6515e-02,  5.9807e-04,  1.3264e-01,  5.0183e-02,  1.1850e-01,
          6.3413e-02, -2.6445e-02,  1.7152e-02, -8.6439e-02, -5.8670e-02,
          4.3982e-02, -1.3629e-01, -4.0352e-02,  1.3771e-01,  1.1423e-04,
         -9.1020e-02, -1.2737e-01,  2.9980e-02, -1.2439e-01,  2.8136e-02,
          3.7607e-02, -4.0858e-02,  6.5722e-02, -8.3969e-02, -2.7378e-02,
          1.3255e-01, -7.9063e-02,  5.4204e-02, -2.0350e-02,  1.2682e-02,
          1.2240e-01,  4.2547e-02, -1.6056e-02,  1.2655e-01,  3.1946e-02,
          1.0046e-01,  4.6876e-03, -3.8074e-02, -1.0052e-01, -9.1502e-02,
          5.9118e-02, -4.8431e-02, -1.6149e-01, -3.0141e-02,  1.8592e-01,
          4.0732e-02, -1.9329e-01, -6.4903e-02, -2.9889e-03, -1.1932e-01,
         -1.0286e-01, -1.2362e-02, -1.7452e-01, -7.5198e-02, -1.3520e-02,
          1.1519e-01, -1.0938e-01, -3.7216e-02, -8.1227e-02,  1.0517e-01,
         -1.9429e-01, -1.5507e-01,  1.0705e-01, -1.6433e-02, -7.6020e-02,
          9.6546e-02, -6.9244e-02, -6.9090e-02, -1.1710e-01,  1.1706e-01,
         -1.9317e-01, -7.7380e-02,  4.9766e-03, -6.3141e-02,  5.2052e-02,
          1.3147e-02, -7.2976e-02, -2.6641e-02, -5.5907e-02, -9.4763e-02,
         -5.5690e-03,  2.9907e-02, -5.4849e-02, -1.2614e-01,  1.2306e-02,
          2.2610e-01,  1.9708e-02,  7.9473e-02,  2.9587e-02, -1.4246e-02,
         -9.6205e-02, -2.8490e-01, -9.1136e-02, -3.3081e-02, -1.3401e-01,
         -5.0882e-02, -3.1863e-02, -9.1243e-02]])


b = torch.tensor([[ 6.4785e-02,  1.2803e-01, -3.6303e-02, -7.7526e-02,  5.5587e-02,                                      
          1.0225e-01,  9.0887e-02,  2.4126e-03,  4.9468e-02, -1.1827e-01,    
          5.2309e-02, -5.0036e-02,  1.5834e-02,  1.5710e-02, -2.4707e-01,                                                                                                                                                     
         -4.3604e-02, -8.6595e-02, -4.1258e-02, -7.4943e-02,  2.3530e-02,
         -6.2038e-02, -1.0780e-01, -5.5874e-02, -1.5814e-01, -1.2427e-01,                                                                                                                                                     
          4.4659e-02,  1.3200e-02,  4.2071e-02,  8.1512e-02, -6.1901e-02,  
          7.5295e-02, -5.0164e-02, -9.2014e-02, -1.3430e-01,  6.6311e-02,                                 
          8.0811e-02, -1.2673e-01,  5.7392e-02, -3.2054e-02, -1.3951e-01,             
          1.0002e-01,  1.7893e-02,  1.2977e-01,  8.8536e-02,  6.5056e-02,       
          8.0566e-03,  9.0353e-02, -1.1336e-01, -1.0186e-01,  1.0732e-01,    
          1.7694e-01, -7.9595e-02, -4.0497e-02, -4.3559e-02,  6.4190e-02,
          1.3910e-01, -1.1184e-01,  2.8591e-02,  1.6668e-01, -1.0529e-02,                                
         -1.7412e-02,  1.0820e-01, -6.1826e-02,  1.9421e-01, -1.1821e-01,
          4.7045e-02, -1.2279e-01,  6.8970e-02, -9.7729e-02,  3.6076e-02,     
          2.7115e-04, -1.5079e-03, -1.1539e-01, -1.4463e-01,  1.7668e-01,  
          9.5951e-02, -1.1452e-01,  7.0626e-02, -3.2202e-02,  6.2415e-03,                                
         -6.2763e-02,  1.7882e-02,  1.2298e-01,  2.4615e-02,  2.2870e-02,
          2.8605e-02,  3.1279e-02,  2.9901e-03, -5.9874e-02,  7.9426e-02,      
         -6.2542e-02, -6.5372e-02,  3.9405e-02, -2.1939e-01,  7.0284e-02,
          4.8612e-02, -7.4576e-02,  1.2524e-02, -2.2450e-02,  3.0984e-02,                                
         -1.8122e-01,  8.1006e-02,  7.2468e-02, -8.8894e-02,  3.8744e-02,
         -4.2511e-02,  4.2579e-02,  8.0874e-03,  9.8572e-02, -8.5389e-02,                        
         -1.3015e-01, -3.9547e-02,  7.7807e-02,  7.8035e-02, -4.0452e-02,                             
         -1.7186e-02,  1.0481e-02,  1.1151e-01,  7.2522e-02,  1.4350e-01,                      
         -1.4862e-01,  2.0109e-02, -1.5078e-01,  8.7314e-02, -3.0347e-02,
          3.5852e-02, -1.1390e-03, -5.6330e-02]])

In [70]:
def entropy(p):
    print(p)
    log_p = (p + 1e-10).log2()
    print(log_p)
    return -p.mul((p + 1e-10).log2()).sum(dim=1)

In [77]:
m = nn.GELU()
a1 = m(a)
a1
# a2 = sharpening_DINO(a, 0.06)
# a2
# entropy(a2)

tensor([[ 1.3191e-02,  7.6129e-02, -1.4574e-03, -2.4081e-02,  2.0804e-02,
          5.1051e-02, -8.5860e-03,  6.1595e-02,  2.0083e-02, -5.2509e-02,
         -2.0311e-02,  3.5590e-03, -3.5020e-03, -2.0891e-02, -2.9173e-02,
         -7.0714e-03, -2.8261e-02, -1.3409e-02, -4.9679e-02, -4.3042e-02,
         -1.1456e-02,  6.9407e-02, -2.1814e-02, -9.0270e-04, -1.9953e-02,
          2.9248e-02, -4.4901e-02,  5.3535e-03,  3.6964e-02,  5.2386e-02,
         -1.7726e-02,  2.9918e-04,  7.3318e-02,  2.6096e-02,  6.4839e-02,
          3.3310e-02, -1.2944e-02,  8.6934e-03, -4.0242e-02, -2.7963e-02,
          2.2762e-02, -6.0758e-02, -1.9527e-02,  7.6397e-02,  5.7120e-05,
         -4.2209e-02, -5.7230e-02,  1.5349e-02, -5.6038e-02,  1.4384e-02,
          1.9368e-02, -1.9763e-02,  3.4583e-02, -3.9175e-02, -1.3390e-02,
          7.3264e-02, -3.7040e-02,  2.8274e-02, -1.0010e-02,  6.4052e-03,
          6.7162e-02,  2.1995e-02, -7.9252e-03,  6.9647e-02,  1.6380e-02,
          5.4249e-02,  2.3526e-03, -1.

In [73]:
m = nn.GELU()
i = torch.randn(2)
m(i)

tensor([ 0.2810, -0.1211])

In [91]:
def kl_divergence(p, q):
    print((p+1e-10)/(q+1e-10))
    return (p * ((p+1e-10) / (q+1e-10)).log()).sum(dim=1)

In [92]:
kl_divergence(a,b)

tensor([[ 3.9900e-01,  1.0722e+00,  8.0481e-02,  6.4712e-01,  7.2521e-01,
          9.2835e-01, -1.9160e-01,  4.6846e+01,  7.8748e-01,  9.7802e-01,
         -8.0353e-01, -1.4146e-01, -4.4484e-01, -2.7546e+00,  2.4830e-01,
          3.2809e-01,  6.8512e-01,  6.6455e-01,  1.4515e+00, -3.9511e+00,
          3.7632e-01, -1.1702e+00,  8.1007e-01,  1.1433e-02,  3.3205e-01,
          1.2538e+00, -7.3752e+00,  2.5236e-01,  8.5901e-01, -1.5709e+00,
         -4.8496e-01, -1.1922e-02, -1.4415e+00, -3.7366e-01,  1.7870e+00,
          7.8471e-01,  2.0867e-01,  2.9886e-01,  2.6967e+00,  4.2054e-01,
          4.3973e-01, -7.6169e+00, -3.1095e-01,  1.5554e+00,  1.7559e-03,
         -1.1298e+01, -1.4097e+00, -2.6447e-01,  1.2212e+00,  2.6217e-01,
          2.1254e-01,  5.1332e-01, -1.6229e+00,  1.9277e+00, -4.2652e-01,
          9.5291e-01,  7.0693e-01,  1.8958e+00, -1.2209e-01, -1.2045e+00,
         -7.0296e+00,  3.9323e-01,  2.5970e-01,  6.5161e-01, -2.7025e-01,
          2.1354e+00, -3.8176e-02, -5.

tensor([nan])