In [2]:
import numpy as np
import torch
from torch import nn

In [3]:
%load_ext tensorboard

In [4]:
def load_treshold(file='../data/treshold.npy'):
    return np.load(file)[0]


In [5]:
treshold = load_treshold()

In [6]:
def read_string_array(arr, dtype=None):
    return np.fromstring(arr[1:-1], dtype=None, sep=' ')

In [7]:
def read_string_ndarray(arr, dtype=None):
    return np.array([read_string_array(target, dtype=None) for target in arr])

In [8]:
import pandas as pd

In [9]:
train_file = '../data/train.csv'
validation_file = '../data/validation.csv'

In [10]:
def read_chunk(chunk):
    X = chunk['image'].values
    y = chunk['label'].to_numpy(dtype=np.int)
    
    return X, y

In [11]:
def to_tensor(X, y, device=torch.device('cpu')):
    return torch.tensor(X, device=device), torch.tensor(y, device=device)

In [12]:
device = torch.device('cuda:2')

In [13]:
def chunk_iterator(df, chunksize=1, limit=None):
    length = len(df)
    i = 0
    
    while True:
        mn = i * chunksize
        mx = mn + min(chunksize, length)

        if limit is not None:
            mx = min(mx, limit)
        
        yield df[mn:mx]
        i += 1
        
        if mx >= length or (limit is not None and mx >= limit):
            raise StopIteration
        

In [14]:
import matplotlib.pyplot as plt
%matplotlib inline

In [15]:
def pairs_iterator(df, chunksize=50000):
    unique_labels = df["label"].unique()
    
    for label in unique_labels:
        positive = df[df["label"] == label]
        negative = df[df["label"] != label]

        size = int(chunksize / 4)
        positive_1 = positive.sample(size, replace=True)
        positive_2 = positive.sample(size, replace=True)
        positive_3 = positive.sample(size, replace=True)
        negative_1 = negative.sample(size, replace=True)
        
        x1_1, y1_1 = read_chunk(positive_1)
        x1_1 = np.array(x1_1.tolist(), dtype=np.float)
        
        x1_2, y1_2 = read_chunk(positive_2)
        x1_2 = np.array(x1_2.tolist(), dtype=np.float)
        
        x2_1, y2_1 = read_chunk(positive_3)
        x2_1 = np.array(x2_1.tolist(), dtype=np.float)
        
        x2_2, y2_2 = read_chunk(negative_1)
        x2_2 = np.array(x2_2.tolist(), dtype=np.float)
        
        x1 = np.array(list(zip(x1_1, x1_2)))
        x1 = np.concatenate(x1, axis=0)
        y1 = (y1_1 == y1_2).reshape(-1, 1).astype(np.int)
        
        x2 = np.array(list(zip(x2_1, x2_2)))
        x2 = np.concatenate(x2, axis=0)
        y2 = -(y2_1 != y2_2).reshape(-1, 1).astype(np.int)
        
        X = np.concatenate([x1, x2])
        Y = np.concatenate([y1, y2])
        
        yield X, Y
    raise StopIteration

In [16]:
from sklearn.utils import shuffle

In [17]:
def triples_iterator(df, size=10000):
    df = shuffle(df)
    unique_labels = df["label"].unique()
    
    for label in unique_labels:
        positive = df[df["label"] == label]
        negative = df[df["label"] != label]

        anchor = positive.sample(size, replace=True)
        positives = positive.sample(size, replace=True)
        negatives = negative.sample(size, replace=True)
        
        x1, _ = read_chunk(anchor)
        x1 = np.array(x1.tolist(), dtype=np.float)
        
        x2, _ = read_chunk(positives)
        x2 = np.array(x2.tolist(), dtype=np.float)
        
        x3, y_ = read_chunk(negatives)
        x3 = np.array(x3.tolist(), dtype=np.float)
        
        X = np.array(list(zip(x1, x2, x3)))
        X = np.concatenate(X, axis=0)
        
        yield X
    raise StopIteration

In [18]:
from torch.nn import Parameter

