In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torch.optim as opt
import itertools

from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from nltk.tokenize import word_tokenize

from allennlp.modules.elmo import Elmo, batch_to_ids

import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/glebn/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
class SharedEncoder(nn.Module):
    def __init__(self, glove_matrix, elmo_dim, hidden_size, num_layers=2):
        super(SharedEncoder, self).__init__()
        vocab_size, glove_dim = glove_matrix.shape
        self.glove_embedding = nn.Embedding.from_pretrained(torch.tensor(glove_matrix, dtype=torch.float), freeze=False)
        self.elmo_dim = elmo_dim
        self.bilstm = nn.LSTM(glove_dim + elmo_dim, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True)

    def forward(self, glove_inputs, elmo_inputs, lengths):
        glove_embedded = self.glove_embedding(glove_inputs)
        combined_embeddings = torch.cat([glove_embedded, elmo_inputs], dim=-1)
        packed = pack_padded_sequence(combined_embeddings, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.bilstm(packed)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        pooled = torch.max(output, dim=1)[0]
        return pooled

In [None]:
def load_glove_embeddings(filepath, embedding_dim):
    embeddings_index = {}
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            if len(values) != embedding_dim + 1:
                continue
            word = values[0]
            embedding = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = embedding
    return embeddings_index

glove_filepath = './glove.840B.300d.txt'
embedding_dim = 300
glove_embeddings = load_glove_embeddings(glove_filepath, embedding_dim)

In [5]:
def create_embedding_matrix(vocab, glove_embeddings, embedding_dim):
    vocab_size = len(vocab)
    embedding_matrix = np.zeros((vocab_size, embedding_dim))
    for word, idx in vocab.items():
        embedding_vector = glove_embeddings.get(word)
        if embedding_vector is not None:
            embedding_matrix[idx] = embedding_vector
        else:
            embedding_matrix[idx] = np.random.normal(scale=0.6, size=(embedding_dim,))
    return torch.tensor(embedding_matrix, dtype=torch.float)

In [6]:
elmo_options_file = './elmo_options.json'
elmo_weight_file = './elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5'
elmo = Elmo(elmo_options_file, elmo_weight_file, num_output_representations=1, dropout=0)

In [7]:
import pickle

with open("balanced_datasets_2.pkl", "rb") as file:
    loaded_datasets = pickle.load(file)

train_msr = loaded_datasets["train_msr"]
train_rte = loaded_datasets["train_rte"]
train_qnli = loaded_datasets["train_qnli"]
train_qqp = loaded_datasets["train_qqp"]
train_mnli = loaded_datasets["train_mnli"]
train_sst = loaded_datasets["train_sst"]
train_cola = loaded_datasets["train_cola"]

In [8]:
def preprocess_sentence_with_embeddings(sentence, tokenizer, glove_embeddings, elmo):
    tokens = tokenizer(sentence)
    glove_vectors = []
    for token in tokens:
        if token in glove_embeddings:
            glove_vectors.append(glove_embeddings[token])
        else:
            glove_vectors.append(np.zeros(300))
    character_ids = batch_to_ids([tokens])
    elmo_output = elmo(character_ids)
    elmo_vectors = elmo_output['elmo_representations'][0].detach().numpy()[0]
    concatenated_embeddings = np.concatenate([glove_vectors, elmo_vectors], axis=1)
    return concatenated_embeddings

In [9]:
class GLUEDataset(Dataset):
    def __init__(self, dataframe, tokenizer, glove_embeddings, elmo, max_length=50):
        self.sentences = dataframe['sentence'].values
        self.labels = dataframe['label'].values
        self.tokenizer = tokenizer
        self.glove_embeddings = glove_embeddings
        self.elmo = elmo
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]
        embeddings = preprocess_sentence_with_embeddings(sentence, self.tokenizer, self.glove_embeddings, self.elmo)
        if embeddings.shape[0] > self.max_length:
            embeddings = embeddings[:self.max_length]
        else:
            pad_length = self.max_length - embeddings.shape[0]
            pad_vector = np.zeros((pad_length, embeddings.shape[1]))
            embeddings = np.vstack([embeddings, pad_vector])
        return {
            'embeddings1': torch.tensor(embeddings, dtype=torch.float),
            'label': torch.tensor(label, dtype=torch.long)
        }
        
