In [1]:
from task2vec import Task2Vec
from models import get_model
import datasets
import task_similarity
import argparse
import torch
import torch.nn.functional as F
import torchvision.transforms as tt
from tqdm import tqdm

In [2]:
from synbols_utils import Synbols

In [3]:
saved = torch.load('all.pt')
seed = 123

In [4]:
import pandas as pd
ref = pd.read_csv('ref.csv', index_col = 0)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Add Task2Vec as part of the SingleHead Class

class AttentionNet(nn.Module):
    """Soft-Attention on the Embedding z"""
    
    def __init__(self, input_dim, latent_dim):
        super(Attention, self).__init__()
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        self.attention = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.ReLU(inplace=True), #FIXME Try Tanh as well
            nn.Linear(latent_dim,input_dim)
        )

    def forward(self, x):
        x = self.attention(x)
        return F.softmax(x)


class AttentionRawList(nn.Module):
    """Attention through Raw Params
            Input: Param shape, No of Params"""
    def __init__(self, shape, N):
        super(AttentionRaw, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn((shape))) for _ in range(N)])

    def __getitem__(self, index):
        return self.params[index]

    def forward(self, index, x):
        # Weighted/Atttention on x using Attention 'i'
        return self.params[index] * x 

class AttentionRaw(nn.Module):
    """Attention through Raw Params
            Input: Param shape, No of Params"""
    def __init__(self, shape):
        super(AttentionRaw, self).__init__()
        # Attention Params: Shape N_attn, M as N Attentions of dimension M 
        self.params = nn.Parameter(torch.randn(N, task2vec_dim), requires_grad = True)

    def forward(self, x):
        # Weighted/Atttention on x using Attention 'i'
        # X shape: N_z, M, Params shape: N_attn, M
        # Do Element wise multiplication
        return torch.mul(torch.unsqueeze(self.params, 1), x) 

In [6]:
Z = [z.hessian for z in saved]
task2vec_dim = Z[0].shape[0]
latent_dim = 128
N = 4 # No of attributes
M = len(Z) # No of Task vectors

In [7]:
Z_tensor = torch.tensor(Z)

In [8]:
attns = AttentionRaw((N, task2vec_dim)) #AttentionRawList(task2vec_dim, N)

In [9]:
attns_z_matrix = attns(Z_tensor)

import numpy as np
attns_query = torch.mm(attns.params, Z_tensor.T) / np.sqrt(1384)

attn_ = F.softmax(attns_query, 0)

Z_tensor.shape

batch_size = 1024
q_len = 32
v_len = 32
hidden_dim = 128
k_len = 32
query = torch.randn((batch_size, q_len, hidden_dim)) # q_len, hidden_dim @ hidden_dim, k_len 
key = torch.randn((batch_size, k_len, hidden_dim))
value = torch.randn((batch_size, v_len, hidden_dim))

score = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(hidden_dim)
attn = F.softmax(score, -1)
context = torch.bmm(attn, value)

In [10]:

#Attentions.cuda()
Z_loader = torch.utils.data.DataLoader(Z, batch_size = len(Z))
optimizer = torch.optim.Adam(attns.parameters())

In [11]:
for _, batch in enumerate(Z_loader):
    break

In [12]:
attrs = attns(batch)


In [13]:
def _similarity_matrix(attrs, epsilon = 1e-8):
    norm = attrs.norm(dim = -1, keepdim = True)
    norm = torch.maximum(norm, 1e-8 * torch.ones(norm.shape))
    attrs = attrs / norm
    temp = attrs.view(attrs.shape[0], attrs.shape[1],1,1,attrs.shape[2])
    similarity_matrix = torch.mul(temp,attrs).sum(axis = -1)
    return similarity_matrix

In [14]:
similarity_matrix = _similarity_matrix(attrs)

## Sample 

In [15]:
Attns = [[0.5, 0.5], [0.2, 0.8], [0.1, 0.9]] # Three Attns 3x2
Z = [[1,2],[3,4],[5,6],[7,8],[9,10]] # 5 z vectors. 5 x 2
Attns = torch.tensor(Attns)
Z = torch.tensor(Z)
print(f"Attns: {Attns.shape}, Z: {Z.shape}")

