## Imports

In [1]:
from transformer import TransformerClassifier
from dataloader import *
from CV.utils import * 
import torch
import ot
from copy import deepcopy
from tqdm import tqdm
import optuna

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load the Data

In [3]:
#init
tokenizer = Tokenizer()
loader = DataLoader(tokenize = tokenizer.tokenize)

# import data (combine train/test as we split afterwards anyways)
data = pd.read_csv("./Data/IMDB Dataset.csv", encoding='ISO-8859-1')
# convert string label to binary (int) label (spam:1, non-spam:0)
data["sentiment"] = data['sentiment'].apply(lambda x : int(x == "positive"))

# train, test, val split
train, valid, test = loader.make_dataset(data)
vocab = loader.get_vocab(train.iloc[:, 0])
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,
                                                     batch_size=512,
                                                     device=device)

# NLP stuff
pad_idx = vocab['__PAD__']
voc_size = len(vocab)
print("Vocabulary Size : ", voc_size)

dataset initializing start
Length of data after first step of preprocessing:  35832
Tokenizing the data...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data["len"] = data.iloc[:, 0].apply(lambda x : len(self.tokenize(x)))


Length of the data :  29544
0
review       [[CLS], one, of, the, many, silent, comedies, ...
sentiment                                                    0
len                                                        186
Name: 46539, dtype: object


100%|██████████| 23635/23635 [00:01<00:00, 14257.29it/s]
100%|██████████| 2954/2954 [00:00<00:00, 17048.35it/s]
100%|██████████| 2955/2955 [00:00<00:00, 16056.68it/s]


dataset initializing done
Vocabulary Size :  23050


In [4]:
# Load the embedding matrix
embedding = torch.load("Models/embedding_16_trained_multilayer.pt")

In [5]:
embedding.weight

Parameter containing:
tensor([[-0.3602, -2.3004,  1.2823,  ...,  1.3947, -0.0689, -0.4801],
        [-0.1015,  0.2015, -1.9556,  ..., -0.1140,  0.2189,  0.8465],
        [ 1.2555,  0.2540,  1.1998,  ..., -0.5179, -0.0536,  1.1762],
        ...,
        [-0.2308,  0.4539,  0.3592,  ..., -0.7208, -0.6409, -0.3728],
        [ 1.0316,  1.4529,  0.5553,  ..., -0.8593,  0.1438, -1.7995],
        [-0.0379, -1.8211, -0.7899,  ...,  0.9734,  0.4521,  0.8370]])

## Load the Model Weights

In [6]:
modelA = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.5,
                              device = device)

modelA.load_state_dict(torch.load("Models\modelA_IMDB_256_multilayer"))
modelA.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(23050, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.5, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.5, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

In [7]:
modelB = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.5,
                              device = device)