class GLUEDatasetTwoSentence(Dataset):
    def __init__(self, dataframe, tokenizer, glove_embeddings, elmo, max_length=50):
        self.sent1 = dataframe['sentence1'].values
        self.sent2 = dataframe['sentence2'].values
        self.labels = dataframe['label'].values
        self.tokenizer = tokenizer
        self.glove_embeddings = glove_embeddings
        self.elmo = elmo
        self.max_length = max_length

    def __len__(self):
        return len(self.sent1)

    def __getitem__(self, idx):
        sentence1 = self.sent1[idx]
        sentence2 = self.sent2[idx]
        label = self.labels[idx]

        # Process embeddings for both sentences
        embeddings1 = preprocess_sentence_with_embeddings(sentence1, self.tokenizer, self.glove_embeddings, self.elmo)
        embeddings2 = preprocess_sentence_with_embeddings(sentence2, self.tokenizer, self.glove_embeddings, self.elmo)

        # Ensure max length by padding or truncating
        embeddings1 = self._pad_or_truncate(embeddings1)
        embeddings2 = self._pad_or_truncate(embeddings2)

        # Concatenate embeddings for both sentences
        # combined_embeddings = np.concatenate([embeddings1, embeddings2], axis=0)

        return {
            'embeddings1': torch.tensor(embeddings1, dtype=torch.float),
            'embeddings2': torch.tensor(embeddings2, dtype=torch.float),
            'label': torch.tensor(label, dtype=torch.long)
        }

    def _pad_or_truncate(self, embeddings):
        if embeddings.shape[0] > self.max_length:
            return embeddings[:self.max_length]
        else:
            pad_length = self.max_length - embeddings.shape[0]
            pad_vector = np.zeros((pad_length, embeddings.shape[1]))
            return np.vstack([embeddings, pad_vector])
        
class GLUEDatasetTwoSentenceSTS(Dataset):
    def __init__(self, dataframe, tokenizer, glove_embeddings, elmo, max_length=50):
        self.sent1 = dataframe['sentence1'].values
        self.sent2 = dataframe['sentence2'].values
        self.labels = dataframe['score'].values
        self.tokenizer = tokenizer
        self.glove_embeddings = glove_embeddings
        self.elmo = elmo
        self.max_length = max_length

    def __len__(self):
        return len(self.sent1)

    def __getitem__(self, idx):
        sentence1 = self.sent1[idx]
        sentence2 = self.sent2[idx]
        label = self.labels[idx]

        # Process embeddings for both sentences
        embeddings1 = preprocess_sentence_with_embeddings(sentence1, self.tokenizer, self.glove_embeddings, self.elmo)
        embeddings2 = preprocess_sentence_with_embeddings(sentence2, self.tokenizer, self.glove_embeddings, self.elmo)

        # Ensure max length by padding or truncating
        embeddings1 = self._pad_or_truncate(embeddings1)
        embeddings2 = self._pad_or_truncate(embeddings2)

        # Concatenate embeddings for both sentences
        # combined_embeddings = np.concatenate([embeddings1, embeddings2], axis=0)

        return {
            'embeddings1': torch.tensor(embeddings1, dtype=torch.float),
            'embeddings2': torch.tensor(embeddings2, dtype=torch.float),
            'label': torch.tensor(label, dtype=torch.float)
        }

    def _pad_or_truncate(self, embeddings):
        if embeddings.shape[0] > self.max_length:
            return embeddings[:self.max_length]
        else:
            pad_length = self.max_length - embeddings.shape[0]
            pad_vector = np.zeros((pad_length, embeddings.shape[1]))
            return np.vstack([embeddings, pad_vector])


In [10]:
class GLUEDatasetTest(Dataset):
    def __init__(self, dataframe, tokenizer, glove_embeddings, elmo, max_length=50):
        self.sentences = dataframe['sentence'].values
        # self.labels = dataframe['label'].values
        self.tokenizer = tokenizer
        self.glove_embeddings = glove_embeddings
        self.elmo = elmo
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        # label = self.labels[idx]
        embeddings = preprocess_sentence_with_embeddings(sentence, self.tokenizer, self.glove_embeddings, self.elmo)
        if embeddings.shape[0] > self.max_length:
            embeddings = embeddings[:self.max_length]
        else:
            pad_length = self.max_length - embeddings.shape[0]
            pad_vector = np.zeros((pad_length, embeddings.shape[1]))
            embeddings = np.vstack([embeddings, pad_vector])
        return {
            'embeddings1': torch.tensor(embeddings, dtype=torch.float),
            # 'label': torch.tensor(label, dtype=torch.long)
        }
        