In [19]:
def recall_at_k(x1, x2, y, k=5):
    y = y > 0
    is_same = y.reshape(len(y))
    y = y[is_same]
    count = torch.sum(is_same)
    
    if not count:
        return 1.0
    
    x1 = x1.detach()[is_same, :].argmax(axis=1).reshape(-1, 1)
    x2 = -x2.detach()[is_same, :]
    x2 = x2.argsort(axis=1)[:, :k]
    y_hat = x1 == x2
    y_hat = torch.sum(y_hat, dim=1).reshape(-1, 1) > 0
    y = y > 0
    recall = torch.sum(y == y_hat).type(torch.float) / count
    return recall

In [20]:
def triplet_recall_at_k(anchor, positives, k=5):
    anchor = anchor.detach().argmax(axis=1).view(-1, 1)
    positives = -positives.detach()
    positives = positives.argsort(axis=1)[:, :k]
    y_hat = anchor == positives
    y_hat = torch.sum(y_hat, dim=1).view(-1, 1) > 0
    recall = torch.sum(y_hat).type(torch.float) / len(anchor)
    return recall

In [21]:
recall_at_k(torch.tensor([[0.9, 0.1, 0.5, 0.6, 0.7],[0.1, 0.5 , 0.2, 0.3, 0.1]]), 
            torch.tensor([[1.0, 0.0, 0.2, 0.3, 0.4], [0.2, 1 , 0.3, 0.5, 0.6]]),
            torch.tensor([[1],[-1]]), k=2)

tensor(1.)

In [22]:
df_train = pd.read_csv(train_file,
                 header=0,
                 converters={'image': read_string_array})

In [23]:
df_train.head()

Unnamed: 0,image,label
0,"[-0.102006525, -0.202537119, 0.113066137, 0.14...",47
1,"[-0.118899353, 0.101335108, 0.0376838185, 0.48...",18
2,"[-0.260352045, -0.338344872, -0.137531623, -0....",41
3,"[0.0528041609, -0.177006856, 0.0429622903, -0....",12
4,"[-0.212558284, 0.015732646, 0.230285764, 0.069...",29