modelB.load_state_dict(torch.load("Models\modelB_IMDB_256_multilayer"))
modelB.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(23050, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.5, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.5, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

## OT Functions 

In [8]:
def getSupport(model, trainloader, l, alignment = "acts", numOfBatches= 10):
    '''
    Get the support matrices using Activation-based ("acts") or Weight-based ("wts") alignment 
    '''
    if alignment == "acts":
        activation = None
        for i, data in enumerate(trainloader, 0):
            if i >= numOfBatches:
                break
            
            inputs, targets = data
            outputs = model(inputs)

            if activation is None:
                activation = model.actMatrix[l]
            else:
                activation = torch.cat((activation, model.actMatrix[l]))

        return activation
    elif alignment == "wts":
        return model.state_dict()[l]


In [33]:
def fusion(nameA, nameB, weightA, weightB, transport_matrix, beta):
    support_y = getSupport(modelB, train_iter, nameB, alignment="wts")
    # Get the weights at layer "idx" from the first model
    W_A = weightA
    W_B = weightB
    # Align the weights from the first model
    aligned_W = torch.matmul(W_A, torch.matmul(transport_matrix, torch.diag(1 / beta)))
    # Get the X-Support
    n = W_A.shape[0]
    alpha = torch.ones(n) * (1/n)
    support_x = getSupport(modelA, train_iter, nameA, alignment="wts")
    # Calculate the euclidean distance between the supports
    distance = ot.dist(support_x, support_y)
    # Calculate beta
    m = W_B.shape[0]
    beta = torch.ones(m) * (1/m)
    # Calculate the transport matrix using optimal transport
    transport_matrix = torch.from_numpy(ot.emd(alpha.numpy(), beta.numpy(), distance.detach().numpy())).float().reshape((n, m))
    # Align model neurons
    aligned_model = torch.matmul(torch.diag(1 / beta), torch.matmul(transport_matrix.T, aligned_W))
    # Get the weights at layer "idx" from the second model
    fused = (aligned_model + W_B) / 2
    return  fused, transport_matrix, beta

In [26]:
def fusion_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, layers):
    layer = int(nameA.partition("layers.")[2][0])
    
    W_A = torch.empty((weightA.shape[0] * layers, weightA.shape[1]))
    W_B = torch.empty((weightB.shape[0] * layers, weightB.shape[1]))

    for l in range(layers):
        W_A[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelA.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]
        W_B[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelB.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]

    support_y = W_B
    support_x = W_A

    # Initialize the fused model and transport matrix
    fused = torch.empty(W_B.shape)
    transport_matrix_new = torch.zeros((W_A.shape[0], W_B.shape[0]))
    # Align the weights from the first model
    aligned_W = torch.matmul(W_A, torch.matmul(transport_matrix, torch.diag(1 / beta)))
    # Get the X-Support
    n = W_A.shape[0]
    alpha = torch.ones(n) * (1/n)
    # Calculate the euclidean distance between the supports
    distance = ot.dist(support_x, support_y)
    # Calculate beta
    m = W_B.shape[0]
    beta = torch.ones(m) * (1/m)
    # Calculate the transport matrix using optimal transport
    transport_matrix = torch.from_numpy(ot.emd(alpha.numpy(), beta.numpy(), distance.detach().numpy())).float().reshape((n, m))
    # Align model neurons
    aligned_model = torch.matmul(torch.diag(1 / beta), torch.matmul(transport_matrix.T, aligned_W))
    # Get the weights at layer "idx" from the second model
    fused = (aligned_model + W_B) / 2
    dim = layer * weightB.shape[0]
    dim_to = layer * weightB.shape[0] + weightB.shape[0]
    return  fused[dim:dim_to, :], transport_matrix_new[dim:dim_to, dim:dim_to], beta

In [27]:
def fusion_multihead_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, heads, layers):
    layer = int(nameA.partition("layers.")[2][0])
    
    W_A = torch.empty((weightA.shape[0] * layers, weightA.shape[1]))
    W_B = torch.empty((weightB.shape[0] * layers, weightB.shape[1]))
    
    stride = weightB.shape[0] // heads
    for l in range(layers):
        W_A[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelA.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]
        W_B[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelB.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]

    support_y = W_B
    support_x = W_A

    # Initialize the fused model and transport matrix
    fused = torch.empty(W_B.shape)
    transport_matrix_new = torch.zeros((W_A.shape[0], W_B.shape[0]))
    stride = W_B.shape[0] // (heads * layers)
    for i in range(0, W_B.shape[0], stride):
        # Align the weights from the first model
        aligned_W = torch.matmul(W_A[i:i+stride, :], torch.matmul(transport_matrix, torch.diag(1 / beta)))
        # Get the X-Support
        n = W_A.shape[0] // (heads * layers)
        alpha = torch.ones(n) * (1/n)
        # Calculate the euclidean distance between the supports
        distance = ot.dist(support_x[i:i+stride, :], support_y[i:i+stride, :])
        # Calculate beta
        m = W_B.shape[0] // (heads * layers)
        beta_new = torch.ones(m) * (1/m)
        # Calculate the transport matrix using optimal transport
        transport_matrix_new[i:i+stride, i:i+stride] = torch.from_numpy(ot.emd(alpha.numpy(), beta_new.numpy(), distance.detach().numpy())).float().reshape((n, m))
        # Align model neurons
        aligned_model = torch.matmul(torch.diag(1 / beta_new), torch.matmul(transport_matrix_new[i:i+stride, i:i+stride].T, aligned_W))
        # Get the weights at layer "idx" from the second model
        fused[i:i+stride, :] = (aligned_model + W_B[i:i+stride, :]) / 2 
    dim = layer * weightB.shape[0]
    dim_to = layer * weightB.shape[0] + weightB.shape[0]
    return  fused[dim:dim_to, :], transport_matrix_new[dim:dim_to, dim:dim_to], beta_new

# Fusion via Optimal Transport

In [28]:
scores = {'loss': {'A_test_set': [], 'B_test_set': [], 'OT_pre': [], 'OT_calibr': [], 'OT_last_layer': [], 'OT_test_set' : []},
          'accuracy': {'A_test_set': [], 'B_test_set': [], 'OT_pre': [], 'OT_calibr': [], 'OT_last_layer': [], 'OT_test_set' : []},
          'f1': {'A_test_set': [], 'B_test_set': [], 'OT_pre': [], 'OT_calibr': [], 'OT_last_layer': [], 'OT_test_set' : []}}

