## Imports

In [3]:
from transformer import TransformerClassifier
from dataloader import *
from utils import *
import torch
import ot
from copy import deepcopy
from tqdm import tqdm
import optuna

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load the Data

In [5]:
#init
tokenizer = Tokenizer()
loader = DataLoader(tokenize = tokenizer.tokenize)

# import data (combine train/test as we split afterwards anyways)
data = pd.read_csv("./Data/IMDB Dataset.csv", encoding='ISO-8859-1')
# convert string label to binary (int) label (spam:1, non-spam:0)
data["sentiment"] = data['sentiment'].apply(lambda x : int(x == "positive"))

# train, test, val split
train, valid, test = loader.make_dataset(data)
vocab = loader.get_vocab(train.iloc[:, 0])
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,
                                                     batch_size=512,
                                                     device=device)

# NLP stuff
pad_idx = vocab['__PAD__']
voc_size = len(vocab)
print("Vocabulary Size : ", voc_size)

dataset initializing start
Tokenizing the data...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data["len"] = data.iloc[:, 0].apply(lambda x : len(self.tokenize(x)))


Length of the data :  29544
0
review       [[CLS], one, of, the, many, silent, comedies, ...
sentiment                                                    0
len                                                        186
Name: 46539, dtype: object


100%|██████████| 23635/23635 [00:03<00:00, 7592.40it/s]
100%|██████████| 2954/2954 [00:00<00:00, 7348.25it/s]
100%|██████████| 2955/2955 [00:00<00:00, 4323.46it/s]


dataset initializing done
Vocabulary Size :  23050


In [141]:
# Load the embedding matrix
embedding = torch.load("Models/embedding_16.pt")

## Load the Model Weights

In [143]:
modelA = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.5,
                              device = device)

modelA.load_state_dict(torch.load("Models\modelA_IMDB_256_multilayer"))
modelA.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(23050, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.5, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.5, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

In [144]:
modelB = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.5,
                              device = device)

modelB.load_state_dict(torch.load("Models\modelB_IMDB_256_multilayer"))
modelB.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(23050, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.5, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.5, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

## OT Functions 

In [129]:
def getSupport(model, trainloader, l, alignment = "acts", numOfBatches= 10):
    '''
    Get the support matrices using Activation-based ("acts") or Weight-based ("wts") alignment 
    '''
    if alignment == "acts":
        activation = None
        for i, data in enumerate(trainloader, 0):
            if i >= numOfBatches:
                break
            
            inputs, targets = data
            outputs = model(inputs)

            if activation is None:
                activation = model.actMatrix[l]
            else:
                activation = torch.cat((activation, model.actMatrix[l]))

        return activation
    elif alignment == "wts":
        return model.state_dict()[l]


In [130]:
def fusion(nameA, nameB, weightA, weightB, transport_matrix, beta):
    support_y = getSupport(modelB, train_iter, nameB, alignment="wts")
    # Get the weights at layer "idx" from the first model
    W_A = weightA
    W_B = weightB
    # Align the weights from the first model
    aligned_W = torch.matmul(W_A, torch.matmul(transport_matrix, torch.diag(1 / beta)))
    # Get the X-Support
    n = W_A.shape[0]
    alpha = torch.ones(n) * (1/n)
    support_x = getSupport(modelA, train_iter, nameA, alignment="wts")
    # Calculate the euclidean distance between the supports
    distance = ot.dist(support_x, support_y)
    # Calculate beta
    m = W_B.shape[0]
    beta = torch.ones(m) * (1/m)
    # Calculate the transport matrix using optimal transport
    transport_matrix = torch.from_numpy(ot.emd(alpha.numpy(), beta.numpy(), distance.detach().numpy())).float().reshape((n, m))
    # Align model neurons
    aligned_model = torch.matmul(torch.diag(1 / beta), torch.matmul(transport_matrix.T, aligned_W))
    # Get the weights at layer "idx" from the second model
    fused = (aligned_model + W_B) / 2 
    return  fused, transport_matrix, beta

In [131]:
def fusion_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, layers):
    layer = int(nameA.partition("layers.")[2][0])
    
    W_A = torch.empty((weightA.shape[0] * layers, weightA.shape[1]))
    W_B = torch.empty((weightB.shape[0] * layers, weightB.shape[1]))

    for l in range(layers):
        W_A[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelA.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]
        W_B[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelB.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]

    support_y = W_B
    support_x = W_A

    # Initialize the fused model and transport matrix
    fused = torch.empty(W_B.shape)
    transport_matrix_new = torch.zeros((W_A.shape[0], W_B.shape[0]))
    # Align the weights from the first model
    aligned_W = torch.matmul(W_A, torch.matmul(transport_matrix, torch.diag(1 / beta)))
    # Get the X-Support
    n = W_A.shape[0]
    alpha = torch.ones(n) * (1/n)
    # Calculate the euclidean distance between the supports
    distance = ot.dist(support_x, support_y)
    # Calculate beta
    m = W_B.shape[0]
    beta = torch.ones(m) * (1/m)
    # Calculate the transport matrix using optimal transport
    transport_matrix = torch.from_numpy(ot.emd(alpha.numpy(), beta.numpy(), distance.detach().numpy())).float().reshape((n, m))
    # Align model neurons
    aligned_model = torch.matmul(torch.diag(1 / beta), torch.matmul(transport_matrix.T, aligned_W))
    # Get the weights at layer "idx" from the second model
    fused = (aligned_model + W_B) / 2
    dim = layer * weightB.shape[0]
    dim_to = layer * weightB.shape[0] + weightB.shape[0]
    return  fused[dim:dim_to, :], transport_matrix_new[dim:dim_to, dim:dim_to], beta

In [132]:
def fusion_multihead_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, heads, layers):
    layer = int(nameA.partition("layers.")[2][0])
    
    W_A = torch.empty((weightA.shape[0] * layers, weightA.shape[1]))
    W_B = torch.empty((weightB.shape[0] * layers, weightB.shape[1]))
    
    stride = weightB.shape[0] // heads
    for l in range(layers):
        W_A[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelA.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]
        W_B[l * weightB.shape[0] : l * weightB.shape[0] + weightB.shape[0], :] = modelB.state_dict()[nameA.replace(f"layers.{layer}", f"layers.{l}")]

    support_y = W_B
    support_x = W_A

    # Initialize the fused model and transport matrix
    fused = torch.empty(W_B.shape)
    transport_matrix_new = torch.zeros((W_A.shape[0], W_B.shape[0]))
    stride = W_B.shape[0] // (heads * layers)
    for i in range(0, W_B.shape[0], stride):
        # Align the weights from the first model
        aligned_W = torch.matmul(W_A[i:i+stride, :], torch.matmul(transport_matrix, torch.diag(1 / beta)))
        # Get the X-Support
        n = W_A.shape[0] // (heads * layers)
        alpha = torch.ones(n) * (1/n)
        # Calculate the euclidean distance between the supports
        distance = ot.dist(support_x[i:i+stride, :], support_y[i:i+stride, :])
        # Calculate beta
        m = W_B.shape[0] // (heads * layers)
        beta_new = torch.ones(m) * (1/m)
        # Calculate the transport matrix using optimal transport
        transport_matrix_new[i:i+stride, i:i+stride] = torch.from_numpy(ot.emd(alpha.numpy(), beta_new.numpy(), distance.detach().numpy())).float().reshape((n, m))
        # Align model neurons
        aligned_model = torch.matmul(torch.diag(1 / beta_new), torch.matmul(transport_matrix_new[i:i+stride, i:i+stride].T, aligned_W))
        # Get the weights at layer "idx" from the second model
        fused[i:i+stride, :] = (aligned_model + W_B[i:i+stride, :]) / 2 
    dim = layer * weightB.shape[0]
    dim_to = layer * weightB.shape[0] + weightB.shape[0]
    return  fused[dim:dim_to, :], transport_matrix_new[dim:dim_to, dim:dim_to], beta_new

# Fusion via Optimal Transport

In [154]:
scores = {'loss': {'A_test_set': [], 'B_test_set': [], 'OT_pre': [], 'OT_calibr': [], 'OT_last_layer': [], 'OT_test_set' : []},
          'accuracy': {'A_test_set': [], 'B_test_set': [], 'OT_pre': [], 'OT_calibr': [], 'OT_last_layer': [], 'OT_test_set' : []},
          'f1': {'A_test_set': [], 'B_test_set': [], 'OT_pre': [], 'OT_calibr': [], 'OT_last_layer': [], 'OT_test_set' : []}}

## Method 1

In [155]:
fusedModel = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding = embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.2,
                              device = device)

In [156]:
def obj(trial):
    a = trial.suggest_float('a', 0, 1)
    
    # Create the fused weights matrix
    W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
    # Initialize the algorithm
    m = list(modelB.state_dict().items())[1][1].shape[1]
    beta = torch.ones(m) * (1/m)
    transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))

    # Fusion via Optimal Transport
    for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
        if nameA == "encoder.emb.tok_emb.embedding.weight":
            W_fusion[nameA] = weightA
        else:   
            if "weight" in nameA:
                if "encoder" in nameA:
                    if "concat" not in nameA and "linear" not in nameA: 
                        transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                        W_fusion[nameA], transport_matrix_triplet, _ = fusion_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, 2)
                    else:
                        W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)

                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            elif "bias" in nameA:
                if "encoder" in nameA: 
                    if "concat" not in nameA and "linear" not in nameA: 
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                    else:
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB

    # Assign the weights
    with torch.no_grad():
        for name, param in fusedModel.named_parameters():
            param.data = torch.nn.Parameter(W_fusion[name])

    # Validate the fused model
    criterion = torch.nn.CrossEntropyLoss()
    val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, None, criterion, device)
    return val_loss

In [157]:
study = optuna.create_study()
study.optimize(obj, n_trials=20)

