# SGD playground

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from info_nce import InfoNCE
from losses.loss import sup_con_loss, sup_con_loss_no_norm, Supervised_NT_xent_n, Supervised_NT_xent_uni

In [2]:
def CrossEntropyDistill(outputs, targets, temp=3.0):
    # targets = F.one_hot(targets, self.n_classes_num)
    log_softmax_outputs = F.log_softmax(outputs/temp, dim=1)
    softmax_targets = F.softmax(targets/temp, dim=1)
    return -(log_softmax_outputs * softmax_targets).sum(dim=1).mean()

def Supervised_NT_xent_pre(sim_matrix, labels, temperature=0.5, chunk=2, eps=1e-8, multi_gpu=False):
    '''
        Compute NT_xent loss
        - sim_matrix: (B', B') tensor for B' = B * chunk (first 2B are pos samples)
    '''

    device = sim_matrix.device
  #  labels1 = labels
    labels1 = labels#.repeat(2)


    logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)

    sim_matrix = sim_matrix - logits_max.detach()


    B = sim_matrix.size(0) // chunk  # B = B' / chunk

  #  eye = torch.eye(B * chunk).to(device)  # (B', B')
    sim_matrix = torch.exp(sim_matrix / temperature) #* (1 - eye)  # remove diagonal

    denom = torch.sum(sim_matrix, dim=1, keepdim=True)

    sim_matrix = -torch.log(sim_matrix/(denom+eps)+eps)  # loss matrix

    labels1 = labels1.contiguous().view(-1, 1)

    Mask1 = torch.eq(labels1, labels1.t()).float().to(device)

    Mask1 = Mask1 / (Mask1.sum(dim=1, keepdim=True) + eps)

    return torch.sum(Mask1 * sim_matrix) / (2 * B)


a = torch.randn(4, 100)
b = torch.randn(4, 100)

for temp in [0.07, 0.1, 1.0, 2.0, 3.0]:

    infonce = InfoNCE(temperature=temp)
    ce_loss = CrossEntropyDistill(a,b, temp)
    ce_loss_norm = CrossEntropyDistill(F.normalize(a), F.normalize(b))
    infonce_loss = infonce(a, b)

    y = torch.arange(4)
    sup_con = sup_con_loss(torch.cat((a,b)), temp, torch.cat((y,y)))
    sup_con_no_norm = sup_con_loss_no_norm(torch.cat((a,b)), temp, torch.cat((y,y)))

    sim_matrix = torch.matmul(a, b.t())
    nt_xent_pre = Supervised_NT_xent_pre(sim_matrix, y, temp)

    print(f"\ntemp: {temp}")
    print(f"ce_loss: {ce_loss:.6f}")
    print(f"ce_loss_norm: {ce_loss_norm:.6f}")
    print(f"infonce_loss: {infonce_loss:.6f}")
    print(f"sup_con: {sup_con:.6f}")
    print(f"sup_con_no_norm: {sup_con_no_norm:.6f}")
    print(f"nt_xent_pre: {nt_xent_pre:.6f}")



temp: 0.07
ce_loss: 30.826508
ce_loss_norm: 4.605613
infonce_loss: 0.745558
sup_con: 1.485533
sup_con_no_norm: nan
nt_xent_pre: 4.605170

temp: 0.1
ce_loss: 22.824261
ce_loss_norm: 4.605613
infonce_loss: 0.851042
sup_con: 1.493399
sup_con_no_norm: nan
nt_xent_pre: 4.605170

temp: 1.0
ce_loss: 5.051229
ce_loss_norm: 4.605613
infonce_loss: 1.310059
sup_con: 1.860430
sup_con_no_norm: 5.485090
nt_xent_pre: 1.960001

temp: 2.0
ce_loss: 4.716408
ce_loss_norm: 4.605613
infonce_loss: 1.347467
sup_con: 1.901785
sup_con_no_norm: 2.880279
nt_xent_pre: 1.006603

temp: 3.0
ce_loss: 4.654295
ce_loss_norm: 4.605613
infonce_loss: 1.360251
sup_con: 1.916183
sup_con_no_norm: 2.098509
nt_xent_pre: 0.753990


In [3]:
for temp in [0.07, 0.1, 1.0, 2.0, 3.0]:

    infonce = InfoNCE(temperature=temp)
    ce_loss = CrossEntropyDistill(a,b, temp)
    ce_loss_norm = CrossEntropyDistill(F.normalize(a), F.normalize(b))
    infonce_loss = infonce(a, b)

    y = torch.arange(4)
    sup_con = sup_con_loss(torch.cat((a,b)), temp, torch.cat((y,y)))
    sup_con_no_norm = sup_con_loss_no_norm(torch.cat((a,b)), temp, torch.cat((y,y)))

    sim_matrix = torch.matmul(a, b.t())
    nt_xent_pre = Supervised_NT_xent_pre(sim_matrix, y, temp)

    print(f"\ntemp: {temp}")
    print(f"ce_loss: {ce_loss:.6f}")
    print(f"ce_loss_norm: {ce_loss_norm:.6f}")
    print(f"infonce_loss: {infonce_loss:.6f}")
    print(f"sup_con: {sup_con:.6f}")
    print(f"sup_con_no_norm: {sup_con_no_norm:.6f}")
    print(f"nt_xent_pre: {nt_xent_pre:.6f}")



temp: 0.07
ce_loss: 30.826508
ce_loss_norm: 4.605613
infonce_loss: 0.745558
sup_con: 1.485533
sup_con_no_norm: nan
nt_xent_pre: 4.605170

temp: 0.1
ce_loss: 22.824261
ce_loss_norm: 4.605613
infonce_loss: 0.851042
sup_con: 1.493399
sup_con_no_norm: nan
nt_xent_pre: 4.605170

temp: 1.0
ce_loss: 5.051229
ce_loss_norm: 4.605613
infonce_loss: 1.310059
sup_con: 1.860430
sup_con_no_norm: 5.485090
nt_xent_pre: 1.960001

temp: 2.0
ce_loss: 4.716408
ce_loss_norm: 4.605613
infonce_loss: 1.347467
sup_con: 1.901785
sup_con_no_norm: 2.880279
nt_xent_pre: 1.006603

temp: 3.0
ce_loss: 4.654295
ce_loss_norm: 4.605613
infonce_loss: 1.360251
sup_con: 1.916183
sup_con_no_norm: 2.098509
nt_xent_pre: 0.753990


In [3]:
from utils.rotation_transform import Rotation, rot_inner_all

x = torch.randn(4, 3, 32, 32)
rot_x = Rotation(x)
print(x.size(), rot_x.size())

torch.Size([4, 3, 32, 32]) torch.Size([64, 3, 32, 32])


In [4]:
x = torch.randn(4, 3, 32, 32)
rot_x = rot_inner_all(x)
print(x.size(), rot_x.size())

torch.Size([4, 3, 32, 32]) torch.Size([16, 3, 32, 32])


In [2]:
def GlobalRotation(x):
    return torch.cat((x, torch.rot90(x, 2, (2, 3)), torch.rot90(x, 1, (2, 3)), torch.rot90(x, 3, (2, 3))), dim=0)

x = torch.randn(4, 3, 32, 32)
y = GlobalRotation(x)
print(y.size())

torch.Size([16, 3, 32, 32])
