In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }<\style>"))


import argparse
import os
import pandas
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3'

import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import tqdm

In [2]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from collections import OrderedDict
from torch.nn.functional import one_hot as one_hot
import torch.utils.data as data
import utils
from model import * 


In [3]:
batch_size = 64

dataset_name = 'tiny-imagenet'

lr = 1e-3
device = 'cuda:0'
tau_plus = 0.01

epochs = 300

In [4]:
train_data, memory_data, test_data = utils.get_dataset(dataset_name, root='../data')

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True, drop_last=True)
memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)

In [5]:
tau_plus =0.1
beta =1
estimator = 'hard'
temperature =0.5
epoch = 20
epochs = 400
k = 200
c = len(train_data.classes)

In [6]:
def get_negative_mask(batch_size):
    negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
    for i in range(batch_size):
        negative_mask[i, i] = 0
        negative_mask[i, i + batch_size] = 0

    negative_mask = torch.cat((negative_mask, negative_mask), 0)
    return negative_mask

def triplet(out_1,out_2,tau_plus,batch_size,temperature, debias = True):
    N = batch_size * 2 - 2
    out = torch.cat([out_1, out_2], dim=0) # 2 * bs x fs
    s = torch.pow(out.unsqueeze(0) - out.unsqueeze(1), 2).sum(-1) # 2 * bs x 2 * bs
    mask = get_negative_mask(batch_size).to(device)
    
    if debias:
        s = s.masked_select(mask).view(2 * batch_size, -1)  # (2 * bs, 2 * bs - 2) : subtract self and its augment

        pos = (torch.pow(out_1 - out_2, 2))
        pos = torch.cat([pos, pos], dim=0).sum(-1)

        neg = (-tau_plus * N * pos + s.sum(-1)) / (1 - tau_plus)

    else:
        neg = s.masked_select(mask).view(2 * batch_size, -1)  # (2 * bs, 2 * bs - 2) : subtract self and its augment

        pos = (torch.pow(out_1 - out_2, 2))
        pos = torch.cat([pos, pos], dim=0).sum(-1)

    return (pos - neg).mean()

def W(out_d, out_b, batch_size):
    mask = get_negative_mask(batch_size).to(device)
    
    # difficulty by distance
#     s_d =  torch.pow(out_d.unsqueeze(0) - out_d.unsqueeze(1), 2).sum(-1)
#     s_d = s_d.masked_select(mask).view(2 * batch_size, -1) / temperature
#     s_b =  torch.pow(out_b.unsqueeze(0) - out_b.unsqueeze(1), 2).sum(-1)
#     s_b = s_b.masked_select(mask).view(2 * batch_size, -1) / temperature
    
#     difficulty by cosine similarity
    s_d = torch.exp(torch.mm(out_d, out_d.t().contiguous()) / temperature)
    s_d = s_d.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, 2 * bs - 2) : subtract self and its augment
    s_d = F.normalize(s_d, dim = -1)
    
    s_b = torch.exp(torch.mm(out_b, out_b.t().contiguous()) / temperature)
    s_b = s_b.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, 2 * bs - 2) : subtract self and its augment
    s_b = F.normalize(s_b, dim = -1)
    
    weight = 1 + s_d / (s_b + s_d + 1e-6)
#     print(weight)
    if np.isnan(weight.sum().item()):
        print('weight NaN')
        
    return weight.detach()
    
def orientation(out_1_d,out_2_d,out_1_b,out_2_b,batch_size):
    #space sharing
    out_d = torch.cat([out_1_d,out_2_d], dim=0)
    out_b = torch.cat([out_1_b,out_2_b], dim=0)
#     print(out_d)
#     print(out_b)
#     print(nn.MSELoss(reduction = 'sum')(out_d, out_b)/batch_size)
#     return nn.MSELoss(reduction = 'sum')(out_d, out_b)/batch_size
#    return (torch.pow(out_d - out_b, 2) / temperature).mean()
    return -torch.log(torch.exp((out_d * out_b).sum(-1)/temperature)).mean() 
    