[32m[I 2022-12-30 16:56:08,884][0m A new study created in memory with name: no-name-e1ebd106-0f98-485f-a676-417bc2ebc8ae[0m
  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device)  # put to cpu/gpu
[32m[I 2022-12-30 16:56:20,988][0m Trial 0 finished with value: 0.8179209729035696 and parameters: {'a': 0.8970449330060377}. Best is trial 0 with value: 0.8179209729035696.[0m
[32m[I 2022-12-30 16:56:32,516][0m Trial 1 finished with value: 2.98838218053182 and parameters: {'a': 0.003952219253880895}. Best is trial 0 with value: 0.8179209729035696.[0m
[32m[I 2022-12-30 16:56:43,397][0m Trial 2 finished with value: 0.9196842014789581 and parameters: {'a': 0.8180080947523416}. Best is trial 0 with value: 0.8179209729035696.[0m
[32m[I 2022-12-30 16:56:53,306][0m Trial 3 finished with value: 1.1408438086509705 and parameters: {'a': 0.6931171086467981}. Best is trial 0 with value: 0.8179209729035696.[0m
[32m[I 2022-12-30 16:57:02,673][0m Trial 4 finished with valu

In [158]:
a = study.best_params["a"]

# Create the fused weights matrix
W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
# Initialize the algorithm
m = list(modelB.state_dict().items())[1][1].shape[1]
beta = torch.ones(m) * (1/m)
transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))
# Fusion via Optimal Transport
for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
    if nameA == "encoder.emb.tok_emb.embedding.weight":
        W_fusion[nameA] = weightA
    else:   
        if "weight" in nameA:
            if "encoder" in nameA:
                if "concat" not in nameA and "linear" not in nameA: 
                    transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                    W_fusion[nameA], transport_matrix_triplet, _ = fusion_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, 2)
                else:
                    W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        elif "bias" in nameA:
            if "encoder" in nameA: 
                if "concat" not in nameA and "linear" not in nameA: 
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        else:
            W_fusion[nameA] = a * weightA + (1-a) * weightB
# Assign the weights
with torch.no_grad():
    for name, param in fusedModel.named_parameters():
        param.data = torch.nn.Parameter(W_fusion[name])
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, None, criterion, device)
scores['loss']['OT_pre'], scores['accuracy']['OT_pre'], scores['f1']['OT_pre'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

Validation Loss:      0.7510955631732941
Validation Accuracy:      0.6117759226945855


## Method 2

In [None]:
fusedModel = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding = embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 2,
                              drop_prob = 0.2,
                              device = device)

In [86]:
def obj(trial):
    a = trial.suggest_float('a', 0, 1)
    
    # Create the fused weights matrix
    W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
    # Initialize the algorithm
    m = list(modelB.state_dict().items())[1][1].shape[1]
    beta = torch.ones(m) * (1/m)
    transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))

    # Fusion via Optimal Transport
    for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
        if nameA == "encoder.emb.tok_emb.embedding.weight":
            W_fusion[nameA] = weightA
        else:   
            if "weight" in nameA:
                if "encoder" in nameA:
                    if "concat" not in nameA and "linear" not in nameA: 
                        transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                        W_fusion[nameA], transport_matrix_triplet, _ = fusion_multihead_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, fusedModel.n_head, 2)
                    else:
                        W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)

                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            elif "bias" in nameA:
                if "encoder" in nameA: 
                    if "concat" not in nameA and "linear" not in nameA: 
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                    else:
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB

    # Assign the weights
    with torch.no_grad():
        for name, param in fusedModel.named_parameters():
            param.data = torch.nn.Parameter(W_fusion[name])

    # Validate the fused model
    criterion = torch.nn.CrossEntropyLoss()
    val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, None, criterion, device)
    return val_loss

In [87]:
study = optuna.create_study()
study.optimize(obj, n_trials=20)

[32m[I 2022-12-30 15:19:39,165][0m A new study created in memory with name: no-name-4acba308-fb73-42c5-8496-dca3631aa836[0m
[32m[I 2022-12-30 15:19:58,384][0m Trial 0 finished with value: 0.854739228884379 and parameters: {'a': 0.9832036567452003}. Best is trial 0 with value: 0.854739228884379.[0m
[32m[I 2022-12-30 15:20:16,796][0m Trial 1 finished with value: 1.0881722768147786 and parameters: {'a': 0.1758490581392247}. Best is trial 0 with value: 0.854739228884379.[0m
[32m[I 2022-12-30 15:20:35,644][0m Trial 2 finished with value: 0.8684331278006235 and parameters: {'a': 0.9970797215861974}. Best is trial 0 with value: 0.854739228884379.[0m
[32m[I 2022-12-30 15:20:53,443][0m Trial 3 finished with value: 0.8378826479117075 and parameters: {'a': 0.9632188911445807}. Best is trial 3 with value: 0.8378826479117075.[0m
[32m[I 2022-12-30 15:21:10,874][0m Trial 4 finished with value: 0.7351060807704926 and parameters: {'a': 0.6722490075535558}. Best is trial 4 with value: 0

In [91]:
a = study.best_params["a"]

# Create the fused weights matrix
W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
# Initialize the algorithm
m = list(modelB.state_dict().items())[1][1].shape[1]
beta = torch.ones(m) * (1/m)
transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))
# Fusion via Optimal Transport
for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
    if nameA == "encoder.emb.tok_emb.embedding.weight":
        W_fusion[nameA] = weightA
    else:   
        if "weight" in nameA:
            if "encoder" in nameA:
                if "concat" not in nameA and "linear" not in nameA: 
                    transport_matrix_triplet = torch.empty((weightA.shape[0], weightB.shape[0]))
                    W_fusion[nameA], transport_matrix_triplet, _ = fusion_multihead_multilayer(nameA, nameB, weightA, weightB, transport_matrix, beta, fusedModel.n_head, 2)
                else:
                    W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        elif "bias" in nameA:
            if "encoder" in nameA: 
                if "concat" not in nameA and "linear" not in nameA: 
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        else:
            W_fusion[nameA] = a * weightA + (1-a) * weightB
# Assign the weights
with torch.no_grad():
    for name, param in fusedModel.named_parameters():
        param.data = torch.nn.Parameter(W_fusion[name])
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, None, criterion, device)
scores['loss']['OT_pre'], scores['accuracy']['OT_pre'], scores['f1']['OT_pre'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device)  # put to cpu/gpu


Validation Loss:      0.7322940131028494
Validation Accuracy:      0.6673408417935702


## Recalibrate the normalization layers

In [159]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "weight" in name or "bias" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 47/47 [02:06<00:00,  2.70s/it]


In [160]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, None, criterion, device)
scores['loss']['OT_pre'], scores['accuracy']['OT_pre'], scores['f1']['OT_pre'] = val_loss, val_acc, val_f1
scores['loss']['OT_calibr'], scores['accuracy']['OT_calibr'], scores['f1']['OT_calibr'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

Validation Loss:      0.5179742028315862
Validation Accuracy:      0.7538269352791879


In [161]:
for name, param in fusedModel.named_parameters():
    param.requires_grad = True

## Train the last layer

In [162]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "encoder" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 47/47 [01:23<00:00,  1.78s/it]


In [163]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc, val_f1 = validation(fusedModel, valid_iter, None, criterion, device)
scores['loss']['OT_last_layer'], scores['accuracy']['OT_last_layer'], scores['f1']['OT_last_layer'] = val_loss, val_acc, val_f1
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

Validation Loss:      0.4185267935196559
Validation Accuracy:      0.8062291798857868


## Test Set

In [164]:
# Test the models
criterion = torch.nn.CrossEntropyLoss()

test_loss_fused, test_acc_fused, test_f1_fused  = validation(fusedModel, test_iter, None, criterion, device)
scores['loss']['OT_test_set'], scores['accuracy']['OT_test_set'], scores['f1']['OT_test_set'] = test_loss_fused, test_acc_fused, test_f1_fused
test_loss_A, test_acc_A, test_f1_A  = validation(modelA, test_iter, None, criterion, device)
scores['loss']['A'], scores['accuracy']['A'], scores['f1']['A'] = test_loss_A, test_acc_A, test_f1_A
test_loss_B, test_acc_B, test_f1_B  = validation(modelB, test_iter, None, criterion, device)
scores['loss']['B'], scores['accuracy']['B'], scores['f1']['B'] = test_loss_A, test_acc_A, test_f1_A

print("Validation Loss Fused Model:     ", test_loss_fused)
print("Validation Accuracy Fused Model:     ", test_acc_fused.item())
print("Validation Loss Model A:     ", test_loss_A)
print("Validation Accuracy Model A:     ", test_acc_A.item())
print("Validation Loss Model B:     ", test_loss_B)
print("Validation Accuracy Model B:     ", test_acc_B.item())

Validation Loss Fused Model:      0.4156232972939809
Validation Accuracy Fused Model:      0.8086926424050632
Validation Loss Model A:      0.5493571062882742
Validation Accuracy Model A:      0.7044064807489452
Validation Loss Model B:      0.5096772213776907
Validation Accuracy Model B:      0.8205556104957806


In [165]:
df = pd.DataFrame.from_dict(scores)

In [169]:
df.head(10)

Unnamed: 0,Loss,Accuracy,F1 score
A_test_set,[],[],[]
B_test_set,[],[],[]
OT_pre,0.517974,"tensor(0.7538, dtype=torch.float64)",tensor(0.7552)
OT_calibr,0.517974,"tensor(0.7538, dtype=torch.float64)",tensor(0.7552)
OT_last_layer,0.418527,"tensor(0.8062, dtype=torch.float64)",tensor(0.8057)
OT_test_set,0.415623,"tensor(0.8087, dtype=torch.float64)",tensor(0.8091)
A,0.549357,"tensor(0.7044, dtype=torch.float64)",tensor(0.7042)
B,0.549357,"tensor(0.7044, dtype=torch.float64)",tensor(0.7042)


In [166]:
# rename rows, cols
df.columns = ['Loss', 'Accuracy', 'F1 score']
df.index = ['Model A on test set', 'Model B on test set', 'Optimal transport + weighted fusion', 'OT + weighted fusion (recalibrated)',
            'OT + weighted fusion (last layer retrained)', 'OT + weighted fusion (last layer retrained) on test set']

ValueError: Length mismatch: Expected axis has 8 elements, new values have 6 elements

In [None]:
# boldify highest score
for col in (0, 1, 2):
    if col == 0:
        index_max = np.argmin([float(entry.split('±')[0]) for entry in df.iloc[:, col]])
    else:
        index_max = np.argmax([float(entry.split('±')[0]) for entry in df.iloc[:, col]])
    entry = df.iloc[index_max, col]
    entry = 'BOLD{' + entry + '}'
    df.iloc[index_max, col] = entry

In [None]:
# convert to latex
latex = df.to_latex(index=True,
                    bold_rows=True,
                    caption='Model performance (5-fold CV)',
                    position='H').replace('BOLD\\', r'\textbf').replace('\}', '}')

In [None]:
print(latex)