class GLUEDatasetTwoSentenceTest(Dataset):
    def __init__(self, dataframe, tokenizer, glove_embeddings, elmo, max_length=50):
        self.sent1 = dataframe['sentence1'].values
        self.sent2 = dataframe['sentence2'].values
        # self.labels = dataframe['label'].values
        self.tokenizer = tokenizer
        self.glove_embeddings = glove_embeddings
        self.elmo = elmo
        self.max_length = max_length

    def __len__(self):
        return len(self.sent1)

    def __getitem__(self, idx):
        sentence1 = self.sent1[idx]
        sentence2 = self.sent2[idx]
        # label = self.labels[idx]

        # Process embeddings for both sentences
        embeddings1 = preprocess_sentence_with_embeddings(sentence1, self.tokenizer, self.glove_embeddings, self.elmo)
        embeddings2 = preprocess_sentence_with_embeddings(sentence2, self.tokenizer, self.glove_embeddings, self.elmo)

        # Ensure max length by padding or truncating
        embeddings1 = self._pad_or_truncate(embeddings1)
        embeddings2 = self._pad_or_truncate(embeddings2)

        # Concatenate embeddings for both sentences
        # combined_embeddings = np.concatenate([embeddings1, embeddings2], axis=0)

        return {
            'embeddings1': torch.tensor(embeddings1, dtype=torch.float),
            'embeddings2': torch.tensor(embeddings2, dtype=torch.float),
            # 'label': torch.tensor(label, dtype=torch.long)
        }

    def _pad_or_truncate(self, embeddings):
        if embeddings.shape[0] > self.max_length:
            return embeddings[:self.max_length]
        else:
            pad_length = self.max_length - embeddings.shape[0]
            pad_vector = np.zeros((pad_length, embeddings.shape[1]))
            return np.vstack([embeddings, pad_vector])

In [11]:
tokenizer = word_tokenize
cola_train_dataset = GLUEDataset(train_cola, tokenizer, glove_embeddings, elmo, max_length=50)
sst_train_dataset = GLUEDataset(train_sst, tokenizer, glove_embeddings, elmo, max_length=50)

In [12]:
mnli_train_dataset = GLUEDatasetTwoSentence(train_mnli, tokenizer, glove_embeddings, elmo, max_length=50)
msr_train_dataset = GLUEDatasetTwoSentence(train_msr, tokenizer, glove_embeddings, elmo, max_length=50)
qnli_train_dataset = GLUEDatasetTwoSentence(train_qnli, tokenizer, glove_embeddings, elmo, max_length=50)
qqp_train_dataset = GLUEDatasetTwoSentence(train_qqp, tokenizer, glove_embeddings, elmo, max_length=50)
rte_train_dataset = GLUEDatasetTwoSentence(train_rte, tokenizer, glove_embeddings, elmo, max_length=50)

In [13]:
batch_size = 24

cola_train_loader = DataLoader(cola_train_dataset, batch_size=20) # 8551, 268
sst_train_loader = DataLoader(sst_train_dataset, batch_size=20) # 67349, 2105
mnli_train_loader = DataLoader(mnli_train_dataset, batch_size=28) # 391171, 12225
msr_train_loader = DataLoader(msr_train_dataset, batch_size=12) # 3260, 102
qnli_train_loader = DataLoader(qnli_train_dataset, batch_size=28) # 103106, 3223
qqp_train_loader = DataLoader(qqp_train_dataset, batch_size=28) # 363846, 11371
rte_train_loader = DataLoader(rte_train_dataset, batch_size=12) # 78, 2489

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [15]:
embedding_dim = 300 + 1024  # GloVe (300) + ELMo (1024)
hidden_size = 1500
num_classes_cola = 2  # Binary classification
num_classes_sst2 = 2  # Binary classification

In [16]:
cola_iter = iter(cola_train_loader)
sst_iter = iter(sst_train_loader)
mnli_iter = iter(mnli_train_loader)
msr_iter = iter(msr_train_loader)
qnli_iter = iter(qnli_train_loader)
qqp_iter = iter(qqp_train_loader)
rte_iter = iter(rte_train_loader)

# Containers to hold all embeddings
cola_all = []
sst_all = []
mnli_all = []
msr_all = []
qnli_all = []
qqp_all = []
rte_all = []

