In [96]:
from transformers import AutoModelForSeq2SeqLM,AutoTokenizer
import pandas as pd
from bertviz import head_view
import torch
from torch import nn
from sklearn.model_selection import train_test_split
import tqdm
import pickle
import numpy as np
import datetime
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data
import json

In [120]:
path_train = "../data/labeled-train.model-aware.v2.json"
df = pd.read_json(path_train)

In [121]:
path_val = "../data/val.model-aware.v2.json"
val_df = pd.read_json(path_val)
val_df["label"] = val_df["label"].astype('category')
val_df = val_df.dropna()

In [138]:
TASK = "PG"
FILTER = False

task_df = df[(df["task"] == TASK)] 
model_name = task_df["model"].unique().item()
if FILTER:
    task_df = task_df[(task_df["p(Hallucination)"] == 0) | (task_df["p(Hallucination)"]==1)]
task_df = task_df[["src","hyp","label"]]
task_df["label"] = task_df["label"].astype('category')
task_df = task_df.dropna()

In [141]:
task_df["label"]

10000        Hallucination
10001        Hallucination
10002        Hallucination
10003        Hallucination
10004        Hallucination
               ...        
19995        Hallucination
19996    Not Hallucination
19997    Not Hallucination
19998        Hallucination
19999    Not Hallucination
Name: label, Length: 9999, dtype: category
Categories (2, object): ['Hallucination', 'Not Hallucination']

In [123]:
# Check that MPS is available
if torch.cuda.is_available():
    device = "cuda"
elif not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
    device = "cpu"
else:
    device = "mps"
print(device)

cuda


In [105]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [124]:
def get_hidden_states(outputs):
    eos_indices = []
    
    for seq in outputs.sequences:
        #skip first pad token
        if seq[0] in [tokenizer.eos_token_id,tokenizer.pad_token_id]:
            seq = seq[1:]
        nonzero = (seq == tokenizer.eos_token_id).nonzero()
        if nonzero.nelement() != 0:
            idx = nonzero.item()
        else:
            idx = seq.size(0)-1 #last token
        eos_indices.append(idx)        
    eos_indices = torch.tensor(eos_indices,device=device)

    decoder_hidden_states = []
    for t in outputs.decoder_hidden_states:
        decoder_hidden_states.append(torch.stack(t))
    decoder_hidden_states = torch.stack(decoder_hidden_states)
    seq_len,layers,batch, _, hidden_dim = decoder_hidden_states.size()
    decoder_hidden_states = decoder_hidden_states.reshape((layers,batch,seq_len,-1))

    eos_indices = eos_indices.reshape((1,-1,1,1)).repeat(layers,1,1,hidden_dim)
    hiddens = decoder_hidden_states.gather(dim=2,index=eos_indices).squeeze(2)
    return hiddens


In [125]:
def infer_df(model,tokenizer,df,start_batch = 0,batch_size = 16):        
    batched_df = [df[i:i+batch_size] for i in range(0,len(df),batch_size)]
    dataset = []
    for batch in tqdm.tqdm(batched_df[start_batch:],total=len(batched_df[start_batch:])):
        tokens = tokenizer(batch["src"].to_list(),
                        return_tensors="pt",
                        padding=True,
                        truncation=True).to(device)
        with torch.no_grad():
            outputs = model.generate(**tokens,
                                    max_new_tokens=25,
                                    # num_beams=10, 
                                    num_return_sequences=1,
                                    # temperature=0,
                                    return_dict_in_generate=True,
                                    output_hidden_states=True)
            hiddens = get_hidden_states(outputs)
            hiddens = hiddens.detach().to("cpu")
        dataset.append((hiddens,batch["label"].cat.codes.values))
    hiddens,labels = list(zip(*dataset))
    hiddens = torch.cat(hiddens,dim=1)
    labels = np.concatenate(labels)
    return hiddens,labels

# hiddens,labels = infer_df(model,tokenizer,task_df)
# with open(f"../data/{TASK}-train.pkl", "wb") as f:
#     pickle.dump((hiddens,labels),f)

