In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }<\style>"))


import argparse
import os
import pandas
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'

import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import tqdm

In [2]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from collections import OrderedDict
from torch.nn.functional import one_hot as one_hot
import torch.utils.data as data
import utils
from model import * 


In [3]:
batch_size = 64

dataset_name = 'tiny-imagenet'

lr = 1e-3
device = 'cuda:0'
tau_plus = 0.01

epochs = 300

In [4]:
train_data, memory_data, test_data = utils.get_dataset(dataset_name, root='../data')

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True, drop_last=True)
memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)

In [5]:
tau_plus =0.1
beta =1
estimator = 'hard'
temperature =0.5
epoch = 20
epochs = 400
k = 200
c = len(train_data.classes)

In [6]:
def get_negative_mask(batch_size):
    negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
    for i in range(batch_size):
        negative_mask[i, i] = 0
        negative_mask[i, i + batch_size] = 0

    negative_mask = torch.cat((negative_mask, negative_mask), 0)
    return negative_mask

def triplet(out_1,out_2,tau_plus,batch_size,temperature, debias = True):
    N = batch_size * 2 - 2
    out = torch.cat([out_1, out_2], dim=0) # 2 * bs x fs
    s = torch.pow(out.unsqueeze(0) - out.unsqueeze(1), 2).sum(-1) # 2 * bs x 2 * bs
    mask = get_negative_mask(batch_size).to(device)
    
    if debias:
        s = s.masked_select(mask).view(2 * batch_size, -1)  # (2 * bs, 2 * bs - 2) : subtract self and its augment

        pos = (torch.pow(out_1 - out_2, 2))
        pos = torch.cat([pos, pos], dim=0).sum(-1)

        neg = (-tau_plus * N * pos + s.sum(-1)) / (1 - tau_plus)

    else:
        neg = s.masked_select(mask).view(2 * batch_size, -1)  # (2 * bs, 2 * bs - 2) : subtract self and its augment

        pos = (torch.pow(out_1 - out_2, 2))
        pos = torch.cat([pos, pos], dim=0).sum(-1)

    return (pos - neg).mean()

def W(out_d, out_b, batch_size):
    mask = get_negative_mask(batch_size).to(device)
    
    # difficulty by distance
#     s_d =  torch.pow(out_d.unsqueeze(0) - out_d.unsqueeze(1), 2).sum(-1)
#     s_d = s_d.masked_select(mask).view(2 * batch_size, -1) / temperature
#     s_b =  torch.pow(out_b.unsqueeze(0) - out_b.unsqueeze(1), 2).sum(-1)
#     s_b = s_b.masked_select(mask).view(2 * batch_size, -1) / temperature
    
#     difficulty by cosine similarity
    s_d = torch.exp(torch.mm(out_d, out_d.t().contiguous()) / temperature)
    s_d = s_d.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, 2 * bs - 2) : subtract self and its augment
    s_d = F.normalize(s_d, dim = -1)
    
    s_b = torch.exp(torch.mm(out_b, out_b.t().contiguous()) / temperature)
    s_b = s_b.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, 2 * bs - 2) : subtract self and its augment
    s_b = F.normalize(s_b, dim = -1)
    
    weight = 1 + s_d / (s_b + s_d + 1e-6)
#     print(weight)
    if np.isnan(weight.sum().item()):
        print('weight NaN')
        
    return weight.detach()
    
def orientation(out_1_d,out_2_d,out_1_b,out_2_b,batch_size):
    #space sharing
    out_d = torch.cat([out_1_d,out_2_d], dim=0)
    out_b = torch.cat([out_1_b,out_2_b], dim=0)
#     print(out_d)
#     print(out_b)
#     print(nn.MSELoss(reduction = 'sum')(out_d, out_b)/batch_size)
#     return nn.MSELoss(reduction = 'sum')(out_d, out_b)/batch_size
#    return (torch.pow(out_d - out_b, 2) / temperature).mean()
    return -torch.log(torch.exp((out_d * out_b).sum(-1)/temperature)).mean() 
    
def criterion(out_1_d, out_2_d, out_1_b, out_2_b, tau_plus, batch_size, beta, temperature):
    # neg score
    out = torch.cat([out_1_d, out_2_d], dim=0)
    out_b = torch.cat([out_1_b, out_2_b], dim=0)
    neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
    mask = get_negative_mask(batch_size).to(device)
    neg = neg.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, bs - 2) : subtract self and its augment

    # pos score
    pos = torch.exp(torch.sum(out_1_d * out_2_d, dim=-1) / temperature)
    pos = torch.cat([pos, pos], dim=0)
    
    weight = W(out, out_b, batch_size) # (2 * bs, bs - 2)

    # negative samples similarity scoring
    N = batch_size * 2 - 2
#         imp = (beta* neg.log()).exp()
#         reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)
    reweight_neg = weight * neg

    Ng = (-tau_plus * N * pos + reweight_neg.sum(dim = -1)) / (1 - tau_plus)
    # constrain (optional)
    Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))

    # contrastive loss
    loss = (-torch.log(pos / (pos + Ng) )).mean()
    
    if np.isnan(loss.mean().item()):
#         print("pos : ", pos)
#         print("Ng : ", Ng)
        print("neg : ", neg)
    
        np.savetxt('pos.txt', pos.detach().cpu().numpy(), delimiter=',')
        np.savetxt('Ng.txt', Ng.detach().cpu().numpy(), delimiter=',')
        np.savetxt('neg.txt', neg.detach().cpu().numpy(), delimiter=',')
        np.savetxt('weight.txt', weight.detach().cpu().numpy(), delimiter=',')

    return loss

In [7]:
model_d = Image_Model()
model_d = nn.DataParallel(model_d).cuda()


model_b = Image_Model()
model_b = nn.DataParallel(model_b).cuda()

In [13]:
save_dir = 'tiny-im/results/{}_v2/{}'.format('wctr', 'tiny-imagenet')

In [15]:
os.path.join(save_dir, 'model_d_512_0.1_50.pth')



'tiny-im/results/wctr_v2/tiny-imagenet/model_d_512_0.1_50.pth'

In [18]:
save_dir = '../results/imagenet/wcl/no_orient_new'
os.path.join(save_dir, 'imagenet_model_b_256_0.01_0.07_0.03_0.005_500.pth')


'../results/imagenet/wcl/no_orient_new/imagenet_model_b_256_0.01_0.07_0.03_0.005_500.pth'

In [19]:
model_d.load_state_dict(torch.load(os.path.join(save_dir, 'imagenet_model_d_256_0.01_0.07_0.03_0.005_500.pth'.format(batch_size,tau_plus,epoch))))
model_b.load_state_dict(torch.load(os.path.join(save_dir, 'imagenet_model_b_256_0.01_0.07_0.03_0.005_500.pth'.format(batch_size,tau_plus,epoch))))

<All keys matched successfully>

In [58]:
pos_1, pos_2, target = iter(train_loader).next()
pos_1, pos_2 = pos_1.cuda(), pos_2.cuda()

feature_1, out_1_d = model_d(pos_1)
feature_2, out_2_d = model_d(pos_2)

feature_1, out_1_b = model_b(pos_1)
feature_2, out_2_b = model_b(pos_2)

#         loss_tri = triplet(out_1_b, out_2_b, batch_size)
loss_tri = triplet(out_1_b, out_2_b, tau_plus, batch_size, temperature, True)
#         loss_ori = orientation(out_1_d, out_2_d, out_1_b, out_2_b, batch_size)
loss_crt = criterion(out_1_d, out_2_d, out_1_b, out_2_b, tau_plus, batch_size, beta, temperature)

In [22]:
out_d = torch.cat([out_1_d, out_1_d], dim = 0)
out_b = torch.cat([out_1_b, out_1_b], dim = 0)

In [32]:
mask = get_negative_mask(batch_size).to(device)

# difficulty by distance
#     s_d =  torch.pow(out_d.unsqueeze(0) - out_d.unsqueeze(1), 2).sum(-1)
#     s_d = s_d.masked_select(mask).view(2 * batch_size, -1) / temperature
#     s_b =  torch.pow(out_b.unsqueeze(0) - out_b.unsqueeze(1), 2).sum(-1)
#     s_b = s_b.masked_select(mask).view(2 * batch_size, -1) / temperature

#     difficulty by cosine similarity
s_d = torch.exp(torch.mm(out_d, out_d.t().contiguous()) / temperature)
s_d = s_d.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, 2 * bs - 2) : subtract self and its augment
s_d = F.normalize(s_d, dim = -1)

s_b = torch.exp(torch.mm(out_b, out_b.t().contiguous()) / temperature)
s_b = s_b.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, 2 * bs - 2) : subtract self and its augment
s_b = F.normalize(s_b, dim = -1)

weight = 1 + s_d / (s_b + s_d + 1e-6)

In [64]:
out = torch.cat([out_1_d, out_2_d], dim=0)
neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
# mask = get_negative_mask(batch_size).to(device)
# neg = neg.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, bs - 2) : subtract self and its augment

In [60]:
pos = torch.exp(torch.sum(out_1_d * out_2_d, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0)

In [61]:
pos.shape

torch.Size([128])

In [63]:
pos

tensor([6.1965, 6.4684, 6.6369, 5.1381, 7.2278, 6.5985, 5.7907, 7.2830, 7.2033,
        5.3678, 7.3236, 6.3417, 7.2716, 7.3307, 5.4713, 6.4377, 6.6081, 5.1806,
        5.3092, 7.1126, 6.7840, 6.6477, 7.3575, 6.4398, 7.3449, 6.0548, 4.6535,
        6.4515, 6.1537, 6.0765, 7.3122, 6.2015, 5.8231, 5.9929, 7.3526, 7.3533,
        7.3270, 5.9478, 7.2632, 7.3291, 7.2754, 6.6456, 7.2601, 6.8021, 7.3570,
        5.1825, 6.6082, 7.3115, 7.3356, 5.6945, 7.2748, 7.3138, 7.3522, 7.3571,
        7.3456, 7.3515, 5.8596, 6.7059, 5.6224, 7.2870, 7.2970, 7.3388, 7.3248,
        6.1510, 6.1965, 6.4684, 6.6369, 5.1381, 7.2278, 6.5985, 5.7907, 7.2830,
        7.2033, 5.3678, 7.3236, 6.3417, 7.2716, 7.3307, 5.4713, 6.4377, 6.6081,
        5.1806, 5.3092, 7.1126, 6.7840, 6.6477, 7.3575, 6.4398, 7.3449, 6.0548,
        4.6535, 6.4515, 6.1537, 6.0765, 7.3122, 6.2015, 5.8231, 5.9929, 7.3526,
        7.3533, 7.3270, 5.9478, 7.2632, 7.3291, 7.2754, 6.6456, 7.2601, 6.8021,
        7.3570, 5.1825, 6.6082, 7.3115, 

In [67]:
N = batch_size * 2 - 2
out = torch.cat([out_1_b, out_2_b], dim=0) # 2 * bs x fs
s = torch.pow(out.unsqueeze(0) - out.unsqueeze(1), 2).sum(-1) # 2 * bs x 2 * bs
mask = get_negative_mask(batch_size).to(device)

s = s.masked_select(mask).view(2 * batch_size, -1)  # (2 * bs, 2 * bs - 2) : subtract self and its augment

pos = (torch.pow(out_1_b - out_2_b, 2))
pos = torch.cat([pos, pos], dim=0).sum(-1)

neg = (-tau_plus * N * pos + s.sum(-1)) / (1 - tau_plus)

In [74]:
pos.shape

torch.Size([128])

In [88]:
(tau_plus * pos - s.mean(-1))/(1-tau_plus)

tensor([-2.1938, -2.2426, -2.2314, -2.2796, -2.2426, -2.3109, -2.2128, -2.2007,
        -2.2298, -2.2087, -2.2826, -2.2054, -2.2269, -2.3162, -2.1834, -2.1977,
        -2.1322, -2.3097, -2.2710, -2.3155, -2.2893, -2.2327, -2.2134, -2.3082,
        -2.3050, -2.3109, -2.2172, -2.2030, -2.2078, -2.2281, -2.2679, -2.2002,
        -2.3042, -2.1847, -2.3172, -2.2421, -2.1983, -2.2168, -2.2100, -2.2937,
        -2.3048, -2.2206, -2.3112, -2.1504, -2.2203, -2.1878, -2.2833, -2.2676,
        -2.1439, -2.1972, -2.2368, -2.3055, -2.2941, -2.3074, -2.2097, -2.2594,
        -2.2567, -2.2960, -2.2355, -2.1983, -2.2047, -2.2086, -2.2510, -2.3137,
        -2.2338, -2.2189, -2.2166, -2.2622, -2.2386, -2.3107, -2.2411, -2.1958,
        -2.2267, -2.2138, -2.2937, -2.1858, -2.2298, -2.3155, -2.1854, -2.1977,
        -2.1837, -2.3125, -2.2577, -2.3157, -2.2897, -2.2372, -2.2127, -2.3038,
        -2.3044, -2.3075, -2.1980, -2.1957, -2.1871, -2.2358, -2.2675, -2.2062,
        -2.3090, -2.1858, -2.3172, -2.24

In [81]:
s

tensor([[0.0332, 3.8393, 0.5072,  ..., 3.0464, 3.9708, 1.4636],
        [0.0332, 3.9513, 0.2904,  ..., 3.3384, 3.9999, 1.1228],
        [3.8393, 3.9513, 3.8956,  ..., 0.3683, 0.0539, 3.2500],
        ...,
        [3.0464, 3.3384, 0.3683,  ..., 0.0042, 0.6789, 3.9226],
        [3.9708, 3.9999, 0.0539,  ..., 0.7792, 0.6789, 2.8563],
        [1.4636, 1.1228, 3.2500,  ..., 3.9543, 3.9226, 2.8563]],
       device='cuda:0', grad_fn=<ViewBackward0>)