for step in tqdm(range(10)):  # or more steps!
    cola_batch = next(cola_iter)
    cola_embeddings = cola_batch['embeddings1'].to(device)  # shape: [batch_size, 50, 1324]
    cola_labels = cola_batch['label'].to(device)

    sst_batch = next(sst_iter)
    sst_embeddings = sst_batch['embeddings1'].to(device)    # shape: [batch_size, 50, 1324]
    sst_labels = sst_batch['label'].to(device)
    
    mnli_batch = next(mnli_iter)
    mnli_embeddings = mnli_batch['embeddings1'].to(device)    # shape: [batch_size, 50, 1324]
    mnli_labels = mnli_batch['label'].to(device)
    
    msr_batch = next(msr_iter)
    msr_embeddings = msr_batch['embeddings1'].to(device)    # shape: [batch_size, 50, 1324]
    msr_labels = msr_batch['label'].to(device)
    
    qnli_batch = next(qnli_iter)
    qnli_embeddings = qnli_batch['embeddings1'].to(device)    # shape: [batch_size, 50, 1324]
    qnli_labels = qnli_batch['label'].to(device)
    
    qqp_batch = next(qqp_iter)
    qqp_embeddings = qqp_batch['embeddings1'].to(device)    # shape: [batch_size, 50, 1324]
    qqp_labels = qqp_batch['label'].to(device)
    
    rte_batch = next(rte_iter)
    rte_embeddings = rte_batch['embeddings1'].to(device)    # shape: [batch_size, 50, 1324]
    rte_labels = rte_batch['label'].to(device)

    # Flatten across time (sequence) dimension
    cola_all.append(cola_embeddings.view(-1, 1324))  # [batch_size * 50, 1324]
    sst_all.append(sst_embeddings.view(-1, 1324))    # same here
    mnli_all.append(mnli_embeddings.view(-1, 1324))  # [batch_size * 50, 1324]
    msr_all.append(msr_embeddings.view(-1, 1324))
    qnli_all.append(qnli_embeddings.view(-1, 1324))  # [batch_size * 50, 1324]
    qqp_all.append(qqp_embeddings.view(-1, 1324))
    rte_all.append(rte_embeddings.view(-1, 1324))  # [batch_size * 50, 1324]

# Combine into full matrices
cola_matrix = torch.cat(cola_all, dim=0)  # shape: [total_Cola_sentences * 50, 1324]
sst_matrix = torch.cat(sst_all, dim=0)    # shape: [total_SST_sentences * 50, 1324]
mnli_matrix = torch.cat(mnli_all, dim=0)  # shape: [total_Cola_sentences * 50, 1324]
msr_matrix = torch.cat(msr_all, dim=0)
qnli_matrix = torch.cat(qnli_all, dim=0)  # shape: [total_Cola_sentences * 50, 1324]
qqp_matrix = torch.cat(qqp_all, dim=0)
rte_matrix = torch.cat(rte_all, dim=0)  # shape: [total_Cola_sentences * 50, 1324]

100%|██████████| 10/10 [09:38<00:00, 57.89s/it]


In [21]:
non_zero_rows = torch.nonzero(cola_matrix.abs().sum(dim=1) > 0).squeeze()
cola_matrix = cola_matrix[non_zero_rows]
non_zero_rows = torch.nonzero(sst_matrix.abs().sum(dim=1) > 0).squeeze()
sst_matrix = sst_matrix[non_zero_rows]
non_zero_rows = torch.nonzero(mnli_matrix.abs().sum(dim=1) > 0).squeeze()
mnli_matrix = mnli_matrix[non_zero_rows]
non_zero_rows = torch.nonzero(msr_matrix.abs().sum(dim=1) > 0).squeeze()
msr_matrix = msr_matrix[non_zero_rows]
non_zero_rows = torch.nonzero(qnli_matrix.abs().sum(dim=1) > 0).squeeze()
qnli_matrix = qnli_matrix[non_zero_rows]
non_zero_rows = torch.nonzero(qqp_matrix.abs().sum(dim=1) > 0).squeeze()
qqp_matrix = qqp_matrix[non_zero_rows]
non_zero_rows = torch.nonzero(rte_matrix.abs().sum(dim=1) > 0).squeeze()
rte_matrix = rte_matrix[non_zero_rows]

In [22]:
min_size = min(cola_matrix.shape[0], sst_matrix.shape[0], mnli_matrix.shape[0], msr_matrix.shape[0], qnli_matrix.shape[0], qqp_matrix.shape[0], rte_matrix.shape[0])

In [24]:
sst_matrix = sst_matrix[:min_size]
cola_matrix = cola_matrix[:min_size]
mnli_matrix = mnli_matrix[:min_size]
msr_matrix = msr_matrix[:min_size]
qnli_matrix = qnli_matrix[:min_size]
qqp_matrix = qqp_matrix[:min_size]
rte_matrix = rte_matrix[:min_size]