## Method 1

In [45]:
fusedModel = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding = embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.2,
                              device = device)

In [46]:
def obj(trial):
    a = trial.suggest_float('a', 0, 1)
    
    # Create the fused weights matrix
    W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
    # Initialize the algorithm
    m = list(modelB.state_dict().items())[1][1].shape[1]
    beta = torch.ones(m) * (1/m)
    transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))

    # Fusion via Optimal Transport
    for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
        if nameA == "encoder.emb.tok_emb.embedding.weight":
            W_fusion[nameA] = weightA
        else:   
            if "weight" in nameA:
                if "encoder" in nameA:
                    if "concat" not in nameA and "linear" not in nameA: 
                        transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                        W_fusion[nameA], transport_matrix_triplet, _ = fusion_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, 2)
                    else:
                        W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)

                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            elif "bias" in nameA:
                if "encoder" in nameA: 
                    if "concat" not in nameA and "linear" not in nameA: 
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                    else:
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB

    # Assign the weights
    with torch.no_grad():
        for name, param in fusedModel.named_parameters():
            param.data = torch.nn.Parameter(W_fusion[name])

    # Validate the fused model
    criterion = torch.nn.CrossEntropyLoss()
    val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, criterion, device)
    return val_loss

In [47]:
study = optuna.create_study()
study.optimize(obj, n_trials=20)