In [None]:
labels

In [137]:
with open(f"../data/{TASK}-train.pkl", "rb") as f:
    hiddens,labels = pickle.load(f)
with open(f"../data/{TASK}-val.pkl", "rb") as f:
    val_hiddens,val_y = pickle.load(f)
hiddens.size(),labels.shape

(torch.Size([17, 9999, 1024]), (9999,))

In [127]:
hidden_dim = hiddens.size(-1)
class StatesClassifier(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.ln1 = nn.Linear(hidden_dim,256)
        self.ln2 = nn.Linear(256,128)
        self.ln3 = nn.Linear(128,64)
        self.head = nn.Linear(64,1)

    def forward(self,x):
        x = self.ln1(x).relu()
        x = self.ln2(x).relu()
        x = self.ln3(x).relu()
        return self.head(x) #.sigmoid()

In [128]:
def get_weights(labels):
    uniques, counts = np.unique(labels,return_counts=True)
    weights_dict = {u:1/c for u,c in zip(uniques,counts)}
    return weights_dict

In [129]:
def train_one_epoch(model,loader,optimizer,loss_fn):
    running_loss = 0.
    
    for i, data in enumerate(loader):        
        inputs, labels = data
        labels = labels.unsqueeze(1).float()
        optimizer.zero_grad()

        outputs = model(inputs.to(device))

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels.to(device))
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()        

    return running_loss / len(loader)

In [130]:
EPOCHS = 20
rates = [0.01,0.001,0.0001]
use_sampler_p = [False,True]

best_vloss = 1_000_000.
best_acc = 0.0

model_path = None

total_layers = hiddens.size(0)
results = []
now = datetime.datetime.now().strftime("%H:%M:%S")
for use_sampler in use_sampler_p:
    for rate in rates:
        for layer_num in range(total_layers):
            writer = SummaryWriter(f'runs/{TASK}/{now}/layer_{layer_num}')
            train_x = hiddens[layer_num]
            train_y = labels
            val_x = val_hiddens[layer_num]
            # train_x, val_x, train_y,val_y = train_test_split(hiddens[layer_num],labels,test_size=0.1,random_state=42,stratify=labels)    
            weights_dict = get_weights(train_y)
            weights = [weights_dict[l]for l in train_y]
            if use_sampler:
                sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(train_y), replacement=True)
            else:
                sampler = None
            shuffle = sampler is None
            pos_weight = torch.tensor(weights_dict[1]) if sampler is None else None
            training_loader = torch.utils.data.DataLoader(list(zip(train_x,train_y)), batch_size=64,sampler=sampler,shuffle=shuffle)
            validation_loader = torch.utils.data.DataLoader(list(zip(val_x,val_y)), batch_size=64, shuffle=False)

            classifier = StatesClassifier().to(device)
            optimizer = torch.optim.Adam(classifier.parameters(), lr=rate)
            loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
            for epoch_number in range(EPOCHS):
                print(f'EPOCH {epoch_number+1}:')

                # Make sure gradient tracking is on, and do a pass over the data
                classifier.train(True)
                avg_loss = train_one_epoch(classifier,training_loader,optimizer,loss_fn)

                running_vloss = 0.0
                running_vacc = 0.0

                classifier.eval()
                with torch.no_grad():
                    for i, vdata in enumerate(validation_loader):
                        vinputs, vlabels = vdata
                        vinputs= vinputs.to(device)
                        vlabels = vlabels.unsqueeze(1).float().to(device)

                        voutputs = classifier(vinputs)
                        vloss = loss_fn(voutputs, vlabels)
                        running_vloss += vloss
                        running_vacc += (voutputs.sigmoid().round() == vlabels).float().mean()

                avg_vloss = running_vloss / (i + 1)
                avg_acc = running_vacc / (i+1)
                print(f'LOSS train {avg_loss} valid {avg_vloss}')
                print(f'Valid ACC {avg_acc}')

                # Log the running loss averaged per batch
                # for both training and validation
                writer.add_scalars('Training Loss',
                                { 'Training' : avg_loss},
                                epoch_number + 1)
                writer.add_scalars('Validation Loss',
                                {'Validation' : avg_vloss },
                                epoch_number + 1)
                writer.add_scalars('Validation accuracy',
                                { 'ACC' : avg_acc},
                                epoch_number + 1)
                writer.flush()

                # Track best performance, and save the model's state
                if avg_acc > best_acc:
                    best_acc = avg_acc    
                    model_path = f'{TASK}_model_best.pt'
                    torch.save(classifier.state_dict(), model_path)
                    results.append(
                        {
                            "accuracy":avg_acc.item(),
                            "loss":avg_vloss.item(),
                            "epoch":epoch_number,
                            "layer_num":layer_num,
                            "lr":rate,
                            "weighted_sampler": use_sampler
                        }
                        )                        

                epoch_number += 1

EPOCH 1:
LOSS train 14.101819457426952 valid 10.013069152832031
Valid ACC 0.41126179695129395
EPOCH 2:
LOSS train 22.101242303278795 valid 21.555482864379883
Valid ACC 0.41126179695129395
EPOCH 3:
LOSS train 13.04271861094578 valid 2.2027077674865723
Valid ACC 0.41126179695129395
EPOCH 4:
LOSS train 5.578089771661789 valid 4.79072904586792
Valid ACC 0.41126179695129395
EPOCH 5:
LOSS train 39.79845873261713 valid 22.727550506591797
Valid ACC 0.41126179695129395
EPOCH 6:
LOSS train 76.33103069226453 valid 29.958980560302734
Valid ACC 0.41126179695129395
EPOCH 7:
LOSS train 60.126797256765855 valid 16.16668128967285
Valid ACC 0.41126179695129395
EPOCH 8:
LOSS train 17.6792485972119 valid 9.15337085723877
Valid ACC 0.41126179695129395
EPOCH 9:
LOSS train 9.487126863686143 valid 4.380025863647461
Valid ACC 0.41126179695129395
EPOCH 10:
LOSS train 3.482228779441612 valid 6.862470626831055
Valid ACC 0.41126179695129395
EPOCH 11:
LOSS train 13.814428764164068 valid 9.164053916931152
Valid ACC 

In [131]:
best_results = list(sorted(results,key=lambda x:x["accuracy"],reverse=True))[:10]
best_results = [{k: v.item() if isinstance(v,torch.Tensor) else v for k,v in r.items()} for r in best_results]
best_results

[{'accuracy': 0.6172980666160583,
  'loss': 1.5192617177963257,
  'epoch': 9,
  'layer_num': 9,
  'lr': 0.0001,
  'weighted_sampler': True},
 {'accuracy': 0.6075324416160583,
  'loss': 1.6313867568969727,
  'epoch': 13,
  'layer_num': 9,
  'lr': 0.001,
  'weighted_sampler': True},
 {'accuracy': 0.6051739454269409,
  'loss': 2.1101293563842773,
  'epoch': 11,
  'layer_num': 9,
  'lr': 0.001,
  'weighted_sampler': True},
 {'accuracy': 0.6032208204269409,
  'loss': 1.8017871379852295,
  'epoch': 11,
  'layer_num': 2,
  'lr': 0.001,
  'weighted_sampler': True},
 {'accuracy': 0.6008623242378235,
  'loss': 0.6885821223258972,
  'epoch': 0,
  'layer_num': 2,
  'lr': 0.01,
  'weighted_sampler': True},
 {'accuracy': 0.588738203048706,
  'loss': 0.68015056848526,
  'epoch': 0,
  'layer_num': 0,
  'lr': 0.01,
  'weighted_sampler': True},
 {'accuracy': 0.5198997855186462,
  'loss': 498.664306640625,
  'epoch': 17,
  'layer_num': 6,
  'lr': 0.001,
  'weighted_sampler': False},
 {'accuracy': 0.44055

In [132]:
with open(f"{TASK}_model_best.json","w") as f:
    json.dump(best_results,f,indent=2)

# Submition

In [154]:
path_test = "../data/labeled-test.model-aware.json"
test_df = pd.read_json(path_test)
test_df["label"] = test_df["label"].astype('category')
tasks = test_df["task"].unique()
tasks

array(['DM', 'MT', 'PG'], dtype=object)

In [157]:
task_results = pd.DataFrame({
        "id":[],
        "label":[],
        "p(Hallucinationa)":[]
    })
for task in tasks:
    current_df = test_df[test_df["task"]==task]
    model_name = current_df["model"].unique().item()
    print(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    hiddens,_ = infer_df(model,tokenizer,current_df)
    
    classifier = StatesClassifier().to(device)    
    classifier.load_state_dict(torch.load(f"{task}_model_best.pt"))

    with open(f"{task}_model_best.json","r"):
        best_params = json.load(f)
    layer_num = best_params[0]["layer_num"]

    data_loader = torch.utils.data.DataLoader(hiddens[layer_num], batch_size=64, shuffle=False)
    result_p = []
    for i, data in enumerate(data_loader):
        outputs = classifier(data)
        result_p.extend(outputs.sigmoid().tolist())
    result_l = [test_df["label"].cat.categories[p] for p in result_p]
    task_results = pd.conсat(task_results,
    new_df = pd.DataFrame({
        "id":current_df["id"],
        "label":result_l,
        "p(Hallucination)":result_p
    }),ignore_index=True)


ltg/flan-t5-definition-en-base


100%|██████████| 36/36 [00:06<00:00,  5.38it/s]


RuntimeError: Error(s) in loading state_dict for StatesClassifier:
	Missing key(s) in state_dict: "ln1.weight", "ln1.bias", "ln2.weight", "ln2.bias", "ln3.weight", "ln3.bias", "head.weight", "head.bias". 
	Unexpected key(s) in state_dict: "shared.weight", "encoder.embed_tokens.weight", "encoder.block.0.layer.0.SelfAttention.q.weight", "encoder.block.0.layer.0.SelfAttention.k.weight", "encoder.block.0.layer.0.SelfAttention.v.weight", "encoder.block.0.layer.0.SelfAttention.o.weight", "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", "encoder.block.0.layer.0.layer_norm.weight", "encoder.block.0.layer.1.DenseReluDense.wi_0.weight", "encoder.block.0.layer.1.DenseReluDense.wi_1.weight", "encoder.block.0.layer.1.DenseReluDense.wo.weight", "encoder.block.0.layer.1.layer_norm.weight", "encoder.block.1.layer.0.SelfAttention.q.weight", "encoder.block.1.layer.0.SelfAttention.k.weight", "encoder.block.1.layer.0.SelfAttention.v.weight", "encoder.block.1.layer.0.SelfAttention.o.weight", "encoder.block.1.layer.0.layer_norm.weight", "encoder.block.1.layer.1.DenseReluDense.wi_0.weight", "encoder.block.1.layer.1.DenseReluDense.wi_1.weight", "encoder.block.1.layer.1.DenseReluDense.wo.weight", "encoder.block.1.layer.1.layer_norm.weight", "encoder.block.2.layer.0.SelfAttention.q.weight", "encoder.block.2.layer.0.SelfAttention.k.weight", "encoder.block.2.layer.0.SelfAttention.v.weight", "encoder.block.2.layer.0.SelfAttention.o.weight", "encoder.block.2.layer.0.layer_norm.weight", "encoder.block.2.layer.1.DenseReluDense.wi_0.weight", "encoder.block.2.layer.1.DenseReluDense.wi_1.weight", "encoder.block.2.layer.1.DenseReluDense.wo.weight", "encoder.block.2.layer.1.layer_norm.weight", "encoder.block.3.layer.0.SelfAttention.q.weight", "encoder.block.3.layer.0.SelfAttention.k.weight", "encoder.block.3.layer.0.SelfAttention.v.weight", "encoder.block.3.layer.0.SelfAttention.o.weight", "encoder.block.3.layer.0.layer_norm.weight", "encoder.block.3.layer.1.DenseReluDense.wi_0.weight", "encoder.block.3.layer.1.DenseReluDense.wi_1.weight", "encoder.block.3.layer.1.DenseReluDense.wo.weight", "encoder.block.3.layer.1.layer_norm.weight", "encoder.block.4.layer.0.SelfAttention.q.weight", "encoder.block.4.layer.0.SelfAttention.k.weight", "encoder.block.4.layer.0.SelfAttention.v.weight", "encoder.block.4.layer.0.SelfAttention.o.weight", "encoder.block.4.layer.0.layer_norm.weight", "encoder.block.4.layer.1.DenseReluDense.wi_0.weight", "encoder.block.4.layer.1.DenseReluDense.wi_1.weight", "encoder.block.4.layer.1.DenseReluDense.wo.weight", "encoder.block.4.layer.1.layer_norm.weight", "encoder.block.5.layer.0.SelfAttention.q.weight", "encoder.block.5.layer.0.SelfAttention.k.weight", "encoder.block.5.layer.0.SelfAttention.v.weight", "encoder.block.5.layer.0.SelfAttention.o.weight", "encoder.block.5.layer.0.layer_norm.weight", "encoder.block.5.layer.1.DenseReluDense.wi_0.weight", "encoder.block.5.layer.1.DenseReluDense.wi_1.weight", "encoder.block.5.layer.1.DenseReluDense.wo.weight", "encoder.block.5.layer.1.layer_norm.weight", "encoder.block.6.layer.0.SelfAttention.q.weight", "encoder.block.6.layer.0.SelfAttention.k.weight", "encoder.block.6.layer.0.SelfAttention.v.weight", "encoder.block.6.layer.0.SelfAttention.o.weight", "encoder.block.6.layer.0.layer_norm.weight", "encoder.block.6.layer.1.DenseReluDense.wi_0.weight", "encoder.block.6.layer.1.DenseReluDense.wi_1.weight", "encoder.block.6.layer.1.DenseReluDense.wo.weight", "encoder.block.6.layer.1.layer_norm.weight", "encoder.block.7.layer.0.SelfAttention.q.weight", "encoder.block.7.layer.0.SelfAttention.k.weight", "encoder.block.7.layer.0.SelfAttention.v.weight", "encoder.block.7.layer.0.SelfAttention.o.weight", "encoder.block.7.layer.0.layer_norm.weight", "encoder.block.7.layer.1.DenseReluDense.wi_0.weight", "encoder.block.7.layer.1.DenseReluDense.wi_1.weight", "encoder.block.7.layer.1.DenseReluDense.wo.weight", "encoder.block.7.layer.1.layer_norm.weight", "encoder.block.8.layer.0.SelfAttention.q.weight", "encoder.block.8.layer.0.SelfAttention.k.weight", "encoder.block.8.layer.0.SelfAttention.v.weight", "encoder.block.8.layer.0.SelfAttention.o.weight", "encoder.block.8.layer.0.layer_norm.weight", "encoder.block.8.layer.1.DenseReluDense.wi_0.weight", "encoder.block.8.layer.1.DenseReluDense.wi_1.weight", "encoder.block.8.layer.1.DenseReluDense.wo.weight", "encoder.block.8.layer.1.layer_norm.weight", "encoder.block.9.layer.0.SelfAttention.q.weight", "encoder.block.9.layer.0.SelfAttention.k.weight", "encoder.block.9.layer.0.SelfAttention.v.weight", "encoder.block.9.layer.0.SelfAttention.o.weight", "encoder.block.9.layer.0.layer_norm.weight", "encoder.block.9.layer.1.DenseReluDense.wi_0.weight", "encoder.block.9.layer.1.DenseReluDense.wi_1.weight", "encoder.block.9.layer.1.DenseReluDense.wo.weight", "encoder.block.9.layer.1.layer_norm.weight", "encoder.block.10.layer.0.SelfAttention.q.weight", "encoder.block.10.layer.0.SelfAttention.k.weight", "encoder.block.10.layer.0.SelfAttention.v.weight", "encoder.block.10.layer.0.SelfAttention.o.weight", "encoder.block.10.layer.0.layer_norm.weight", "encoder.block.10.layer.1.DenseReluDense.wi_0.weight", "encoder.block.10.layer.1.DenseReluDense.wi_1.weight", "encoder.block.10.layer.1.DenseReluDense.wo.weight", "encoder.block.10.layer.1.layer_norm.weight", "encoder.block.11.layer.0.SelfAttention.q.weight", "encoder.block.11.layer.0.SelfAttention.k.weight", "encoder.block.11.layer.0.SelfAttention.v.weight", "encoder.block.11.layer.0.SelfAttention.o.weight", "encoder.block.11.layer.0.layer_norm.weight", "encoder.block.11.layer.1.DenseReluDense.wi_0.weight", "encoder.block.11.layer.1.DenseReluDense.wi_1.weight", "encoder.block.11.layer.1.DenseReluDense.wo.weight", "encoder.block.11.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight", "decoder.embed_tokens.weight", "decoder.block.0.layer.0.SelfAttention.q.weight", "decoder.block.0.layer.0.SelfAttention.k.weight", "decoder.block.0.layer.0.SelfAttention.v.weight", "decoder.block.0.layer.0.SelfAttention.o.weight", "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", "decoder.block.0.layer.0.layer_norm.weight", "decoder.block.0.layer.1.EncDecAttention.q.weight", "decoder.block.0.layer.1.EncDecAttention.k.weight", "decoder.block.0.layer.1.EncDecAttention.v.weight", "decoder.block.0.layer.1.EncDecAttention.o.weight", "decoder.block.0.layer.1.layer_norm.weight", "decoder.block.0.layer.2.DenseReluDense.wi_0.weight", "decoder.block.0.layer.2.DenseReluDense.wi_1.weight", "decoder.block.0.layer.2.DenseReluDense.wo.weight", "decoder.block.0.layer.2.layer_norm.weight", "decoder.block.1.layer.0.SelfAttention.q.weight", "decoder.block.1.layer.0.SelfAttention.k.weight", "decoder.block.1.layer.0.SelfAttention.v.weight", "decoder.block.1.layer.0.SelfAttention.o.weight", "decoder.block.1.layer.0.layer_norm.weight", "decoder.block.1.layer.1.EncDecAttention.q.weight", "decoder.block.1.layer.1.EncDecAttention.k.weight", "decoder.block.1.layer.1.EncDecAttention.v.weight", "decoder.block.1.layer.1.EncDecAttention.o.weight", "decoder.block.1.layer.1.layer_norm.weight", "decoder.block.1.layer.2.DenseReluDense.wi_0.weight", "decoder.block.1.layer.2.DenseReluDense.wi_1.weight", "decoder.block.1.layer.2.DenseReluDense.wo.weight", "decoder.block.1.layer.2.layer_norm.weight", "decoder.block.2.layer.0.SelfAttention.q.weight", "decoder.block.2.layer.0.SelfAttention.k.weight", "decoder.block.2.layer.0.SelfAttention.v.weight", "decoder.block.2.layer.0.SelfAttention.o.weight", "decoder.block.2.layer.0.layer_norm.weight", "decoder.block.2.layer.1.EncDecAttention.q.weight", "decoder.block.2.layer.1.EncDecAttention.k.weight", "decoder.block.2.layer.1.EncDecAttention.v.weight", "decoder.block.2.layer.1.EncDecAttention.o.weight", "decoder.block.2.layer.1.layer_norm.weight", "decoder.block.2.layer.2.DenseReluDense.wi_0.weight", "decoder.block.2.layer.2.DenseReluDense.wi_1.weight", "decoder.block.2.layer.2.DenseReluDense.wo.weight", "decoder.block.2.layer.2.layer_norm.weight", "decoder.block.3.layer.0.SelfAttention.q.weight", "decoder.block.3.layer.0.SelfAttention.k.weight", "decoder.block.3.layer.0.SelfAttention.v.weight", "decoder.block.3.layer.0.SelfAttention.o.weight", "decoder.block.3.layer.0.layer_norm.weight", "decoder.block.3.layer.1.EncDecAttention.q.weight", "decoder.block.3.layer.1.EncDecAttention.k.weight", "decoder.block.3.layer.1.EncDecAttention.v.weight", "decoder.block.3.layer.1.EncDecAttention.o.weight", "decoder.block.3.layer.1.layer_norm.weight", "decoder.block.3.layer.2.DenseReluDense.wi_0.weight", "decoder.block.3.layer.2.DenseReluDense.wi_1.weight", "decoder.block.3.layer.2.DenseReluDense.wo.weight", "decoder.block.3.layer.2.layer_norm.weight", "decoder.block.4.layer.0.SelfAttention.q.weight", "decoder.block.4.layer.0.SelfAttention.k.weight", "decoder.block.4.layer.0.SelfAttention.v.weight", "decoder.block.4.layer.0.SelfAttention.o.weight", "decoder.block.4.layer.0.layer_norm.weight", "decoder.block.4.layer.1.EncDecAttention.q.weight", "decoder.block.4.layer.1.EncDecAttention.k.weight", "decoder.block.4.layer.1.EncDecAttention.v.weight", "decoder.block.4.layer.1.EncDecAttention.o.weight", "decoder.block.4.layer.1.layer_norm.weight", "decoder.block.4.layer.2.DenseReluDense.wi_0.weight", "decoder.block.4.layer.2.DenseReluDense.wi_1.weight", "decoder.block.4.layer.2.DenseReluDense.wo.weight", "decoder.block.4.layer.2.layer_norm.weight", "decoder.block.5.layer.0.SelfAttention.q.weight", "decoder.block.5.layer.0.SelfAttention.k.weight", "decoder.block.5.layer.0.SelfAttention.v.weight", "decoder.block.5.layer.0.SelfAttention.o.weight", "decoder.block.5.layer.0.layer_norm.weight", "decoder.block.5.layer.1.EncDecAttention.q.weight", "decoder.block.5.layer.1.EncDecAttention.k.weight", "decoder.block.5.layer.1.EncDecAttention.v.weight", "decoder.block.5.layer.1.EncDecAttention.o.weight", "decoder.block.5.layer.1.layer_norm.weight", "decoder.block.5.layer.2.DenseReluDense.wi_0.weight", "decoder.block.5.layer.2.DenseReluDense.wi_1.weight", "decoder.block.5.layer.2.DenseReluDense.wo.weight", "decoder.block.5.layer.2.layer_norm.weight", "decoder.block.6.layer.0.SelfAttention.q.weight", "decoder.block.6.layer.0.SelfAttention.k.weight", "decoder.block.6.layer.0.SelfAttention.v.weight", "decoder.block.6.layer.0.SelfAttention.o.weight", "decoder.block.6.layer.0.layer_norm.weight", "decoder.block.6.layer.1.EncDecAttention.q.weight", "decoder.block.6.layer.1.EncDecAttention.k.weight", "decoder.block.6.layer.1.EncDecAttention.v.weight", "decoder.block.6.layer.1.EncDecAttention.o.weight", "decoder.block.6.layer.1.layer_norm.weight", "decoder.block.6.layer.2.DenseReluDense.wi_0.weight", "decoder.block.6.layer.2.DenseReluDense.wi_1.weight", "decoder.block.6.layer.2.DenseReluDense.wo.weight", "decoder.block.6.layer.2.layer_norm.weight", "decoder.block.7.layer.0.SelfAttention.q.weight", "decoder.block.7.layer.0.SelfAttention.k.weight", "decoder.block.7.layer.0.SelfAttention.v.weight", "decoder.block.7.layer.0.SelfAttention.o.weight", "decoder.block.7.layer.0.layer_norm.weight", "decoder.block.7.layer.1.EncDecAttention.q.weight", "decoder.block.7.layer.1.EncDecAttention.k.weight", "decoder.block.7.layer.1.EncDecAttention.v.weight", "decoder.block.7.layer.1.EncDecAttention.o.weight", "decoder.block.7.layer.1.layer_norm.weight", "decoder.block.7.layer.2.DenseReluDense.wi_0.weight", "decoder.block.7.layer.2.DenseReluDense.wi_1.weight", "decoder.block.7.layer.2.DenseReluDense.wo.weight", "decoder.block.7.layer.2.layer_norm.weight", "decoder.block.8.layer.0.SelfAttention.q.weight", "decoder.block.8.layer.0.SelfAttention.k.weight", "decoder.block.8.layer.0.SelfAttention.v.weight", "decoder.block.8.layer.0.SelfAttention.o.weight", "decoder.block.8.layer.0.layer_norm.weight", "decoder.block.8.layer.1.EncDecAttention.q.weight", "decoder.block.8.layer.1.EncDecAttention.k.weight", "decoder.block.8.layer.1.EncDecAttention.v.weight", "decoder.block.8.layer.1.EncDecAttention.o.weight", "decoder.block.8.layer.1.layer_norm.weight", "decoder.block.8.layer.2.DenseReluDense.wi_0.weight", "decoder.block.8.layer.2.DenseReluDense.wi_1.weight", "decoder.block.8.layer.2.DenseReluDense.wo.weight", "decoder.block.8.layer.2.layer_norm.weight", "decoder.block.9.layer.0.SelfAttention.q.weight", "decoder.block.9.layer.0.SelfAttention.k.weight", "decoder.block.9.layer.0.SelfAttention.v.weight", "decoder.block.9.layer.0.SelfAttention.o.weight", "decoder.block.9.layer.0.layer_norm.weight", "decoder.block.9.layer.1.EncDecAttention.q.weight", "decoder.block.9.layer.1.EncDecAttention.k.weight", "decoder.block.9.layer.1.EncDecAttention.v.weight", "decoder.block.9.layer.1.EncDecAttention.o.weight", "decoder.block.9.layer.1.layer_norm.weight", "decoder.block.9.layer.2.DenseReluDense.wi_0.weight", "decoder.block.9.layer.2.DenseReluDense.wi_1.weight", "decoder.block.9.layer.2.DenseReluDense.wo.weight", "decoder.block.9.layer.2.layer_norm.weight", "decoder.block.10.layer.0.SelfAttention.q.weight", "decoder.block.10.layer.0.SelfAttention.k.weight", "decoder.block.10.layer.0.SelfAttention.v.weight", "decoder.block.10.layer.0.SelfAttention.o.weight", "decoder.block.10.layer.0.layer_norm.weight", "decoder.block.10.layer.1.EncDecAttention.q.weight", "decoder.block.10.layer.1.EncDecAttention.k.weight", "decoder.block.10.layer.1.EncDecAttention.v.weight", "decoder.block.10.layer.1.EncDecAttention.o.weight", "decoder.block.10.layer.1.layer_norm.weight", "decoder.block.10.layer.2.DenseReluDense.wi_0.weight", "decoder.block.10.layer.2.DenseReluDense.wi_1.weight", "decoder.block.10.layer.2.DenseReluDense.wo.weight", "decoder.block.10.layer.2.layer_norm.weight", "decoder.block.11.layer.0.SelfAttention.q.weight", "decoder.block.11.layer.0.SelfAttention.k.weight", "decoder.block.11.layer.0.SelfAttention.v.weight", "decoder.block.11.layer.0.SelfAttention.o.weight", "decoder.block.11.layer.0.layer_norm.weight", "decoder.block.11.layer.1.EncDecAttention.q.weight", "decoder.block.11.layer.1.EncDecAttention.k.weight", "decoder.block.11.layer.1.EncDecAttention.v.weight", "decoder.block.11.layer.1.EncDecAttention.o.weight", "decoder.block.11.layer.1.layer_norm.weight", "decoder.block.11.layer.2.DenseReluDense.wi_0.weight", "decoder.block.11.layer.2.DenseReluDense.wi_1.weight", "decoder.block.11.layer.2.DenseReluDense.wo.weight", "decoder.block.11.layer.2.layer_norm.weight", "decoder.final_layer_norm.weight", "lm_head.weight". 