In [None]:
def rbf_kernel(x, y, sigma):
    # x_norm = (x ** 2).sum(dim=1).view(-1, 1)  # Shape (m, 1)
    # y_norm = (y ** 2).sum(dim=1).view(1, -1)  # Shape (1, n)
    # dist = x_norm + y_norm - 2.0 * torch.mm(x, y.t())  # Shape (m, n)
    xx=torch.mm(x,x.t())
    yy=torch.mm(y,y.t())
    xy=torch.mm(x,y.t())
    return torch.exp(-(xx+yy-xy) / (2 * sigma ** 2))  # Apply Gaussian kernel

def compute_mmd2(X, Y, kernel_fn, sigma_x, sigma_y, sigma_xy):
    K_xx = kernel_fn(X, X, sigma_x)  # Shape (m, m)
    K_yy = kernel_fn(Y, Y, sigma_y)  # Shape (n, n)
    K_xy = kernel_fn(X, Y, sigma_xy)  # Shape (m, n)
    
    XX = K_xx.sum() - torch.diag(K_xx).sum()  # Sum all elements excluding diagonal
    YY = K_yy.sum() - torch.diag(K_yy).sum()  # Sum all elements excluding diagonal
    XY = K_xy[:10].sum()  # Sum of all off-diagonal elements
    print(K_xy.shape)
    
    print(XX,YY,XY)
    
    m = X.shape[0]

    mmd2 = XX.item()/(m*(m-1)) + YY.item()/(m*(m-1)) - 2 * XY.item()/(m*(m)) # Normalize by number of pairs

    return mmd2, K_xy

tasks = {
    'sst_matrix': sst_matrix,
    'cola_matrix': cola_matrix,
    'mnli_matrix': mnli_matrix,
    'msr_matrix': msr_matrix,
    'qnli_matrix': qnli_matrix,
    'qqp_matrix': qqp_matrix,
    'rte_matrix': rte_matrix,
}

combinations = list(itertools.combinations(tasks.items(), 2))

graph_weights = {}

for (name1, mat1), (name2, mat2) in combinations:
    
    x_ = mat1
    y_ = mat2

    distances = torch.cdist(x_, x_, p=2) ** 2
    median_distance = torch.median(distances[distances > 0])  # Exclude zero distances
    sigma_x = torch.sqrt(median_distance)

    distances = torch.cdist(y_, y_, p=2) ** 2
    median_distance = torch.median(distances[distances > 0])  # Exclude zero distances
    sigma_y = torch.sqrt(median_distance)

    distances = torch.cdist(x_, y_, p=2) ** 2
    median_distance = torch.median(distances[distances > 0])  # Exclude zero distances
    sigma_xy = torch.sqrt(median_distance)

    mmd_score, k_xy = compute_mmd2(x_, y_, rbf_kernel, sigma_x=sigma_x.item(), sigma_y=sigma_y.item(), sigma_xy=sigma_xy.item())
    graph_weights[f"Combining {name1} and {name2}"] = round(mmd_score, 2)


torch.Size([1728, 1728])
tensor(2811537., device='cuda:0') tensor(2827945.7500, device='cuda:0') tensor(16340.8848, device='cuda:0')
torch.Size([1728, 1728])
tensor(2811537., device='cuda:0') tensor(2868462.2500, device='cuda:0') tensor(16455.6699, device='cuda:0')
torch.Size([1728, 1728])
tensor(2811537., device='cuda:0') tensor(2889557.7500, device='cuda:0') tensor(16458.8809, device='cuda:0')
torch.Size([1728, 1728])
tensor(2811537., device='cuda:0') tensor(2830970.2500, device='cuda:0') tensor(16258.6875, device='cuda:0')
torch.Size([1728, 1728])
tensor(2811537., device='cuda:0') tensor(2804926.5000, device='cuda:0') tensor(16055.9043, device='cuda:0')
torch.Size([1728, 1728])
tensor(2811537., device='cuda:0') tensor(2888106.5000, device='cuda:0') tensor(16397.3828, device='cuda:0')
torch.Size([1728, 1728])
tensor(2827945.7500, device='cuda:0') tensor(2868462.2500, device='cuda:0') tensor(16446.8398, device='cuda:0')
torch.Size([1728, 1728])
tensor(2827945.7500, device='cuda:0') te

In [56]:
import networkx as nx
import matplotlib.pyplot as plt

G = nx.Graph()
for key, weight in graph_weights.items():
    parts = key.replace('Combining ', '').replace('_matrix', '').split(' and ')
    task1, task2 = parts
    G.add_edge(task1, task2, weight=weight)