In [25]:
from datasets import load_dataset
from dataloader import *
from transformer import *
import pandas as pd
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import ot
import torch
import optuna

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [78]:
def validation(model, iterator, criterion, device):
    # set model into evaluation mode
    model.eval()

    # validation
    # loss, metrics for current epoch
    val_epoch_loss = 0
    val_epoch_accuracy = 0

    with torch.no_grad(): # stop graph
        # batches
        for i, batch in enumerate(iterator):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            output = model(src)
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)

            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy

            val_epoch_loss += loss.item()
            val_epoch_accuracy += accuracy

    # return mean loss w.r.t. batches
    return val_epoch_loss / len(iterator), val_epoch_accuracy / len(iterator)

def plot_training(history, marker=None):
    # put everything on cpu
    for key, value in history.items():
        history[key] = [element.cpu() if isinstance(element, torch.Tensor) else element for element in value]

    plt.subplots_adjust(left=0.1,
                    bottom=0.01,
                    right=1.5,
                    top=0.6,
                    wspace=0.4,
                    hspace=0.4)

    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'])
    plt.plot(history['val_loss'])
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.title('Training loss')

    # vertical line for marking best epoch
    if marker is not None:
        y_min = min(history['train_loss'] + history['val_loss'])
        y_max = max(history['train_loss'] + history['val_loss'])
        plt.vlines(x=marker, ymin=y_min, ymax=y_max, color='red')

    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'])
    plt.plot(history['val_acc'])
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    plt.title('Training metric')

    # vertical line for marking best epoch
    if marker is not None:
        y_min = min(history['train_acc'] + history['val_acc'])
        y_max = max(history['train_acc'] + history['val_acc'])
        plt.vlines(x=marker, ymin=y_min, ymax=y_max, color='red')

    plt.show()

def train_save_best(model, iterator, valid_iter, optimizer, criterion, epoch, clip, device):

    # set model into training mode
    model.train()

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)

    # save data - init
    history = {'train_loss': [],
               'val_loss': [],
               'train_acc': [],
               'val_acc': [],
               'learning_rate': []}
    best_model = None
    best_model_score = 1e9
    best_model_epoch = 0

    # training
    for e in range(epoch):
        # loss, metrics for current epoch
        epoch_loss = 0
        epoch_acc = 0

        # batches
        for i, batch in enumerate(tqdm(iterator)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = model(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            epoch_loss += loss.item()
            epoch_acc += accuracy / len(iterator)

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step() # optimize model

        # validation
        val_loss, val_acc = validation(model, valid_iter, optimizer, criterion, device)
        
        scheduler.step(val_loss)

        # save data
        with torch.no_grad():
            current_lr = optimizer.param_groups[0]['lr']

            for key, value in zip(history.keys(), [epoch_loss / len(iterator), val_loss, epoch_acc, val_acc, current_lr]):
                history[key].append(value)

            # save best model (w.r.t validation loss)
            if val_loss < best_model_score:
                best_model = model.state_dict()
                best_model_score = val_loss
                best_model_epoch = e

        # visualization
        print(f"Epoch: {e + 1}  Train Loss: {epoch_loss / len(iterator):.4f} \
              Validation Loss: {val_loss:.4f} \
              Train acc: {epoch_acc:.4f}, \
              Val acc: {val_acc:.4f}, \
              Learning Rate : {optimizer.param_groups[0]['lr'] :.4f}")

    # print training curve
    plot_training(history, marker=best_model_epoch)

    return history, best_model, best_model_score

def concat_dataloaders(dataloader1, dataloader2):
    concat_x = []
    concat_y = []

    for elem in dataloader1:
        for i in range(len(elem[0])):
            concat_x.append(elem[0][i].tolist())
            concat_y.append(elem[1][i].tolist())
    for elem in dataloader2:
        for i in range(len(elem[0])):
            concat_x.append(elem[0][i].tolist())
            concat_y.append(elem[1][i].tolist())
    
    x_tensor = torch.tensor(concat_x, device=device)
    y_tensor = torch.tensor(concat_y, device=device)
    dataset = torch.utils.data.TensorDataset(x_tensor, y_tensor)
    iter = torch.utils.data.DataLoader(dataset = dataset, batch_size = 512, shuffle = True)

    return iter  


In [33]:
# init
tokenizer = Tokenizer()
loader = DataLoader(tokenize = tokenizer.tokenize)

dataset = load_dataset("sentiment140")
data_train = pd.DataFrame({'text': dataset['train']['text'], 'sentiment' : dataset['train']['sentiment']})
data_test = pd.DataFrame({'text': dataset['test']['text'], 'sentiment' : dataset['test']['sentiment']})

concat = [data_train, data_test]
data = pd.concat(concat)
data = data.sample(n = 70000, random_state = 42)
# convert string label to binary (int) label (positive:1, negative:0)
data["sentiment"] = data['sentiment'].apply(lambda x : int(x == 4))
# train, test, val split
train_A, valid_A, test_A = loader.make_dataset(data)

dataset = load_dataset("sst2")

data_train = pd.DataFrame({'text': dataset['train']['sentence'], 'sentiment' : dataset['train']['label']})
data_test = pd.DataFrame({'text': dataset['test']['sentence'], 'sentiment' : dataset['test']['label']})
data_val = pd.DataFrame({'text': dataset['validation']['sentence'], 'sentiment' : dataset['validation']['label']})

concat = [data_train, data_test, data_val]
data = pd.concat(concat)
data = data.sample(n = len(data), random_state = 42)
# convert string label to binary (int) label (positive:1, negative:0)
data["sentiment"] = data['sentiment'].apply(lambda x : int(x == 1))
# train, test, val split
train_B, valid_B, test_B = loader.make_dataset(data)

vocab = loader.get_vocab(pd.concat([train_A, train_B], ignore_index=True).iloc[:, 0])

train_iter_A, valid_iter_A, test_iter_A = loader.make_iter(train_A, valid_A, test_A,
                                                     batch_size=512,
                                                     device=device,
                                                     vocab=vocab)

train_iter_B, valid_iter_B, test_iter_B = loader.make_iter(train_B, valid_B, test_B,
                                                     batch_size=512,
                                                     device=device,
                                                     vocab=vocab)

# NLP stuff
pad_idx = vocab['__PAD__']
voc_size = len(vocab)
print("Vocabulary Size : ", voc_size)

dataset initializing start


Found cached dataset sentiment140 (C:/Users/atace/.cache/huggingface/datasets/sentiment140/sentiment140/1.0.0/f81c014152931b776735658d8ae493b181927de002e706c4d5244ecb26376997)
100%|██████████| 2/2 [00:00<00:00, 21.18it/s]


Length of data after first step of preprocessing:  70000
Tokenizing the data...
Length of the data :  70000
1
<class 'pandas.core.frame.DataFrame'>


Found cached dataset sst2 (C:/Users/atace/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)
100%|██████████| 3/3 [00:00<00:00, 373.29it/s]


Length of data after first step of preprocessing:  70042
Tokenizing the data...
Length of the data :  70042
1
text         [[CLS], wicked, ##bit, ##ch, it, is, because, ...
sentiment                                                    1
len                                                         15
Name: 1102814, dtype: object


100%|██████████| 56000/56000 [00:04<00:00, 13282.07it/s]
100%|██████████| 7000/7000 [00:00<00:00, 14410.37it/s]
100%|██████████| 7000/7000 [00:00<00:00, 14561.11it/s]


dataset initializing done
text         [[CLS], elegant, ##ly, appointed, [SEP]]
sentiment                                           1
len                                                 5
Name: 60611, dtype: object


100%|██████████| 56033/56033 [00:03<00:00, 14925.91it/s]
100%|██████████| 7004/7004 [00:00<00:00, 14422.72it/s]
100%|██████████| 7005/7005 [00:00<00:00, 14366.04it/s]


dataset initializing done
Vocabulary Size :  20369


In [79]:
# Create a mixed validation and test set
train_iter = concat_dataloaders(train_iter_A, train_iter_B)
valid_iter = concat_dataloaders(valid_iter_A, valid_iter_B)
test_iter = concat_dataloaders(test_iter_A, test_iter_B)

In [97]:
# Creating the embedding matrix
embedding = torch.load("Models/embedding_16_trained_downstream.pt")

Parameter containing:
tensor([[ 0.1391,  0.6202,  0.2282,  ...,  1.7034,  0.4591, -2.2606],
        [-0.6741, -0.7514,  0.2021,  ..., -1.3650, -0.4353,  0.8555],
        [-0.0928,  0.1908,  1.9213,  ...,  0.4029,  0.7751, -1.9303],
        ...,
        [-0.7450,  0.9792,  0.7684,  ..., -1.5404,  0.2547, -0.2025],
        [ 0.7792,  1.7217,  2.0281,  ...,  0.7905,  0.2468, -0.4598],
        [-1.8177,  2.1274,  0.2408,  ..., -1.1197, -0.0551, -0.2766]],
       requires_grad=True)

## Load the Models

In [98]:
modelA = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 2,
                              n_layers = 1,
                              drop_prob = 0.7,
                              device = device)

modelA.load_state_dict(torch.load("./Models/modelA_sentiment140_256"))
modelA.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(20369, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.7, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.7, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

In [99]:
modelB = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 2,
                              n_layers = 1,
                              drop_prob = 0.7,
                              device = device)

modelB.load_state_dict(torch.load("./Models/modelB_sst2_256"))
modelB.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(20369, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.7, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.7, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

## OT Functions

In [69]:
def getSupport(model, trainloader, l, alignment = "acts", numOfBatches= 10):
    '''
    Get the support matrices using Activation-based ("acts") or Weight-based ("wts") alignment 
    '''
    if alignment == "acts":
        activation = None
        for i, data in enumerate(trainloader, 0):
            if i >= numOfBatches:
                break
            
            inputs, targets = data
            outputs = model(inputs)

            if activation is None:
                activation = model.actMatrix[l]
            else:
                activation = torch.cat((activation, model.actMatrix[l]))

        return activation
    elif alignment == "wts":
        return model.state_dict()[l]


In [70]:
def fusion(nameA, nameB, weightA, weightB, train_iter_sentiment140, train_iter_sst2, transport_matrix, beta):
    support_y = getSupport(modelB, train_iter_sst2, nameB, alignment="wts")
    # Get the weights at layer "idx" from the first model
    W_A = weightA
    W_B = weightB
    # Align the weights from the first model
    aligned_W = torch.matmul(W_A, torch.matmul(transport_matrix, torch.diag(1 / beta)))
    # Get the X-Support
    n = W_A.shape[0]
    alpha = torch.ones(n) * (1/n)
    support_x = getSupport(modelA, train_iter_sentiment140, nameA, alignment="wts")
    # Calculate the euclidean distance between the supports
    distance = ot.dist(support_x, support_y)
    # Calculate beta
    m = W_B.shape[0]
    beta = torch.ones(m) * (1/m)
    # Calculate the transport matrix using optimal transport
    transport_matrix = torch.from_numpy(ot.emd(alpha.numpy(), beta.numpy(), distance.detach().numpy())).float().reshape((n, m))
    # Align model neurons
    aligned_model = torch.matmul(torch.diag(1 / beta), torch.matmul(transport_matrix.T, aligned_W))
    # Get the weights at layer "idx" from the second model
    fused = (aligned_model + W_B) / 2 
    return  fused, transport_matrix, beta

In [71]:
def fusion_multihead(nameA, nameB, weightA, weightB, transport_matrix, beta, head):
    support_y = weightB
    support_x = weightA
    # Get the weights at layer "idx" from the first model
    W_A = weightA
    W_B = weightB
    # Initialize the fused model and transport matrix
    fused = torch.empty(W_B.shape)
    transport_matrix_new = torch.zeros((weightA.shape[0], weightB.shape[0]))
    stride = weightB.shape[0] // head
    for i in range(0, weightB.shape[0], stride):
        # Align the weights from the first model
        aligned_W = torch.matmul(W_A[i:i+stride, :], torch.matmul(transport_matrix, torch.diag(1 / beta)))
        # Get the X-Support
        n = W_A.shape[0] // head
        alpha = torch.ones(n) * (1/n)
        # Calculate the euclidean distance between the supports
        distance = ot.dist(support_x[i:i+stride, :], support_y[i:i+stride, :])
        # Calculate beta
        m = W_B.shape[0] // head
        beta_new = torch.ones(m) * (1/m)
        # Calculate the transport matrix using optimal transport
        transport_matrix_new[i:i+stride, i:i+stride] = torch.from_numpy(ot.emd(alpha.numpy(), beta_new.numpy(), distance.detach().numpy())).float().reshape((n, m))
        # Align model neurons
        aligned_model = torch.matmul(torch.diag(1 / beta_new), torch.matmul(transport_matrix_new[i:i+stride, i:i+stride].T, aligned_W))
        # Get the weights at layer "idx" from the second model
        fused[i:i+stride, :] = (aligned_model + W_B[i:i+stride, :]) / 2 
    return  fused, transport_matrix_new, beta_new

In [72]:
def fusion_crossmultihead(nameA, nameB, weightA, weightB, train_iter_sentiment140, train_iter_sst2, transport_matrix, beta, head):
    W_A_head = weightA.view(head, -1)
    W_B_head = weightB.view(head, -1)
    
    m = W_B_head.shape[1]
    beta_head = torch.ones(m) * (1/m)
    transport_matrix_head = torch.matmul(torch.diag(beta_head), torch.eye(m))

    support_y = getSupport(modelB, train_iter_sst2, nameB, alignment="wts")
    support_x = getSupport(modelA, train_iter_sentiment140, nameA, alignment="wts")

    aligned_W = torch.matmul(W_A_head, torch.matmul(transport_matrix_head, torch.diag(1 / beta_head)))

    dist_head = ot.dist(support_x.view(head, -1), support_y.view(head, -1))

    n = W_A_head.shape[0]
    alpha_head = torch.ones(n) * (1/n)

    m = W_B_head.shape[0]
    beta_head = torch.ones(m) * (1/m)

    transport_matrix_new = torch.from_numpy(ot.emd(alpha_head.numpy(), beta_head.numpy(), dist_head.detach().numpy())).float().reshape((n, m))

    aligned_W_A = torch.matmul(torch.diag(1 / beta_head), torch.matmul(transport_matrix_new.T, aligned_W))
    aligned_W_A = aligned_W_A.view(weightA.shape)
    return fusion_multihead(nameA, nameB, aligned_W_A, weightB, transport_matrix, beta, head)

## Fusion via Optimal Transport

In [84]:
fusedModel = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding = embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 2,
                              n_layers = 1,
                              drop_prob = 0.2,
                              device = device)

## Method 1

In [85]:
def obj(trial):
    a = trial.suggest_float('a', 0, 1)
    
    # Create the fused weights matrix
    W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
    # Initialize the algorithm
    m = list(modelB.state_dict().items())[1][1].shape[1]
    beta = torch.ones(m) * (1/m)
    transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))

    # Fusion via Optimal Transport
    for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
        if nameA == "encoder.emb.tok_emb.embedding.weight":
            W_fusion[nameA] = weightA
        else:
            if "weight" in nameA:
                if "encoder" in nameA:
                    if "concat" not in nameA and "linear" not in nameA: 
                        W_fusion[nameA], transport_matrix_triplet, _ = fusion(nameA, nameB, weightA, weightB, train_iter_A, train_iter_B, transport_matrix, beta)
                    else:
                        W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, train_iter_A, train_iter_B, transport_matrix, beta)

                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            elif "bias" in nameA:
                if "encoder" in nameA: 
                    if "concat" not in nameA and "linear" not in nameA: 
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                    else:
                        m = weightB.shape[0]
                        beta_bias = torch.ones(m) * (1/m)
                        W_A_bias = weightA.reshape(m, 1)
                        aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                        aligned_bias = aligned_bias.reshape(m)
                        W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB

    # Assign the weights
    with torch.no_grad():
        for name, param in fusedModel.named_parameters():
            param.data = torch.nn.Parameter(W_fusion[name])

    # Validate the fused model
    criterion = torch.nn.CrossEntropyLoss()
    val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
    return val_loss

In [86]:
study = optuna.create_study()
study.optimize(obj, n_trials=5)

[32m[I 2023-01-08 15:06:58,055][0m A new study created in memory with name: no-name-a235a0ef-53e2-4db8-a015-859fef0bb7d7[0m
  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
[32m[I 2023-01-08 15:07:43,010][0m Trial 0 finished with value: 0.6837837249040604 and parameters: {'a': 0.6404820603956419}. Best is trial 0 with value: 0.6837837249040604.[0m
[32m[I 2023-01-08 15:08:20,363][0m Trial 1 finished with value: 0.6818252695458276 and parameters: {'a': 0.32792368419063855}. Best is trial 1 with value: 0.6818252695458276.[0m
[32m[I 2023-01-08 15:08:57,277][0m Trial 2 finished with value: 0.6835779973438808 and parameters: {'a': 0.024126054616033232}. Best is trial 1 with value: 0.6818252695458276.[0m
[32m[I 2023-01-08 15:09:34,913][0m Trial 3 finished with value: 0.685010667358126 and parameters: {'a': 0.9389929405376469}. Best is trial 1 with value: 0.6818252695458276.[0m
[32m[I 2023-01-08 15:10:18,087][0m Trial 4 finished with val

In [87]:
a = study.best_params["a"]

# Create the fused weights matrix
W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
# Initialize the algorithm
m = list(modelB.state_dict().items())[1][1].shape[1]
beta = torch.ones(m) * (1/m)
transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))
# Fusion via Optimal Transport
for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
    if nameA == "encoder.emb.tok_emb.embedding.weight":
        W_fusion[nameA] = weightA
    else:
        if "weight" in nameA:
            if "encoder" in nameA:
                if "concat" not in nameA and "linear" not in nameA: 
                    W_fusion[nameA], transport_matrix_triplet, _ = fusion(nameA, nameB, weightA, weightB, train_iter_A, train_iter_B, transport_matrix, beta)
                else:
                    W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, train_iter_A, train_iter_B, transport_matrix, beta)
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        elif "bias" in nameA:
            if "encoder" in nameA: 
                if "concat" not in nameA and "linear" not in nameA: 
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix_triplet.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        else:
            W_fusion[nameA] = a * weightA + (1-a) * weightB
# Assign the weights
with torch.no_grad():
    for name, param in fusedModel.named_parameters():
        param.data = torch.nn.Parameter(W_fusion[name])
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Validation Loss:      0.6827030032873154
Validation Accuracy:      0.6888470362103174


## Recalibrate the Normalization layers

In [88]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "weight" in name or "bias" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 219/219 [06:28<00:00,  1.77s/it]


In [89]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Validation Loss:      0.5135767949478967
Validation Accuracy:      0.7650127108134921


In [90]:
for name, param in fusedModel.named_parameters():
    param.requires_grad = True

## Train the last Layer

In [91]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "encoder" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 219/219 [05:10<00:00,  1.42s/it]


In [92]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Validation Loss:      0.5072341816765922
Validation Accuracy:      0.7600818452380952


## Test Set

In [100]:
# Test the models on the mixed test set
criterion = torch.nn.CrossEntropyLoss()

test_loss_fused, test_acc_fused = validation(fusedModel, test_iter, criterion, device)
test_loss_A, test_acc_A = validation(modelA, test_iter, criterion, device)
test_loss_B, test_acc_B = validation(modelB, test_iter, criterion, device)

print("Test Loss Fused Model:     ", test_loss_fused)
print("Test Accuracy Fused Model:     ", test_acc_fused.item())
print("Test Loss Model A:     ", test_loss_A)
print("Test Accuracy Model A:     ", test_acc_A.item())
print("Test Loss Model B:     ", test_loss_B)
print("Test Accuracy Model B:     ", test_acc_B.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Test Loss Fused Model:      0.4986152691500528
Test Accuracy Fused Model:      0.7626509933652328
Test Loss Model A:      0.6701862599168505
Test Accuracy Model A:      0.6961649689226519
Test Loss Model B:      0.6710367692368371
Test Accuracy Model B:      0.6956547207971587


In [101]:
# Test the models on the sentiment140 test set
criterion = torch.nn.CrossEntropyLoss()

test_loss_fused, test_acc_fused = validation(fusedModel, test_iter_A, criterion, device)
test_loss_A, test_acc_A = validation(modelA, test_iter_A, criterion, device)
test_loss_B, test_acc_B = validation(modelB, test_iter_A, criterion, device)

print("Test Loss Fused Model:     ", test_loss_fused)
print("Test Accuracy Fused Model:     ", test_acc_fused.item())
print("Test Loss Model A:     ", test_loss_A)
print("Test Accuracy Model A:     ", test_acc_A.item())
print("Test Loss Model B:     ", test_loss_B)
print("Test Accuracy Model B:     ", test_acc_B.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Test Loss Fused Model:      0.577849805355072
Test Accuracy Fused Model:      0.7081018999169435
Test Loss Model A:      0.7167302199772426
Test Accuracy Model A:      0.6839019674003322
Test Loss Model B:      0.7163661931242261
Test Accuracy Model B:      0.6835613060631228


In [102]:
# Test the models on the Stanford Treebank test set
criterion = torch.nn.CrossEntropyLoss()

test_loss_fused, test_acc_fused = validation(fusedModel, test_iter_B, criterion, device)
test_loss_A, test_acc_A = validation(modelA, test_iter_B, criterion, device)
test_loss_B, test_acc_B = validation(modelB, test_iter_B, criterion, device)

print("Test Loss Fused Model:     ", test_loss_fused)
print("Test Accuracy Fused Model:     ", test_acc_fused.item())
print("Test Loss Model A:     ", test_loss_A)
print("Test Accuracy Model A:     ", test_acc_A.item())
print("Test Loss Model B:     ", test_loss_B)
print("Test Accuracy Model B:     ", test_acc_B.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Test Loss Fused Model:      0.41766093032700674
Test Accuracy Fused Model:      0.8173796145875972
Test Loss Model A:      0.6251829138823918
Test Accuracy Model A:      0.7070072656569791
Test Loss Model B:      0.6267327538558415
Test Accuracy Model B:      0.7072678955178059
