## Imports

In [4]:
from transformer import TransformerClassifier
from dataloader import *
import torch
import ot
from copy import deepcopy
from tqdm import tqdm
import optuna

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Load the Data

In [6]:
#init
tokenizer = Tokenizer()
loader = DataLoader(tokenize = tokenizer.tokenize)

# import data (combine train/test as we split afterwards anyways)
data = pd.read_csv("./Data/IMDB Dataset.csv", encoding='ISO-8859-1')
# convert string label to binary (int) label (spam:1, non-spam:0)
data["sentiment"] = data['sentiment'].apply(lambda x : int(x == "positive"))

# train, test, val split
train, valid, test = loader.make_dataset(data)
vocab = loader.get_vocab(train.iloc[:, 0])
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,
                                                     batch_size=512,
                                                     device=device)

# NLP stuff
pad_idx = vocab['__PAD__']
voc_size = len(vocab)
print("Vocabulary Size : ", voc_size)

dataset initializing start
Tokenizing the data...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data["len"] = data.iloc[:, 0].apply(lambda x : len(self.tokenize(x)))


Length of the data :  29544
0
review       [[CLS], one, of, the, many, silent, comedies, ...
sentiment                                                    0
len                                                        186
Name: 46539, dtype: object


100%|██████████| 23635/23635 [00:01<00:00, 14783.00it/s]
100%|██████████| 2954/2954 [00:00<00:00, 13424.14it/s]
100%|██████████| 2955/2955 [00:00<00:00, 14698.97it/s]


dataset initializing done
Vocabulary Size :  23050


In [7]:
def validation(model, iterator, criterion, device):
    # set model into evaluation mode
    model.eval()

    # validation
    # loss, metrics for current epoch
    val_epoch_loss = 0
    val_epoch_accuracy = 0

    with torch.no_grad(): # stop graph
        # batches
        for i, batch in enumerate(iterator):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            output = model(src)
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)

            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy

            val_epoch_loss += loss.item()
            val_epoch_accuracy += accuracy

    # return mean loss w.r.t. batches
    return val_epoch_loss / len(iterator), val_epoch_accuracy / len(iterator)

In [8]:
# Load the embedding matrix
embedding = torch.load("Models/embedding_16.pt")

## Load the Model Weights

In [9]:
modelA = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 1,
                              drop_prob = 0.5,
                              device = device)

modelA.load_state_dict(torch.load("Models\modelA_IMDB_256"))
modelA.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(23050, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.5, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.5, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

In [10]:
modelB = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding=embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 1,
                              drop_prob = 0.5,
                              device = device)

modelB.load_state_dict(torch.load("Models\modelB_IMBD_256"))
modelB.eval()

TransformerClassifier(
  (encoder): Encoder(
    (emb): TransformerEmbedding(
      (tok_emb): TokenEmbedding(
        (embedding): Embedding(23050, 16)
      )
      (pos_emb): PositionalEncoding()
      (drop_out): Dropout(p=0.5, inplace=False)
    )
    (layers): ModuleList(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (attention): ScaleDotProductAttention(
            (softmax): Softmax(dim=-1)
          )
          (w_q): Linear(in_features=16, out_features=16, bias=True)
          (w_k): Linear(in_features=16, out_features=16, bias=True)
          (w_v): Linear(in_features=16, out_features=16, bias=True)
          (w_concat): Linear(in_features=16, out_features=16, bias=True)
        )
        (norm1): LayerNorm()
        (dropout1): Dropout(p=0.5, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=16, out_features=32, bias=True)
          (linear2): Linear(in_features=32, out_features=16, bias=True)
     

## OT Functions 

In [11]:
def getSupport(model, trainloader, l, alignment = "acts", numOfBatches= 10):
    '''
    Get the support matrices using Activation-based ("acts") or Weight-based ("wts") alignment 
    '''
    if alignment == "acts":
        activation = None
        for i, data in enumerate(trainloader, 0):
            if i >= numOfBatches:
                break
            
            inputs, targets = data
            outputs = model(inputs)

            if activation is None:
                activation = model.actMatrix[l]
            else:
                activation = torch.cat((activation, model.actMatrix[l]))

        return activation
    elif alignment == "wts":
        return model.state_dict()[l]


In [12]:
def fusion(nameA, nameB, weightA, weightB, transport_matrix, beta):
    support_y = getSupport(modelB, train_iter, nameB, alignment="wts")
    # Get the weights at layer "idx" from the first model
    W_A = weightA
    W_B = weightB
    # Align the weights from the first model
    aligned_W = torch.matmul(W_A, torch.matmul(transport_matrix, torch.diag(1 / beta)))
    # Get the X-Support
    n = W_A.shape[0]
    alpha = torch.ones(n) * (1/n)
    support_x = getSupport(modelA, train_iter, nameA, alignment="wts")
    # Calculate the euclidean distance between the supports
    distance = ot.dist(support_x, support_y)
    # Calculate beta
    m = W_B.shape[0]
    beta = torch.ones(m) * (1/m)
    # Calculate the transport matrix using optimal transport
    transport_matrix = torch.from_numpy(ot.emd(alpha.numpy(), beta.numpy(), distance.detach().numpy())).float().reshape((n, m))
    # Align model neurons
    aligned_model = torch.matmul(torch.diag(1 / beta), torch.matmul(transport_matrix.T, aligned_W))
    # Get the weights at layer "idx" from the second model
    fused = (aligned_model + W_B) / 2 
    return  fused, transport_matrix, beta

# Fusion via Optimal Transport

In [46]:
fusedModel = TransformerClassifier(src_pad_idx = pad_idx,
                              embedding = embedding,
                              enc_voc_size = voc_size,
                              max_len = 256,
                              d_model = 16,
                              ffn_hidden = 32,
                              n_head = 1,
                              n_layers = 1,
                              drop_prob = 0.2,
                              device = device)

In [47]:
def obj(trial):
    a = trial.suggest_float('a', 0, 1)
    
    # Create the fused weights matrix
    W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
    # Initialize the algorithm
    m = list(modelB.state_dict().items())[1][1].shape[1]
    beta = torch.ones(m) * (1/m)
    transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))

    # Fusion via Optimal Transport
    for (nameA, weightA), (nameB, weightB) in zip(modelA.named_parameters(), modelB.named_parameters()):
        if nameA == "encoder.emb.tok_emb.embedding.weight":
            W_fusion[nameA] = weightA
        else:
            if "weight" in nameA:
                if "encoder" in nameA: 
                    W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)
                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            elif "bias" in nameA:
                if "encoder" in nameA: 
                    m = weightB.shape[0]
                    beta_bias = torch.ones(m) * (1/m)
                    W_A_bias = weightA.reshape(m, 1)
                    aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                    aligned_bias = aligned_bias.reshape(m)
                    W_fusion[nameA] = (aligned_bias + weightB) / 2
                else:
                    W_fusion[nameA] = a * weightA + (1-a) * weightB
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB

    # Assign the weights
    with torch.no_grad():
        for name, param in fusedModel.named_parameters():
            param.data = torch.nn.Parameter(W_fusion[name])

    # Validate the fused model
    criterion = torch.nn.CrossEntropyLoss()
    val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
    return val_loss

In [48]:
study = optuna.create_study()
study.optimize(obj, n_trials=20)

[32m[I 2022-12-19 15:01:58,086][0m A new study created in memory with name: no-name-97ceef6f-f916-459f-b95a-0fed8ac49ec9[0m
  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
[32m[I 2022-12-19 15:02:01,966][0m Trial 0 finished with value: 1.1502549449602764 and parameters: {'a': 0.8042947919075609}. Best is trial 0 with value: 1.1502549449602764.[0m
[32m[I 2022-12-19 15:02:05,630][0m Trial 1 finished with value: 0.9582643111546835 and parameters: {'a': 0.5495655676815925}. Best is trial 1 with value: 0.9582643111546835.[0m
[32m[I 2022-12-19 15:02:09,230][0m Trial 2 finished with value: 0.9467704097429911 and parameters: {'a': 0.5288957040324305}. Best is trial 2 with value: 0.9467704097429911.[0m
[32m[I 2022-12-19 15:02:12,736][0m Trial 3 finished with value: 1.2615215579668682 and parameters: {'a': 0.9192471866080213}. Best is trial 2 with value: 0.9467704097429911.[0m
[32m[I 2022-12-19 15:02:16,214][0m Trial 4 finished with value

In [49]:
a = study.best_params["a"]

# Create the fused weights matrix
W_fusion = dict.fromkeys(list(modelA.state_dict().keys()))
# Initialize the algorithm
m = list(modelB.state_dict().items())[1][1].shape[1]
beta = torch.ones(m) * (1/m)
transport_matrix = torch.matmul(torch.diag(beta), torch.eye(m))
# Fusion via Optimal Transport
for (nameA, weightA), (nameB, weightB) in tqdm(zip(modelA.named_parameters(), modelB.named_parameters())):
    if nameA == "encoder.emb.tok_emb.embedding.weight":
        W_fusion[nameA] = weightA
    else:
        if "weight" in nameA:
            if "encoder" in nameA: 
                W_fusion[nameA], transport_matrix, beta = fusion(nameA, nameB, weightA, weightB, transport_matrix, beta)
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        elif "bias" in nameA:
            if "encoder" in nameA: 
                m = weightB.shape[0]
                beta_bias = torch.ones(m) * (1/m)
                W_A_bias = weightA.reshape(m, 1)
                aligned_bias = torch.matmul(torch.diag(1 / beta_bias), torch.matmul(transport_matrix.T, W_A_bias))
                aligned_bias = aligned_bias.reshape(m)
                W_fusion[nameA] = (aligned_bias + weightB) / 2
            else:
                W_fusion[nameA] = a * weightA + (1-a) * weightB
        else:
            W_fusion[nameA] = a * weightA + (1-a) * weightB
# Assign the weights
with torch.no_grad():
    for name, param in fusedModel.named_parameters():
        param.data = torch.nn.Parameter(W_fusion[name])
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

19it [00:00, 474.97it/s]
  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Validation Loss:      0.7737600008646647
Validation Accuracy:      0.7455881186548222


## Recalibrate the normalization layers

In [50]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "weight" in name or "bias" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 47/47 [01:17<00:00,  1.66s/it]


In [51]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Validation Loss:      0.5847417612870535
Validation Accuracy:      0.7797975491751269


In [52]:
for name, param in fusedModel.named_parameters():
    param.requires_grad = True

## Train the last layer

In [53]:
optimizer = torch.optim.Adam(fusedModel.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

clip = 1

fusedModel.train()
for name, param in fusedModel.named_parameters():
    if "encoder" in name:
        param.requires_grad = False
    
for i, batch in enumerate(tqdm(train_iter)):
            src = batch[0] # X
            trg = batch[1] # y
            src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
            optimizer.zero_grad() # reset optimizer
            output = fusedModel(src) # predict
            y_pred = torch.argmax(output, dim=-1) # logits -> labels
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg.to(torch.int64)
            loss = criterion(output_reshape, trg) # calculate loss
            agreements = torch.eq(y_pred, trg)
            accuracy = torch.mean(agreements.double()) # calculate accuracy
            loss.backward() # backward pass

            torch.nn.utils.clip_grad_norm_(fusedModel.parameters(), clip)
            optimizer.step() # optimize model

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu
100%|██████████| 47/47 [00:55<00:00,  1.18s/it]


In [54]:
# Validate the fused model
criterion = torch.nn.CrossEntropyLoss()
val_loss, val_acc = validation(fusedModel, valid_iter, criterion, device)
print("Validation Loss:     ", val_loss)
print("Validation Accuracy:     ", val_acc.item())

  src, trg = torch.tensor(src).to(device), torch.tensor(trg).to(device) # put to cpu/gpu


Validation Loss:      0.465085710088412
Validation Accuracy:      0.8296650274957699
