## Purpose
Reusable pytorch code

In [3]:
import torch
import torch.nn as nn

In [6]:
# mostly copied from fastai v2
class LinBnDrop(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"

    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [nn.BatchNorm1d(n_out if lin_first else n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin + layers if lin_first else layers + lin
        super().__init__(*layers)

In [32]:
def embed_columns(inp, dim, index, emb):
    """
    Replace columns with their embeddings. Works only with 2-d tensors.
    TODO - make it work for multi-dim tensors

    :param inp: tensor of two or more dimensions
    :param dim: dimension along which tensor should be expanded by inserting the embedding
    :param i: index of tensor along dim which is to be embedded
    :param emb: Embedding of shape [v,d], where v vocab_size and d is embedding dimension
    :return: 
    """
    # create a slice of the data to be replaced with embedding. 
    s = inp.index_select(dim, torch.tensor([index])).squeeze(dim)
    print(f's.shape: {s.shape}')
    embedded = emb(s.type(torch.long))
    print('embedded.shape - ', embedded.shape)
    
    first_indices = torch.arange(0,index)
    last_indices = torch.arange(index+1,inp.size(dim))
    print(first_indices, last_indices)
    return torch.cat([inp.index_select(dim, first_indices), embedded.type(inp.dtype), inp.index_select(dim, last_indices)], axis=dim)


# example
d = torch.tensor([[1,1,1],[0,1,1],[0,0,1]],dtype=torch.float)
print(f'd.shape - {d.shape}')
emb = nn.Embedding(2,3)
embed_columns(d, 0, 1, emb)
# print(f'output.shape - {embed_columns(d, 1,0, emb).shape}')

d.shape - torch.Size([3, 3])
s.shape: torch.Size([3])
embedded.shape -  torch.Size([3, 3])
tensor([0]) tensor([2])


tensor([[ 1.0000,  1.0000,  1.0000],
        [-1.1694,  1.6613,  0.4667],
        [ 0.7971, -1.2839,  2.0029],
        [ 0.7971, -1.2839,  2.0029],
        [ 0.0000,  0.0000,  1.0000]], grad_fn=<CatBackward>)

In [5]:
import torch
import torch.nn as nn
# t = torch.arange()

In [13]:
d = torch.tensor([[1,1,1],[0,1,1],[0,0,1]],dtype=torch.float)


tensor([[1., 1., 1.],
        [0., 1., 1.],
        [0., 0., 1.]])

In [17]:
indices = torch.tensor(1)
d.index_select(0,indices)

tensor([[0., 1., 1.]])

In [16]:
torch.tensor(3)

tensor(3)