Attns: torch.Size([3, 2]), Z: torch.Size([5, 2])


In [16]:
def forward_pass(Attns, Z):
    return torch.mul(torch.unsqueeze(Attns, 1), Z) 

In [17]:
Attrs = forward_pass(Attns, Z)
Attrs.shape

torch.Size([3, 5, 2])

In [54]:
def cos_sim(A,B, eps=1e-8, ):
    A_norm = A.norm(dim = -1, keepdim = True)
    B_norm = B.norm(dim = -1, keepdim = True)
    A_norm = torch.maximum(A_norm, eps * torch.ones(A_norm.shape))
    B_norm = torch.maximum(B_norm, eps * torch.ones(B_norm.shape))
    A = torch.div(A,A_norm)
    B = torch.div(B,B_norm)
    return torch.mm(A,B.T)
cos_sim(Attrs[0],Attrs[0]).flatten().contiguous()

In [76]:
def norm_unitvectors(M, eps = 1e-8):
    M_norm = M.norm(dim = -1, keepdim = True)
    M_norm = torch.maximum(M_norm, eps * torch.ones(M_norm.shape))
    return torch.div(M, M_norm)
def positives(Attrs):
    M = norm_unitvectors(Attrs)
    M_trans = M.transpose(1,2)
    return torch.bmm(M, M_trans)

tensor([1.0000, 0.9839, 0.9734, 0.9676, 0.9640, 0.9839, 1.0000, 0.9987, 0.9972,
        0.9960, 0.9734, 0.9987, 1.0000, 0.9997, 0.9993, 0.9676, 0.9972, 0.9997,
        1.0000, 0.9999, 0.9640, 0.9960, 0.9993, 0.9999, 1.0000])

In [45]:
# Positives
# sim(Attr[i], Attr[j]) = 1 | -1 
# |sim(Attr[i], Attr[j])| = 1
# for all i, and i = j

# Negatives or Zero
#   Negatives:
#       sim(Attr[i], Attr[j]) for all i=/j

# or induce Orthogonality as   Zero:
#   sim(Attr[i], Attr[j]) = 0



def positives(Attrs):
    """
    """
    

tensor([0.9430, 0.9615, 0.9668, 0.9693, 0.9708, 0.8682, 0.8969, 0.9056, 0.9097,
        0.9122, 0.8417, 0.8730, 0.8826, 0.8873, 0.8900, 0.8284, 0.8610, 0.8711,
        0.8759, 0.8788, 0.8205, 0.8538, 0.8641, 0.8691, 0.8720])

In [59]:
cos_matrix = _similarity_matrix(Attrs)
cos_matrix.shape

torch.Size([3, 5, 3, 5])

In [64]:
cos_matrix[0][:][0][:]

tensor([[1.0000, 0.9839, 0.9734, 0.9676, 0.9640],
        [0.9430, 0.9615, 0.9668, 0.9693, 0.9708],
        [0.9179, 0.9285, 0.9318, 0.9335, 0.9345]])

In [25]:
i = 0
j = 6
k = 6
similarity_matrix[i][:][i][:].shape

torch.Size([4, 135])

In [None]:
def positives(sim_matrix, i):
    """
        For all j,k: do summation of cos similarities across matrix for given i
        
        Note: mat[i][j][i][k] represents sim(att[i][j], att[i][k])
        So, for i'th attribute, positives are:
            SUM_j,k mat[i][j][i][k]
    """

In [None]:
def alignment_loss(similarity_matrix):
    

In [None]:
for epoch in range(n_epochs):
    print(f"Epoch:{epoch}")
    attentions.train()
    for batch_ix, Z in enumerate(Z_loader):
        optimizer.zero_grad()
        

In [None]:

for epoch in range(args.epochs):
    print("Epoch", epoch)
    model.train()
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        x, y = batch
        x = x.cuda()
        y = y.cuda()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()

    print("train_loss", float(loss))
    with torch.no_grad():
        model.eval()
        hits = 0
        total = 0
        for batch in tqdm(val_loader):
            x, y = batch
            x = x.cuda()
            y = y.cuda()
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            hits += (logits.argmax(-1) == y).float().sum().item()
            total += x.size(0)

    print("val_loss", float(loss), "val_accuracy", hits / total)