def criterion(out_1_d, out_2_d, out_1_b, out_2_b, tau_plus, batch_size, beta, temperature):
    # neg score
    out = torch.cat([out_1_d, out_2_d], dim=0)
    out_b = torch.cat([out_1_b, out_2_b], dim=0)
    neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
    mask = get_negative_mask(batch_size).to(device)
    neg = neg.masked_select(mask).view(2 * batch_size, -1) # (2 * bs, bs - 2) : subtract self and its augment

    # pos score
    pos = torch.exp(torch.sum(out_1_d * out_2_d, dim=-1) / temperature)
    pos = torch.cat([pos, pos], dim=0)
    
    weight = W(out, out_b, batch_size) # (2 * bs, bs - 2)

    # negative samples similarity scoring
    N = batch_size * 2 - 2
#         imp = (beta* neg.log()).exp()
#         reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)
    reweight_neg = weight * neg

    Ng = (-tau_plus * N * pos + reweight_neg.sum(dim = -1)) / (1 - tau_plus)
    # constrain (optional)
    Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))

    # contrastive loss
    loss = (-torch.log(pos / (pos + Ng) )).mean()
    
    if np.isnan(loss.mean().item()):
#         print("pos : ", pos)
#         print("Ng : ", Ng)
        print("neg : ", neg)
    
        np.savetxt('pos.txt', pos.detach().cpu().numpy(), delimiter=',')
        np.savetxt('Ng.txt', Ng.detach().cpu().numpy(), delimiter=',')
        np.savetxt('neg.txt', neg.detach().cpu().numpy(), delimiter=',')
        np.savetxt('weight.txt', weight.detach().cpu().numpy(), delimiter=',')

    return loss

In [7]:
model_d = Image_Model()
model_d = nn.DataParallel(model_d).cuda()


model_b = Image_Model()
model_b = nn.DataParallel(model_b).cuda()

In [13]:
save_dir = 'tiny-im/results/{}_v2/{}'.format('wctr', 'tiny-imagenet')

In [14]:
os.path.join(save_dir, 'model_d_512_0.1_50.pth')

'tiny-im/results/wctr_v2/tiny-imagenet/model_d_64_0.1_20.pth'

In [12]:
model_d.load_state_dict(torch.load(os.path.join(save_dir, 'model_d_{}_{}_{}.pth'.format(batch_size,tau_plus,epoch))))
model_b.load_state_dict(torch.load(os.path.join(save_dir, 'model_b_{}_{}_{}.pth'.format(batch_size,tau_plus,epoch))))

FileNotFoundError: [Errno 2] No such file or directory: 'tiny-im/results/wctr_v2/tiny-imagenet/model_d_64_0.1_20.pth'

In [10]:
pos_1, pos_2, target = iter(train_loader).next()
pos_1, pos_2 = pos_1.cuda(), pos_2.cuda()

feature_1, out_1_d = model_d(pos_1)
feature_2, out_2_d = model_d(pos_2)

feature_1, out_1_b = model_b(pos_1)
feature_2, out_2_b = model_b(pos_2)

#         loss_tri = triplet(out_1_b, out_2_b, batch_size)
loss_tri = triplet(out_1_b, out_2_b, tau_plus, batch_size, temperature, True)
#         loss_ori = orientation(out_1_d, out_2_d, out_1_b, out_2_b, batch_size)
loss_crt = criterion(out_1_d, out_2_d, out_1_b, out_2_b, tau_plus, batch_size, beta, temperature)

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/taeuk/network/git/WCTR/image/model.py", line 42, in forward
    feature = self.model(x)
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torchvision/models/resnet.py", line 249, in forward
    return self._forward_impl(x)
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torchvision/models/resnet.py", line 232, in _forward_impl
    x = self.conv1(x)
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 446, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/taeuk/anaconda3/envs/main/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED
