## Imports

In [1]:
from transformer import TransformerClassifier
from dataloader import *
import torch
import ot
from copy import deepcopy

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load the Data

In [3]:
# init
tokenizer = Tokenizer()
loader = DataLoader(tokenize = tokenizer.tokenize)

# import data (combine train/test as we split afterwards anyways)
data = pd.concat([pd.read_csv("./Data/SMS_train.csv", encoding='ISO-8859-1'),
                  pd.read_csv("./Data/SMS_test.csv", encoding='ISO-8859-1')])

# convert string label to binary (int) label (spam:1, non-spam:0)
labels = pd.Series((data['Label'] == 'Spam').astype(int))
data['Label'] = labels

# train, test, val split
train, valid, test = loader.make_dataset(data[['Message_body', 'Label']])
vocab = loader.get_vocab(train.iloc[:, 0])
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,
                                                     batch_size=128,
                                                     device=device)

# NLP stuff
pad_idx = vocab['__PAD__']
voc_size = len(vocab)
print("Vocabulary Size : ", voc_size)

dataset initializing start
1
Message_body    [[CLS], free, ##ms, ##g, hey, there, darling, ...
Label                                                           1
len                                                           147
Name: 64, dtype: object


100%|██████████| 856/856 [00:00<00:00, 11003.14it/s]
100%|██████████| 107/107 [00:00<00:00, 11860.84it/s]
100%|██████████| 108/108 [00:00<00:00, 10153.88it/s]

dataset initializing done
Vocabulary Size :  2707





## Load the Model Weights

In [4]:
modelA = TransformerClassifier(src_pad_idx = pad_idx,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 512,
                              ffn_hidden = 2048,
                              n_head = 1,
                              n_layers = 1,
                              drop_prob = 0.1,
                              device = device)

modelA.load_state_dict(torch.load("Models\modelA"))
modelA.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(2707, 512, padding_idx=1)
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (w_concat): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu)

In [5]:
modelB = TransformerClassifier(src_pad_idx = pad_idx,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 512,
                              ffn_hidden = 2048,
                              n_head = 1,
                              n_layers = 1,
                              drop_prob = 0.1,
                              device = device)

modelB.load_state_dict(torch.load("Models\modelB"))
modelB.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(2707, 512, padding_idx=1)
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (w_concat): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu)

## OT Functions 

In [49]:
def getSupport(model, trainloader, l, alignment = "acts", numOfBatches= 10):
    '''
    Get the support matrices using Activation-based ("acts") or Weight-based ("wts") alignment 
    '''
    if alignment == "acts":
        activation = None
        for i, data in enumerate(trainloader, 0):
            if i >= numOfBatches:
                break
            
            inputs, targets = data
            outputs = model(inputs)

            if activation is None:
                activation = model.actMatrix[l]
            else:
                activation = torch.cat((activation, model.actMatrix[l]))

        return activation
    elif alignment == "wts":
        return model.state_dict()[l]


In [50]:
def fusion(nameA, nameB, weightA, weightB, transport_matrix, beta):
    support_y = getSupport(modelB, train_iter, nameB, alignment="wts")
    # Get the weights at layer "idx" from the first model
    W_A = weightA
    W_B = weightB
    # Align the weights from the first model
    print(W_A.shape, transport_matrix.shape, torch.diag(1 / beta).shape)
    aligned_W = torch.matmul(W_A, torch.matmul(transport_matrix, torch.diag(1 / beta)))
    # Get the X-Support
    n = W_A.shape[0]
    alpha = torch.ones(n) * (1/n)
    support_x = getSupport(modelA, train_iter, nameA, alignment="wts")
    # Calculate the euclidean distance between the supports
    distance = ot.dist(support_x, support_y)
    # Calculate beta
    m = W_B.shape[0]
    beta = torch.ones(m) * (1/m)
    # Calculate the transport matrix using optimal transport
    print(alpha.shape, beta.shape, distance.shape)
    transport_matrix = torch.from_numpy(ot.emd(alpha.numpy(), beta.numpy(), distance.detach().numpy())).float().reshape((n, m))
    # Align model neurons
    print(torch.diag(1 / beta).shape, transport_matrix.T.shape, aligned_W.shape)
    aligned_model = torch.matmul(torch.diag(1 / beta), torch.matmul(transport_matrix.T, aligned_W))
    # Get the weights at layer "idx" from the second model
    fused = (aligned_model + W_B).T / 2 if "ffn" in nameA else (aligned_model + W_B) / 2 
    return  fused, transport_matrix, beta

# Fusion via Optimal Transport

In [51]:
fusedModel = TransformerClassifier(src_pad_idx = pad_idx,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 512,
                              ffn_hidden = 2048,
                              n_head = 1,
                              n_layers = 1,
                              drop_prob = 0.1,
                              device = device)

In [68]:
# Create the fused weights matrix
W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
# Initialize the algorithm
m = list(modelB.state_dict().items())[1][1].shape[1]
beta = torch.ones(m) * (1/m)
transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))


for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
    if nameA == "encoder.emb.tok_emb.weight":
        W_fusion[nameA] = weightA
    else:
        if "weight" in nameA:
            if "encoder" in nameA: 
                W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)
            else:
                W_fusion[nameA] = weightA
        elif "bias" in nameA:
            if "encoder" in nameA: 
                m = weightB.shape[0]
                beta_bias = torch.ones(m) * (1/m)
                W_A_bias = weightA.reshape(m, 1)
                aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                W_fusion[nameA] = (aligned_bias + weightB) / 2
            else:
                W_fusion[nameA] = weightA
        else:
            continue

torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([512]) torch.Size([512]) torch.Size([512, 512])
torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([512]) torch.Size([512]) torch.Size([512, 512])
torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([512]) torch.Size([512]) torch.Size([512, 512])
torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([512]) torch.Size([512]) torch.Size([512, 512])
torch.Size([512, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([2048, 512]) torch.Size([512, 512]) torch.Size([512, 512])
torch.Size([2048]) torch.Size([2048]) torch.Size([2048, 2048])
torch.Size([2048, 2048]) torch.Size([2048, 2048]) torch.Size([2048, 512])
torch

## Validation of the fused model

In [70]:
criterion = torch.nn.CrossEntropyLoss()

In [69]:
def validation(model, iterator, criterion, device):
    # set model into evaluation mode
    model.eval()

    # validation
    # loss, metrics for current epoch
    val_epoch_loss = 0
    val_epoch_accuracy = 0

    with torch.no_grad(): # stop graph
        # batches
        for i, batch in enumerate(iterator):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            output = model(src)
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)

            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy

            val_epoch_loss += loss.item()
            val_epoch_accuracy += accuracy

    # return mean loss w.r.t. batches
    return val_epoch_loss / len(iterator), val_epoch_accuracy / len(iterator)

In [73]:
val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Validation Loss:      0.5577811002731323
Validation Accuracy:      0.822429906542056
