In [18]:
"""LSTM-based textual encoder for tokenized input"""

from typing import Any

import torch
from torch import nn


class TextEncoder(nn.Module):
    """Simple text encoder based on RNN"""

    def __init__(self, vocab_size: int, emb_dim: int, hidden_dim: int) -> None:
        """
        Initialize embeddings lookup for tokens and main LSTM

        :param vocab_size:
            Size of created vocabulary for textual input. L from paper
        :param emb_dim: Length of embeddings for each word.
        :param hidden_dim:
            Length of hidden state of a LSTM cell. 2 x hidden_dim = C (from LWGAN paper)
        """
        super().__init__()
        self.embs = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, bidirectional=True, batch_first=True)

    def forward(self, tokens: torch.Tensor) -> Any:
        """
        Propagate the text token input through the LSTM and return
        two types of embeddings: word-level and sentence-level

        :param torch.Tensor tokens: Input text tokens from vocab
        :return: Word-level embeddings (BxCxL) and sentence-level embeddings (BxC)
        :rtype: Any
        """
        embs = self.embs(tokens)
        output, (hidden_states, _) = self.lstm(embs)
        word_embs = torch.transpose(output, 1, 2)
        sent_embs = torch.cat((hidden_states[-1, :, :], hidden_states[0, :, :]), dim=1)
        return word_embs, sent_embs


In [19]:
vocab_size = 5000
emb_dim = 300

batch_size = 32
max_seq_len = 18
D = 256
gamma1 = 4 # for both bird, coco
gamma2 = 5 # for both bird, coco
gamma3 = 10 # for both bird, coco
queryL = 10

hidden_dim = D // 2

text_encoder = TextEncoder(vocab_size, emb_dim, hidden_dim)
input_tokens = torch.randint(0, vocab_size, (batch_size, max_seq_len))
input_tokens.shape

torch.Size([32, 18])

In [20]:
input_tokens[0]

tensor([ 139, 1053, 2849, 4876, 2027,  865, 3506,  620,  323, 3078, 3034, 4910,
        3889,  999, 2298, 3799, 2169, 2054])

In [21]:
word_embs, sent_embs = text_encoder(input_tokens)
word_embs.shape, sent_embs.shape

(torch.Size([32, 256, 18]), torch.Size([32, 256]))

In [22]:
word = word_embs[0, :, :10].unsqueeze(0).contiguous()
word.shape

torch.Size([1, 256, 10])

In [23]:
word = word.repeat(batch_size, 1, 1)
word.shape

torch.Size([32, 256, 10])

## func_attention

In [24]:
def func_attention(word):
    batch_size, queryL = word.size(0), word.size(2) #queryL is the length of the sequence
    img_features = torch.randn(batch_size, D, 17, 17)
    context = img_features
    ih, iw = context.size(2), context.size(3)

    sourceL = ih * iw

    context = context.view(batch_size, -1, sourceL) #batch x D x N
    contextT = torch.transpose(context, 1, 2).contiguous() #batch x N x D

    attn = contextT @ word
    attn = attn.view(batch_size * sourceL, queryL) #batch*N x queryL
    attn = nn.Softmax(dim = 1)(attn)
    attn = attn.view(batch_size, sourceL, queryL) #batch x N x queryL
    attn = torch.transpose(attn, 1, 2).contiguous() #batch x queryL x N
    attn = attn.view(batch_size * queryL, sourceL) #batch*queryL x N
    attn = attn * gamma1
    attn = nn.Softmax(dim = 1)(attn)
    attn = attn.view(batch_size, queryL, sourceL) #batch x queryL x N
    attnT = torch.transpose(attn, 1, 2).contiguous() #batch x N x queryL

    weighted_context = context @ attnT #batch x D x queryL
    return weighted_context

In [46]:
def cosine_sim(x1, x2, dim=1, eps=1e-8):
    w12 = torch.sum(x1 * x2, dim)
    w1 = torch.norm(x1, 2, dim)
    w2 = torch.norm(x2, 2, dim)
    return (w12 / (w1 * w2).clamp(min=eps)).squeeze()

In [45]:
word = word.transpose(1, 2).contiguous() #batch x queryL x D
weighted_context = weighted_context.transpose(1, 2).contiguous() #batch x queryL x D

word = word.view(batch_size * queryL, -1) #batch*queryL x D
weighted_context = weighted_context.view(batch_size * queryL, -1) #batch*queryL x D

row_sim = cosine_sim(word, weighted_context)
row_sim = row_sim.view(batch_size, queryL) #batch x queryL

row_sim = torch.exp(row_sim * gamma2) #batch x queryL
row_sim = row_sim.sum(dim = 1, keepdim = True) #batch x 1

row_sim = torch.log(row_sim) #batch x 1
sim_list.append(row_sim)

## with batch loop and combined

In [26]:
vocab_size = 5000
emb_dim = 300