In [24]:
class VAE(nn.Module):
    def __init__(self, tau=1):
        super(VAE, self).__init__()
        self.tau = tau
        
        self.fc1 = nn.Linear(512, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)

        self.embeddings = Parameter(torch.empty((128, 512)).uniform_(-0.5, 0.5).requires_grad_(True))
        
        self.fc3 = nn.Linear(512, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.fc4 = nn.Linear(512, 512)
        
    def update_temperature(self, tau):
        self.tau = tau

    def encode(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = F.tanh(x)
        x = self.fc2(x)
#         x = F.tanh(x)
        return x

    def decode(self, z):
        z = self.fc3(z)
        z = self.bn2(z)
        z = F.tanh(z)
        z = self.fc4(z)
#         z = F.tanh(z)
        return z
    
    def reparametrize(self, q_y):
#         print(q_y)
#         print(self.embeddings)
        z = F.gumbel_softmax(q_y, tau=self.tau, dim=-1)
        z = torch.mm(z, self.embeddings)
#         print(z)
        return z
    
    def forward(self, x):
        encoding = self.encode(x)
        z = self.reparametrize(encoding)
        return self.decode(z), z, encoding

In [25]:
# class VAE(nn.Module):
#     def __init__(self, tau=1):
#         super(VAE, self).__init__()
#         self.tau = tau
        
#         self.fc1 = nn.Linear(512, 256)
#         self.bn1 = nn.BatchNorm1d(256)
#         self.fc2 = nn.Linear(256, 128)
#         self.bn2 = nn.BatchNorm1d(128)
#         self.fc3 = nn.Linear(128, 128 * 512)

# #         self.embeddings = Parameter(torch.empty((128, 512)).uniform_(-0.5, 0.5).requires_grad_(True))
        
#         self.fc4 = nn.Linear(128 * 512, 128)
#         self.bn3 = nn.BatchNorm1d(128)
#         self.fc5 = nn.Linear(128, 256)
#         self.bn4 = nn.BatchNorm1d(256)
#         self.fc6 = nn.Linear(256, 512)
        
#     def update_temperature(self, tau):
#         self.tau = tau

#     def encode(self, x):
#         x = self.fc1(x)
#         x = self.bn1(x)
#         x = F.tanh(x)
#         x = self.fc2(x)
#         x = self.bn2(x)
#         x = F.tanh(x)
# #         x = F.tanh(x)
#         return x

#     def decode(self, z):
#         z = z.view(z.size(0), 128 * 512)
#         z = self.fc4(z)
#         z = self.bn3(z)
#         z = F.tanh(z)
#         z = self.fc5(z)
#         z = self.bn4(z)
#         z = F.tanh(z)
#         z = self.fc6(z)
# #         z = F.tanh(z)
#         return z
    
#     def reparametrize(self, q_y):
# #         print(q_y)
# #         print(self.embeddings)
#         z = F.gumbel_softmax(q_y, tau=self.tau, dim=-1)
#         z = self.fc3(z)
#         z = z.view(z.size(0), 128, 512)
        
# #         z = torch.mm(z, self.embeddings)
# #         print(z)
#         return z
    
#     def forward(self, x):
#         encoding = self.encode(x)
#         z = self.reparametrize(encoding)
#         return self.decode(z), z, encoding

In [26]:
from torch.nn import MSELoss, L1Loss, CosineEmbeddingLoss, KLDivLoss, BCELoss, TripletMarginLoss, BCEWithLogitsLoss
import torch.nn.functional as F

In [27]:
class CosineLoss(nn.Module):
    def __init__(self, reduction='mean', margin=0):
        super(CosineLoss, self).__init__()
        self.margin = margin
        self.reduction = reduction
    def forward(self, x1, x2, y):
        sim = F.cosine_similarity(x1, x2, dim=1).view(-1, 1)
        y = y.view(-1)
        positive = y > 0
        positive_loss = 1 - positive * sim
        negative = y < 0
        negative_loss = negative * sim - self.margin
        negative_loss = torch.max(
            torch.zeros(negative_loss.shape, device=device),
            negative_loss
        )

        if self.reduction == 'sum':
            return torch.sum(positive_loss) + torch.sum(negative_loss)

        return torch.mean(positive_loss + negative_loss)

class TripletCosineLoss(nn.Module):
    def __init__(self, reduction='mean', margin=0):
        super(TripletCosineLoss, self).__init__()
        self.margin = torch.tensor(margin, dtype=torch.float, device=device)
        self.reduction = reduction
    def forward(self, anchor, positives, negatives):
        sim_pos = F.cosine_similarity(anchor, positives, dim=1).view(-1, 1)
        sim_neg = F.cosine_similarity(anchor, negatives, dim=1).view(-1, 1)
        reductor = torch.mean
        
        if self.reduction == 'sum':
            reductor = torch.sum

#         zero = torch.zeros(dist1.shape, device=device)
#         loss = torch.max(zero, self.margin - (dist1 + dist2))
        loss = F.relu(self.margin - sim_pos + sim_neg)
        return reductor(loss)        
        
class VaeLoss(nn.Module):
    def __init__(self, reduction='mean', margin=0):
        super(VaeLoss, self).__init__()
        self.cosine_loss = CosineLoss(margin=margin, reduction=reduction)
        self.mse_loss = MSELoss(reduction=reduction)

    def forward(self, real, pred, x1, x2, y):
        cos_loss = self.cosine_loss(x1, x2, y)
#         print(cos_loss)
        mse_loss = self.mse_loss(real, pred)

        return cos_loss + mse_loss, cos_loss, mse_loss

# class VaeTripletLoss(nn.Module):
#     def __init__(self, reduction='mean', margin=0):
#         super(VaeTripletLoss, self).__init__()
# #         self.cos_loss = CosineLoss(margin=margin, reduction=reduction)
#         self.trilpet_loss = TripletCosineLoss(margin=margin, reduction=reduction)
#         self.mse_loss = MSELoss(reduction=reduction)

#     def forward(self, real, pred, anchor, positive, negative):
#         trilpet_loss = self.trilpet_loss(anchor, positive, negative)
# #         print(cos_loss)
# #         y_pos = torch.tensor([1] * len(anchor), device=device).view(-1, 1)
# #         y_neg = -y_pos
#         mse_loss = self.mse_loss(real, pred)
# #         cos_loss_1 = self.cos_loss(anchor, positive, y_pos)
# #         cos_loss_2 = self.cos_loss(anchor, negative, y_neg)
# #         cos_loss = (cos_loss_1 + cos_loss_2) / 2

#         return trilpet_loss + mse_loss * 10, trilpet_loss, mse_loss, 0
class VaeTripletLoss(nn.Module):
    def __init__(self, reduction='mean', margin=0):
        super(VaeTripletLoss, self).__init__()
#         self.cos_loss = CosineLoss(margin=margin, reduction=reduction)
        self.trilpet_loss = TripletCosineLoss(margin=margin, reduction=reduction)
        self.mse_loss = MSELoss(reduction=reduction)

    def forward(self, real, pred, anchor, positive, negative):
        trilpet_loss = self.trilpet_loss(anchor, positive, negative)
#         print(cos_loss)
#         y_pos = torch.tensor([1] * len(anchor), device=device).view(-1, 1)
#         y_neg = -y_pos
        mse_loss = self.mse_loss(real, pred)
#         cos_loss_1 = self.cos_loss(anchor, positive, y_pos)
#         cos_loss_2 = self.cos_loss(anchor, negative, y_neg)
#         cos_loss = (cos_loss_1 + cos_loss_2) / 2

        return trilpet_loss + mse_loss * 10, trilpet_loss, mse_loss, 0

In [28]:
vae = VAE().to(device)


In [29]:
# vae.load_state_dict(torch.load('vae.torch', map_location=device))

In [30]:
vae_loss = VaeLoss(margin=0, reduction='mean')
vae_triplet_loss = VaeTripletLoss(margin=1, reduction='mean')

In [31]:
def get_optimizer(lr, models):
#     return torch.optim.SGD([{'params': model.parameters()} for model in models],
#                     lr=lr, nesterov=True, momentum=0.9)
    return torch.optim.Adam([{'params': model.parameters()} for model in models],
                    lr=lr)

In [32]:
from torch.utils.tensorboard import SummaryWriter

In [33]:
writer = SummaryWriter()

In [34]:
# vae.train(True)
# lr = 0.99
# for epoch in range(10):
#     train_loss = 0.
#     train_cos_loss = 0.
#     train_mse_loss = 0.
#     train_recall = 0.
#     count = 0
#     optimizer = get_optimizer(lr, [vae])
#     for X, y in pairs_iterator(df_train, chunksize=40000):
#         X = torch.tensor(X, device=device, dtype=torch.float, requires_grad=True)
#         y = torch.tensor(y, device=device, dtype=torch.float)

#         optimizer.zero_grad()
        
#         decoded, thetas, encoding = vae(X)
        
#         pred_1, pred_2 = encoding[0::2], encoding[1::2]
#         x1, x2 = thetas[0::2], thetas[1::2]
#         recall = recall_at_k(pred_1, pred_2, y)
#         loss, cos_loss, mse_loss = vae_loss(X, decoded, x1, x2, y, -1)
# #         loss = criterion(x1, x2, y)
#         loss.backward()
# #         print(head.embeddings.grad)
#         optimizer.step()
#         train_loss += loss.item()
#         train_cos_loss += cos_loss.item()
#         train_mse_loss += mse_loss.item()
#         train_recall += recall.cpu().detach().numpy()
#         count += 1
#         print('Epoch:', epoch, 
#               'Loss:', train_loss / count, 
#               'Recall:', train_recall / count, 
#               'Cos Loss:', train_cos_loss / count, 
#               'MSE Loss:', train_mse_loss / count)
        
#         if count % 100:
#             torch.save(vae.state_dict(), 'vae.torch')
#     lr = max(lr / 10, 1e-10)

In [35]:
vae.train(True)
lr = 1e-3
tau_0 = 10.0
for epoch in range(20):
    train_loss = 0.
    train_cos_loss = 0.
    train_mse_loss = 0.
    train_triplet_loss = 0.
    train_recall = 0.
    count = 0
    vae.update_temperature(tau_0)
    optimizer = get_optimizer(lr, [vae])
#     for X in triples_iterator(df_train, size=10000):
    for X in triples_iterator(df_train, size=1000):
        X = torch.tensor(X, device=device, dtype=torch.float, requires_grad=True)
#         y = torch.tensor(y, device=device, dtype=torch.float)

        optimizer.zero_grad()
        
        decoded, thetas, encoding = vae(X)
        
        pred_anchor, pred_positive = encoding[0::3], encoding[1::3]
        anchor, positive, negative = thetas[0::3], thetas[1::3], thetas[2::3]
        loss, triplet_loss, mse_loss, cos_loss = vae_triplet_loss(X, decoded, anchor, positive, negative)
#         loss = criterion(x1, x2, y)
        loss.backward()
#         print(vae.embeddings.grad)
#         print(vae.embeddings)
        optimizer.step()
#         print(vae.embeddings)
#         for param in vae.parameters():
#             print(param)
        recall = triplet_recall_at_k(pred_anchor, pred_positive)
        train_loss += loss.item()
#         train_cos_loss += cos_loss.item()
        train_mse_loss += mse_loss.item()
        train_triplet_loss += triplet_loss.item()
        train_recall += recall.detach().cpu().numpy()
        count += 1
        writer.add_scalar('Epoch {}/Loss/VAE'.format(epoch), train_loss / count, count)
        writer.add_scalar('Epoch {}/Recall'.format(epoch), train_recall / count, count)
        writer.add_scalar('Epoch {}/Loss/Triplet'.format(epoch), train_triplet_loss / count, count)
        writer.add_scalar('Epoch {}/Loss/MSE'.format(epoch), train_mse_loss / count, count)
        writer.add_scalar('Epoch {}/Temperature'.format(epoch), vae.tau, count)
#         print('Epoch:', epoch, 
#               'Loss:', train_loss / count, 
#               'Recall:', train_recall / count, 
#               'Triplet Loss:', train_triplet_loss / count, 
# #               'Cos Loss:', train_cos_loss / count, 
#               'MSE Loss:', train_mse_loss / count)
        
        if count % 400 == 0 and count > 0:
#             pass
            vae.update_temperature(np.maximum(vae.tau * np.exp(-0.00003 * count), 0.5))
#             print(vae.tau)
        if count % 100 == 0:
            torch.save(vae.state_dict(), 'vae1.torch')
    lr = max(lr / 10, 1e-10)

  


In [36]:
torch.save(vae.state_dict(), 'vae1.torch')

In [None]:
writer.flush()

In [37]:
def predict_cluster_with_vae(vae, file, chunksize=100000):
    vae.eval()
    
    with torch.no_grad():
        count = 0
        df = pd.read_csv(file,
                         chunksize=chunksize,
                         header=0,
                         converters={'image': read_string_array})

        result = []

        for chunk in df:
            print(count + len(chunk))
            X, _ = read_chunk(chunk)
            X = np.array(X.tolist(), dtype=np.float)
            X = torch.tensor(X, device=device, dtype=torch.float)

            clusters = vae.encode(X).detach().argmax(axis=1)
            result.append(clusters)

            count += len(chunk)
        return torch.cat(result, 0).detach().cpu().numpy()

In [38]:
preds = predict_cluster_with_vae(vae, train_file, 100000)

100000
200000
300000
343350


In [39]:
def get_buckets(cluster_predictions):
    buckets = dict()
    
    for i, cluster in enumerate(cluster_predictions):
        if buckets.get(cluster, None) is None:
            buckets.update({cluster: []})
        buckets[cluster].append(i)
    return buckets

In [40]:
clusters = get_buckets(cluster_predictions=preds)
# clusters

In [41]:
len(clusters)

116

In [42]:
for i in clusters:
    print(len(clusters[i]))

126032
3188
3508
1451
2850
5074
1792
1368
702
751
6333
2100
3695
4829
4262
1343
1004
1884
7307
6261
625
2418
8038
6156
665
1980
620
1611
2960
1909
6235
779
1919
2796
1650
337
850
6604
578
11094
2678
826
3873
1577
7796
1461
2791
4800
1107
1931
928
1506
4304
1459
1167
2142
966
468
1433
993
642
1287
1297
1050
959
702
1058
13
587
2532
397
474
1119
1494
1743
769
1341
853
1279
675
863
451
1040
1585
2822
1137
760
2458
1067
712
1786
756
1871
1207
1596
648
1370
792
902
341
1626
643
3154
669
796
352
197
490
659
340
99
3
117
2
3
1


In [43]:
def save_buckets(buckets, file='../data/buckets_vae.npy'):
    np.save(file, [buckets])

In [44]:
save_buckets(clusters)

In [271]:
def collect_by_indexes(df, indexes, chunksize=100000):
    db = chunk_iterator(df, chunksize=chunksize)
    X_result = []
    y_result = []
    
    for chunk in db:
        X, y = read_chunk(chunk)
        X = np.take(X, indexes, mode='clip')
        X = read_string_ndarray(X)
        X = np.array(X.tolist(), dtype=np.float)
        y = np.take(y, indexes, mode='clip')
        
        X_result.append(X)
        y_result.append(y)
    return np.concatenate(X_result), np.concatenate(y_result)

In [269]:
def find_with_vae(vae, clusters, query, df, treshold, chunksize=100000, device=device):
    vae.eval()

    similiarity = torch.nn.CosineSimilarity(dim=1)
    similiarity = similiarity.to(device)
    similiarity.eval()

    with torch.no_grad():
        count = 0

        Q = torch.tensor([query], device=device, dtype=torch.float)
        treshold = torch.tensor(treshold, device=device)
        clusters_hat = vae.encode(Q)
        clusters_sorted = torch.argsort(-clusters_hat).cpu().numpy()[0]
        print(clusters_sorted)

        for cluster_idx in clusters_sorted:
            print(cluster_idx)
            cluster = clusters.get(cluster_idx)
            if not cluster:
                continue
            cluster = np.array(clusters[cluster_idx])

            for indexes in chunk_iterator(cluster, chunksize=chunksize):
                X, y = collect_by_indexes(df, indexes, chunksize)

                X, y = to_tensor(X, y, device)

                sim = similiarity(Q, X)

                idx = torch.argmax(sim)

                if sim[idx] > treshold:
                    return y[idx]
        return None

In [263]:

train_df = pd.read_csv(train_file, header=0)

In [273]:
df = pd.read_csv(validation_file,
                 chunksize=1,
                 skiprows=50000,
                 converters={'image': read_string_array}, names=['image', 'label'])
imgs, label = read_chunk(df.get_chunk())
img = np.array(imgs.tolist()[0], dtype=np.float)

In [274]:
import time

In [275]:
print(label[0] in train_df['label'].values)
start = time.time()
label_hat = find_with_vae(vae, clusters, img, train_df, 0.7, chunksize=10000)
end = time.time()
duration = end - start
print('Duration:', duration)
print('Pred:', label_hat, 'Real', label)

True
[ 94  30  71  84 125  26  63  90  39  67   4  10  93  48  24  31  49  78
   3 122  35 120  22 110  72  76  18  36 126  44 105 103  11 116 119  19
  61  83 121  89  29  88  97  17  98  81   6  23  12  33  57  15 124  60
  43  50  92   7  52  66 114  32  13  79 123  53  47 111  54 118  75 115
  14 117 127  64  99  86  38  58  34 112 104  85  80 107  96  62   9 106
   0  16  41  65  45  73  91  37  27 102  77 108  95  40 113 101  69  56
  87  51  20  82  28  70  46   5   2  55   8  42  68  74   1 109  25  59
  21 100]
94


  


30


KeyboardInterrupt: 

In [49]:
%tensorboard --logdir=runs --port=6007

Reusing TensorBoard on port 6007 (pid 9054), started 0:23:05 ago. (Use '!kill 9054' to kill it.)

In [264]:
logits = torch.randn(20, 32)
# Sample soft categorical using reparametrization trick:
y_soft = F.gumbel_softmax(logits, tau=1, hard=False)
print(y_soft)
# Sample hard categorical using "Straight-through" trick:
y_hard = F.gumbel_softmax(logits, tau=1, hard=True)
print(y_hard)
y_hard - y_soft.detach() + y_soft

tensor([[8.5398e-02, 1.9197e-02, 2.8779e-02, 4.8194e-03, 3.6361e-03, 2.4457e-02,
         5.3987e-03, 1.2014e-02, 1.4412e-01, 1.7968e-02, 4.1328e-03, 2.6410e-02,
         4.4057e-02, 4.2040e-03, 1.5978e-03, 1.0701e-02, 1.4340e-03, 9.4143e-03,
         4.7526e-03, 2.7632e-02, 2.3507e-03, 1.1498e-02, 6.8986e-02, 1.3603e-02,
         3.3304e-02, 5.4837e-02, 7.4567e-03, 5.5239e-02, 1.5254e-02, 2.0303e-01,
         1.6258e-02, 3.8065e-02],
        [1.4381e-02, 4.9327e-03, 3.4376e-02, 3.9673e-03, 7.7558e-02, 3.6949e-03,
         1.8732e-02, 3.0171e-02, 4.1198e-03, 1.4229e-02, 4.8411e-03, 1.3468e-01,
         3.5893e-03, 1.3703e-02, 1.7639e-02, 8.4737e-04, 6.7688e-04, 4.7508e-04,
         1.7213e-02, 1.2203e-01, 2.5529e-03, 2.3087e-03, 5.6765e-02, 3.4419e-01,
         2.1852e-03, 4.1514e-03, 5.6185e-03, 4.7198e-02, 9.6883e-04, 6.1324e-04,
         1.0351e-02, 1.2454e-03],
        [2.6342e-03, 3.0033e-03, 7.0458e-03, 1.9195e-02, 3.5504e-03, 2.0011e-03,
         1.4947e-03, 6.6767e-03, 1.5606e-

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 

In [623]:
x1 = torch.tensor([[0.9,0.2,0.1]], requires_grad=True)
x1

tensor([[0.9000, 0.2000, 0.1000]], requires_grad=True)

In [624]:
x2 = 
x2

tensor([[-0.2713, -0.3969,  0.0938],
        [-0.1533, -0.4739, -0.1860],
        [ 0.4577,  0.4374,  0.1955]], requires_grad=True)

In [625]:
opt = torch.optim.Adam([x1, x2],
                    lr=0.0001)


In [710]:
opt.zero_grad()

In [711]:
res = torch.mm(F.gumbel_softmax(x1, dim=1), x2)
res

tensor([[-0.0450, -0.2233,  0.0075]], grad_fn=<MmBackward>)

In [712]:
loss = 1 - F.cosine_similarity(torch.tensor([[0.1,0.2,0.3]]), res) + F.cosine_similarity(torch.tensor([[-0.1,-0.2,-0.3]]), res)
loss

tensor([2.0999], grad_fn=<AddBackward0>)

In [713]:
loss.backward()

In [714]:
x1.grad

tensor([[-0.1138,  0.6112, -0.4975]])

In [715]:
opt.step()

In [716]:
x2

tensor([[-0.2704, -0.3961,  0.0949],
        [-0.1524, -0.4736, -0.1849],
        [ 0.4586,  0.4381,  0.1965]], requires_grad=True)

In [717]:
x1

tensor([[0.8997, 0.1990, 0.1009]], requires_grad=True)