[32m[I 2023-01-14 11:53:57,171][0m A new study created in memory with name: no-name-ce8cac9a-d997-4937-b0fb-8e09021aa32f[0m
[32m[I 2023-01-14 11:54:04,032][0m Trial 0 finished with value: 0.8183954656124115 and parameters: {'a': 0.723256023267953}. Best is trial 0 with value: 0.8183954656124115.[0m
[32m[I 2023-01-14 11:54:12,721][0m Trial 1 finished with value: 0.4955558280150096 and parameters: {'a': 0.03416997858031012}. Best is trial 1 with value: 0.4955558280150096.[0m
[32m[I 2023-01-14 11:54:22,662][0m Trial 2 finished with value: 0.5235801140467325 and parameters: {'a': 0.20763937634628304}. Best is trial 1 with value: 0.4955558280150096.[0m
[32m[I 2023-01-14 11:54:31,604][0m Trial 3 finished with value: 0.6849202513694763 and parameters: {'a': 0.550675425060867}. Best is trial 1 with value: 0.4955558280150096.[0m
[32m[I 2023-01-14 11:54:40,478][0m Trial 4 finished with value: 0.7962923844655355 and parameters: {'a': 0.6934538825629025}. Best is trial 1 with valu

In [48]:
a = study.best_params["a"]

# Create the fused weights matrix
W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
# Initialize the algorithm
m = list(modelB.state_dict().items())[1][1].shape[1]
beta = torch.ones(m) * (1/m)
transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))
# Fusion via Optimal Transport
for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
    if nameA == "encoder.emb.tok_emb.embedding.weight":
        W_fusion[nameA] = weightA
    else:   
        if "weight" in nameA:
            if "encoder" in nameA:
                if "concat" not in nameA and "linear" not in nameA: 
                    transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                    W_fusion[nameA], transport_matrix_triplet, _ = fusion_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, 2)
                else:
                    W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        elif "bias" in nameA:
            if "encoder" in nameA: 
                if "concat" not in nameA and "linear" not in nameA: 
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        else:
            W_fusion[nameA] = a * weightA + (1-a) * weightB
# Assign the weights
with torch.no_grad():
    for name, param in fusedModel.named_parameters():
        param.data = torch.nn.Parameter(W_fusion[name])
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, criterion, device)
scores['loss']['OT_pre'], scores['accuracy']['OT_pre'], scores['f1']['OT_pre'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())
print("Validation F+:     ", val_f1.item())

Validation Loss:      0.49432618419329327
Validation Accuracy:      0.7745908682318104
Validation F+:      0.7742044925689697


## Method 2

In [34]:
fusedModel = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding = embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.2,
                              device = device)

In [35]:
def obj(trial):
    a = trial.suggest_float('a', 0, 1)
    
    # Create the fused weights matrix
    W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
    # Initialize the algorithm
    m = list(modelB.state_dict().items())[1][1].shape[1]
    beta = torch.ones(m) * (1/m)
    transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))

    # Fusion via Optimal Transport
    for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
        if nameA == "encoder.emb.tok_emb.embedding.weight":
            W_fusion[nameA] = weightA
        else:   
            if "weight" in nameA:
                if "encoder" in nameA:
                    if "concat" not in nameA and "linear" not in nameA: 
                        transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                        W_fusion[nameA], transport_matrix_triplet, _ = fusion_multihead_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, fusedModel.n_head, 2)
                    else:
                        W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)

                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            elif "bias" in nameA:
                if "encoder" in nameA: 
                    if "concat" not in nameA and "linear" not in nameA: 
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                    else:
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB

    # Assign the weights
    with torch.no_grad():
        for name, param in fusedModel.named_parameters():
            param.data = torch.nn.Parameter(W_fusion[name])

    # Validate the fused model
    criterion = torch.nn.CrossEntropyLoss()
    val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, criterion, device)
    return val_loss

In [37]:
study = optuna.create_study()
study.optimize(obj, n_trials=20)

[32m[I 2023-01-14 11:44:22,037][0m A new study created in memory with name: no-name-6b09a687-c6f9-4b9b-83ef-48f0d66a1684[0m
  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device)  # put to cpu/gpu
[32m[I 2023-01-14 11:44:31,048][0m Trial 0 finished with value: 0.7893851002057394 and parameters: {'a': 0.4174552341685299}. Best is trial 0 with value: 0.7893851002057394.[0m
[32m[I 2023-01-14 11:44:40,141][0m Trial 1 finished with value: 0.6804643074671427 and parameters: {'a': 0.2509478447327288}. Best is trial 1 with value: 0.6804643074671427.[0m
[32m[I 2023-01-14 11:44:48,585][0m Trial 2 finished with value: 0.9434540271759033 and parameters: {'a': 0.6104170066096141}. Best is trial 1 with value: 0.6804643074671427.[0m
[32m[I 2023-01-14 11:44:57,550][0m Trial 3 finished with value: 1.1019093990325928 and parameters: {'a': 0.7998780503285836}. Best is trial 1 with value: 0.6804643074671427.[0m
[32m[I 2023-01-14 11:45:06,376][0m Trial 4 finished with valu

In [39]:
a = study.best_params["a"]

# Create the fused weights matrix
W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
# Initialize the algorithm
m = list(modelB.state_dict().items())[1][1].shape[1]
beta = torch.ones(m) * (1/m)
transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))
# Fusion via Optimal Transport
for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
    if nameA == "encoder.emb.tok_emb.embedding.weight":
        W_fusion[nameA] = weightA
    else:   
        if "weight" in nameA:
            if "encoder" in nameA:
                if "concat" not in nameA and "linear" not in nameA: 
                    transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                    W_fusion[nameA], transport_matrix_triplet, _ = fusion_multihead_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, fusedModel.n_head, 2)
                else:
                    W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        elif "bias" in nameA:
            if "encoder" in nameA: 
                if "concat" not in nameA and "linear" not in nameA: 
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        else:
            W_fusion[nameA] = a * weightA + (1-a) * weightB
# Assign the weights
with torch.no_grad():
    for name, param in fusedModel.named_parameters():
        param.data = torch.nn.Parameter(W_fusion[name])
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, criterion, device)
scores['loss']['OT_pre'], scores['accuracy']['OT_pre'], scores['f1']['OT_pre'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())
print("Validation F1:     ", val_f1.item())

Validation Loss:      0.5696245928605398
Validation Accuracy:      0.7417479642554992
Validation Accuracy:      0.7413676381111145


## Recalibrate the normalization layers

In [49]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "weight" in name or "bias" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 47/47 [01:44<00:00,  2.22s/it]


In [50]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, criterion, device)
scores['loss']['OT_pre'], scores['accuracy']['OT_pre'], scores['f1']['OT_pre'] = val_loss, val_acc, val_f1
scores['loss']['OT_calibr'], scores['accuracy']['OT_calibr'], scores['f1']['OT_calibr'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())
print("Validation F1:     ", val_f1.item())

Validation Loss:      0.5758781035741171
Validation Accuracy:      0.7533989662648054
Validation F1:      0.7518618702888489


In [51]:
for name, param in fusedModel.named_parameters():
    param.requires_grad = True

## Train the last layer

In [52]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "encoder" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 47/47 [01:11<00:00,  1.52s/it]


In [53]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, criterion, device)
scores['loss']['OT_last_layer'], scores['accuracy']['OT_last_layer'], scores['f1']['OT_last_layer'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())
print("Validation F1:     ", val_f1.item())

Validation Loss:      0.46658990780512494
Validation Accuracy:      0.8097768612521151
Validation F1:      0.8090723752975464


## Test Set

In [54]:
# Test the models
criterion = torch.nn.CrossEntropyLoss()

test_loss_fused, test_acc_fused, test_f1_fused  = validation(fusedModel, test_iter, criterion, device)
scores['loss']['OT_test_set'], scores['accuracy']['OT_test_set'], scores['f1']['OT_test_set'] = test_loss_fused, test_acc_fused, test_f1_fused
test_loss_A, test_acc_A, test_f1_A  = validation(modelA, test_iter, criterion, device)
scores['loss']['A_test_set'], scores['accuracy']['A_test_set'], scores['f1']['A_test_set'] = test_loss_A, test_acc_A, test_f1_A
test_loss_B, test_acc_B, test_f1_B  = validation(modelB, test_iter, criterion, device)
scores['loss']['B_test_set'], scores['accuracy']['B_test_set'], scores['f1']['B_test_set'] = test_loss_A, test_acc_A, test_f1_A

print("Test Loss Fused Model:     ", test_loss_fused)
print("Test Accuracy Fused Model:     ", test_acc_fused.item())
print("Test F1 Fused Model:     ", test_f1_fused.item())
print("Test Loss Model A:     ", test_loss_A)
print("Test Accuracy Model A:     ", test_acc_A.item())
print("Test F1 Model A:     ", test_f1_A.item())
print("Test Loss Model B:     ", test_loss_B)
print("Test Accuracy Model B:     ", test_acc_B.item())
print("Test F1 Model B:     ", test_f1_B.item())

Test Loss Fused Model:      0.43884028991063434
Test Accuracy Fused Model:      0.8128881526898734
Test F1 Fused Model:      0.8131979703903198
Test Loss Model A:      0.4548064172267914
Test Accuracy Model A:      0.8403274986814346
Test F1 Model A:      0.838917076587677
Test Loss Model B:      0.41694880028565723
Test Accuracy Model B:      0.8556393393987342
Test F1 Model B:      0.8558376431465149


In [22]:
df = pd.DataFrame.from_dict(scores)

In [23]:
df.head(10)

Unnamed: 0,loss,accuracy,f1
A_test_set,0.473577,"tensor(0.8281, dtype=torch.float64)",tensor(0.8281)
B_test_set,0.473577,"tensor(0.8281, dtype=torch.float64)",tensor(0.8281)
OT_pre,0.485952,"tensor(0.7773, dtype=torch.float64)",tensor(0.7776)
OT_calibr,0.485952,"tensor(0.7773, dtype=torch.float64)",tensor(0.7776)
OT_last_layer,0.470886,"tensor(0.7895, dtype=torch.float64)",tensor(0.7891)
OT_test_set,0.47969,"tensor(0.7863, dtype=torch.float64)",tensor(0.7868)


In [24]:
# rename rows, cols
df.columns = ['Loss', 'Accuracy', 'F1 score']
df.index = ['Model A on test set', 'Model B on test set', 'Optimal transport + weighted fusion', 'OT + weighted fusion (recalibrated)',
            'OT + weighted fusion (last layer retrained)', 'OT + weighted fusion (last layer retrained) on test set']

In [26]:
# convert to latex
latex = df.to_latex(index=True,
                    bold_rows=True,
                    caption='Model performance (5-fold CV)',
                    position='H').replace('BOLD\\', r'\textbf').replace('\}', '}')

  latex = df.to_latex(index=True,


In [27]:
print(latex)

\begin{table}[H]
\centering
\caption{Model performance (5-fold CV)}
\begin{tabular}{lrll}
\toprule
{} &      Loss &                             Accuracy &        F1 score \\
\midrule
\textbf{Model A on test set                               } &  0.473577 &  tensor(0.8281, dtype=torch.float64) &  tensor(0.8281) \\
\textbf{Model B on test set                               } &  0.473577 &  tensor(0.8281, dtype=torch.float64) &  tensor(0.8281) \\
\textbf{Optimal transport + weighted fusion               } &  0.485952 &  tensor(0.7773, dtype=torch.float64) &  tensor(0.7776) \\
\textbf{OT + weighted fusion (recalibrated)               } &  0.485952 &  tensor(0.7773, dtype=torch.float64) &  tensor(0.7776) \\
\textbf{OT + weighted fusion (last layer retrained)       } &  0.470886 &  tensor(0.7895, dtype=torch.float64) &  tensor(0.7891) \\
\textbf{OT + weighted fusion (last layer retrained) on ...} &  0.479690 &  tensor(0.7863, dtype=torch.float64) &  tensor(0.7868) \\
\bottomrule
\end{tabular}