batch_size = 4
max_seq_len = 18
D = 256
gamma1 = 4 # for both bird, coco
gamma2 = 5 # for both bird, coco
gamma3 = 10 # for both bird, coco

hidden_dim = D // 2

text_encoder = TextEncoder(vocab_size, emb_dim, hidden_dim)
input_tokens = torch.randint(0, vocab_size, (batch_size, max_seq_len))

word_embs, sent_embs = text_encoder(input_tokens)

sim_list = []

for batch_idx in range(batch_size):
    word = word_embs[batch_idx, :, :10].unsqueeze(0).contiguous()
    word = word.repeat(batch_size, 1, 1)

    weighted_context = func_attention(word)

    word = word.transpose(1, 2).contiguous() #batch x queryL x D
    weighted_context = weighted_context.transpose(1, 2).contiguous() #batch x queryL x D

    word = word.view(batch_size * queryL, -1) #batch*queryL x D
    weighted_context = weighted_context.view(batch_size * queryL, -1) #batch*queryL x D

    row_sim = cosine_sim(word, weighted_context)
    row_sim = row_sim.view(batch_size, queryL) #batch x queryL

    row_sim = torch.exp(row_sim * gamma2) #batch x queryL
    row_sim = row_sim.sum(dim = 1, keepdim = True) #batch x 1

    row_sim = torch.log(row_sim) #batch x 1
    sim_list.append(row_sim)

similarities = torch.cat(sim_list, dim = 1) #batch x batch
similarities.shape


torch.Size([4, 4])

In [73]:
print(class_ids)

tensor([[10],
        [10],
        [ 8],
        [ 1]])


In [72]:
import numpy as np
class_ids = torch.randint(1, 11, (batch_size, 1))

masks = []

for i in range(batch_size):
    mask = (class_ids == class_ids[i]).numpy().astype(np.uint8)
    mask[i] = 0
    masks.append(mask.reshape(1, -1))

masks[0]

array([[0, 1, 0, 0]], dtype=uint8)

In [75]:
masks

[array([[0, 1, 0, 0]], dtype=uint8),
 array([[1, 0, 0, 0]], dtype=uint8),
 array([[0, 0, 0, 0]], dtype=uint8),
 array([[0, 0, 0, 0]], dtype=uint8)]

In [76]:
masks = np.concatenate(masks, axis = 0)
masks.shape

(4, 4)

In [78]:
masks = torch.BoolTensor(masks)
masks

tensor([[False,  True, False, False],
        [ True, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])

In [79]:
similarities.shape

torch.Size([4, 4])

In [80]:
similarities.data.masked_fill_(masks, -float('inf'))

tensor([[4.2583,   -inf, 4.2598, 4.4091],
        [  -inf, 4.3158, 4.1988, 4.1712],
        [4.2803, 4.2669, 4.1280, 4.3717],
        [4.4219, 4.3610, 4.1831, 4.2365]])

In [81]:
labels = torch.LongTensor(range(batch_size))
loss0 = nn.CrossEntropyLoss()(similarities, labels)
loss0

tensor(1.2860, grad_fn=<NllLossBackward0>)

In [82]:
labels

tensor([0, 1, 2, 3])

In [61]:
labels.shape

torch.Size([32])

In [54]:
similarities.shape

torch.Size([32, 32])

In [89]:
class_ids == class_ids[0]

tensor([[ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [ True],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False],
        [False]])

In [85]:
import torch
a = torch.randn((batch_size, batch_size))
classes = torch.randint(1, 11, (batch_size, 1))
masks = torch.randn((batch_size, batch_size)) > 0.5
a.data.masked_fill_(masks, -float('inf'))
for i in range(batch_size):
    a[i, i] = 0
a

tensor([[ 0.0000,    -inf,  0.2741, -0.6443],
        [ 0.0257,  0.0000,  2.5821, -1.0813],
        [-1.8071,  0.1304,  0.0000,  0.9694],
        [-0.2055,    -inf,    -inf,  0.0000]])

In [87]:
a.shape

torch.Size([4, 4])

In [88]:
true_label = torch.tensor(range(batch_size), dtype=torch.long)
#init cross-entropy loss
criterion = torch.nn.CrossEntropyLoss()
loss0 = criterion(a, true_label)
loss0

tensor(1.4959)

In [93]:
preds = torch.tensor([[2, 3, 5], [4, 0, -1]], dtype = torch.float32) #batch x dim
labels = [2, 0] #batch

loss = torch.nn.CrossEntropyLoss()(preds, torch.tensor(labels))
loss

tensor(0.0973)

In [94]:
a = torch.tensor([0, 2, -float("inf")])
sm = nn.Softmax(dim = 0)
a

tensor([0., 2., -inf])

In [105]:
sm_out = sm(a)
sm_out = sm_out.unsqueeze(0)
sm_out

tensor([[0.1192, 0.8808, 0.0000]])

In [107]:
nll = nn.NLLLoss()
nll(torch.log(sm_out), torch.tensor([2]))

tensor(inf)

In [118]:
file_path = "../../Repo/src/data/birds/train/filenames.pickle"

with open(file_path, "rb") as f:
    filenames = pickle.load(f)

filenames[:5]


['002.Laysan_Albatross/Laysan_Albatross_0002_1027',
 '002.Laysan_Albatross/Laysan_Albatross_0003_1033',
 '002.Laysan_Albatross/Laysan_Albatross_0082_524',
 '002.Laysan_Albatross/Laysan_Albatross_0044_784',
 '002.Laysan_Albatross/Laysan_Albatross_0070_788']

In [116]:
path = "../../Repo/src/data/birds/train/class_info.pickle"
import pickle
with open(path, 'rb') as f:
    class_info = pickle.load(f, encoding='latin1')

class_info[:5]

[2, 2, 2, 2, 2]

In [111]:
len(class_info)

8855

In [113]:
#find unique classes
classes = set()
for i in range(len(class_info)):
    classes.add(class_info[i])

In [115]:
len(classes)

150

In [4]:
import torch
import numpy as np
batch_size = 4
class_ids = torch.randint(1, 11, (batch_size, 1))
mask = (class_ids == class_ids[0])
mask

tensor([[ True],
        [False],
        [False],
        [ True]])

In [5]:
#convert bool tensor mask to torch.int
mask = mask.type(torch.int)
mask

tensor([[1],
        [0],
        [0],
        [1]], dtype=torch.int32)

In [14]:
batch_size = 2
D = 2
numb_words = 3
actual_words = torch.randn(1, D, numb_words)
actual_words

tensor([[[-1.2904,  2.0137,  0.9687],
         [-1.4123,  0.5085, -0.4709]]])

In [15]:
actual_words_ex = actual_words.expand(batch_size, -1, -1) # shape: (batch, D, numb_words)
actual_words_ex

tensor([[[-1.2904,  2.0137,  0.9687],
         [-1.4123,  0.5085, -0.4709]],

        [[-1.2904,  2.0137,  0.9687],
         [-1.4123,  0.5085, -0.4709]]])

In [16]:
actual_words_ex.shape

torch.Size([2, 2, 3])

In [17]:
actual_words.shape

torch.Size([1, 2, 3])

In [62]:
def mean_squared_error(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Computes the mean squared error between two tensors.

    Args:
        input: The input tensor.
        target: The target tensor.
    """
    return nn.MSELoss()(input, target)

In [63]:
real = torch.randn(4, 128, 128, 128)
fake = torch.randn(4, 128, 128, 128)

mean_squared_error(real, fake)

tensor(1.9986)

In [64]:
import torch.nn.functional as F
F.mse_loss(real, fake)

tensor(1.9986)

In [1]:
import torch
batch = 4
D = 10
global_incept_feat = torch.randn(1, batch, D)
incept_feat_norm = torch.norm(global_incept_feat, 2, dim=2, keepdim=True)
incept_feat_norm

tensor([[[2.9519],
         [3.3639],
         [2.8623],
         [2.7975]]])

In [2]:
incept_feat_norm.shape

torch.Size([1, 4, 1])

In [4]:
new_norm = torch.linalg.norm(global_incept_feat, ord = 2, dim = 2, keepdim = True)
new_norm

tensor([[[2.9519],
         [3.3639],
         [2.8623],
         [2.7975]]])

In [5]:
new_norm.shape

torch.Size([1, 4, 1])

In [3]:
help(torch.norm)

Help on function norm in module torch.functional:

norm(input, p='fro', dim=None, keepdim=False, out=None, dtype=None)
    Returns the matrix norm or vector norm of a given tensor.
    
    
        torch.norm is deprecated and may be removed in a future PyTorch release.
        Its documentation and behavior may be incorrect, and it is no longer
        actively maintained.
    
        Use :func:`torch.linalg.norm`, instead, or :func:`torch.linalg.vector_norm`
        when computing vector norms and :func:`torch.linalg.matrix_norm` when
        computing matrix norms. Note, however, the signature for these functions
        is slightly different than the signature for torch.norm.
    
    Args:
        input (Tensor): The input tensor. Its data type must be either a floating
            point or complex type. For complex inputs, the norm is calculated using the
            absolute value of each element. If the input is complex and neither
            :attr:`dtype` nor :attr:`out` is

In [9]:
batch = 5
cap_len = torch.randint(0, 4, (5, 1))


In [12]:
cap_len

tensor([[3],
        [0],
        [0],
        [2],
        [1]])

In [13]:
numb_words = cap_len[0]
numb_words

tensor([3])

In [14]:
b = torch.arange(16).reshape(4, 4)
b

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [15]:
b[2:, :numb_words]

tensor([[ 8,  9, 10],
        [12, 13, 14]])