In [None]:
import os
os.chdir("../")

## Description Generation using Flan-T5

In [None]:

import os
import numpy as np
import torch
import torch.nn as nn
import argparse
import json
from utils import process_config

from datasets import load_dataset

from utils import set_seed, create_synthetic_column
import pickle

import sys


from src import TapexModelForConditionalGeneration, TapexModelForMaskedLanguageModelling, TapexModelForSequenceClassification
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

from data import SciGenDataset

from utils import prepare_dataloaders, prepare_models
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings
warnings.filterwarnings("ignore")

from utils import Trainer, Logger, LightningTrainer


from src import BartModelForMaskedLM, BartModelForConditionalGeneration, BartModelForSequenceClassification


from src import compute_metrics

import wandb

from src import GPT2ModelForConditionalGeneration, T5ModelForConditionalGeneration

## Column Reasoning using Flan T5

In [None]:
import torch
import torch.nn as nn

from src import T5ModelForConditionalGeneration
from datasets import load_dataset
import json

from utils import process_config
import pandas as pd

In [None]:
with open("configs/column_reasoning/flant5.json", "r") as f:
    config = json.load(f)

In [None]:
config = process_config(config)

In [None]:
dataset = load_dataset(config.data.data_path)

In [None]:
from data import SciGenDataset
from utils import create_synthetic_column

In [None]:
dataset = create_synthetic_column(dataset, "test")

In [None]:
test_dataset = SciGenDataset(dataset, config, data_type = "test")

In [None]:
tokenizer = test_dataset.tokenizer

In [None]:
model = T5ModelForConditionalGeneration(config)
model.load_state_dict(torch.load("logs/column_reasoning_flant5/checkpoints/epoch=5.pt"))

In [None]:
def predict(index):

    batch = test_dataset.__getitem__(index)

    input_ids, attention_mask, token_type_ids, decoder_input_ids, labels = batch
    actual_output_ids = labels.clone()
    output_ids = model.model.generate(input_ids = input_ids.unsqueeze(0), max_new_tokens = config.tokenizer.output_max_length, 
                                            num_beams = 3, early_stopping = True, attention_mask = attention_mask.unsqueeze(0))

    print("Input sequence: \t\t", tokenizer.decode(input_ids, skip_special_tokens = True), end = "\n\n")
    print("Actual output: \t\t", tokenizer.decode(labels, skip_special_tokens = True), end = "\n\n")
    print("Predicted output: \t", tokenizer.decode(output_ids.squeeze(0), skip_special_tokens = True), end = "\n\n")

In [None]:
predict(0)

In [None]:
predict(1)

In [None]:
predict(4)

In [None]:
predict(6)

In [None]:
predict(15)

In [None]:
predict(21)

In [None]:
predict(27)

In [None]:
predict(50)

## TAPEX with MLM Checkpoint Accuracy evaluation

In [None]:
from datasets import load_dataset
import torch
from utils import process_config
import json

In [None]:
with open("configs/wiki_tq/tapex.json", "r") as f:
    config = json.load(f)

In [None]:
config = process_config(config)

In [None]:
dataset = load_dataset(config.data.data_path)

In [None]:
from data import WikiTQDataset

In [None]:
test_dataset = WikiTQDataset(dataset, config, data_type="test")

In [None]:
tokenizer = test_dataset.tokenizer

In [None]:
from src import BartModelForGenerativeQuestionAnswering

In [None]:
model = BartModelForGenerativeQuestionAnswering(config)
model.load_state_dict(torch.load("logs/table_question_answering_tapex_mlm_pretrained_epochs30/checkpoints/epoch=15.pt"))

In [None]:
total = 0
correct = 0

In [None]:
import os
import re
import sys
import argparse
import unicodedata
import numpy as np
import pandas as pd
from tqdm import tqdm
from codecs import open
from math import isnan, isinf
from easydict import EasyDict
from torch.utils.data import DataLoader
from abc import ABCMeta, abstractmethod

def normalize(x):

    if not isinstance(x, str):
        x = x.decode('utf8', errors='ignore')

    # Remove diacritics
    x = ''.join(c for c in unicodedata.normalize('NFKD', x)
                if unicodedata.category(c) != 'Mn')
    
    # Normalize quotes and dashes
    x = re.sub(r"[‘’´`]", "'", x)
    x = re.sub(r"[“”]", "\"", x)
    x = re.sub(r"[‐‑‒–—−]", "-", x)

    while True:
        
        old_x = x

        # Remove citations
        x = re.sub(r"((?<!^)\[[^\]]*\]|\[\d+\]|[•♦†‡*#+])*$", "", x.strip())
        
        # Remove details in parenthesis
        x = re.sub(r"(?<!^)( \([^)]*\))*$", "", x.strip())
        
        # Remove outermost quotation mark
        x = re.sub(r'^"([^"]*)"$', r'\1', x.strip())
        
        if x == old_x:
            break
    
    # Remove final '.'
    if x and x[-1] == '.':
        x = x[:-1]
    
    # Collapse whitespaces and convert to lower case
    x = re.sub(r'\s+', ' ', x, flags=re.U).lower().strip()
    
    return x


class Value(object):
    __metaclass__ = ABCMeta

    # Should be populated with the normalized string
    _normalized = None

    @abstractmethod
    def match(self, other):
        """Return True if the value matches the other value.

        Args:
            other (Value)
        Returns:
            a boolean
        """
        pass

    @property
    def normalized(self):
        return self._normalized


class StringValue(Value):

    def __init__(self, content):
        assert isinstance(content, str)
        self._normalized = normalize(content)
        self._hash = hash(self._normalized)

    def __eq__(self, other):
        return isinstance(other, StringValue) and self.normalized == other.normalized

    def __hash__(self):
        return self._hash

    def __str__(self):
        return 'S' + str([self.normalized])

    __repr__ = __str__

    def match(self, other):
        assert isinstance(other, Value)
        return self.normalized == other.normalized


class NumberValue(Value):

    def __init__(self, amount, original_string=None):
        assert isinstance(amount, (int, float))
        if abs(amount - round(amount)) < 1e-6:
            self._amount = int(amount)
        else:
            self._amount = float(amount)
        if not original_string:
            self._normalized = str(self._amount)
        else:
            self._normalized = normalize(original_string)
        self._hash = hash(self._amount)

    @property
    def amount(self):
        return self._amount

    def __eq__(self, other):
        return isinstance(other, NumberValue) and self.amount == other.amount

    def __hash__(self):
        return self._hash

    def __str__(self):
        return ('N(%f)' % self.amount) + str([self.normalized])

    __repr__ = __str__

    def match(self, other):
        assert isinstance(other, Value)
        if self.normalized == other.normalized:
            return True
        if isinstance(other, NumberValue):
            return abs(self.amount - other.amount) < 1e-6
        return False

    @staticmethod
    def parse(text):
        """Try to parse into a number.

        Return:
            the number (int or float) if successful; otherwise None.
        """
        try:
            return int(text)
        except:
            try:
                amount = float(text)
                assert not isnan(amount) and not isinf(amount)
                return amount
            except:
                return None


class DateValue(Value):

    def __init__(self, year, month, day, original_string=None):

        """Create a new DateValue. Placeholders are marked as -1."""
        assert isinstance(year, int)
        assert isinstance(month, int) and (month == -1 or 1 <= month <= 12)
        assert isinstance(day, int) and (day == -1 or 1 <= day <= 31)
        assert not (year == month == day == -1)
        
        self._year = year
        self._month = month
        self._day = day
        
        if not original_string:
            self._normalized = '{}-{}-{}'.format(
                year if year != -1 else 'xx',
                month if month != -1 else 'xx',
                day if day != '-1' else 'xx')
        else:
            self._normalized = normalize(original_string)
        
        self._hash = hash((self._year, self._month, self._day))

    @property
    def ymd(self):
        return (self._year, self._month, self._day)

    def __eq__(self, other):
        return isinstance(other, DateValue) and self.ymd == other.ymd

    def __hash__(self):
        return self._hash

    def __str__(self):
        return (('D(%d,%d,%d)' % (self._year, self._month, self._day))
                + str([self._normalized]))

    __repr__ = __str__

    def match(self, other):
        
        assert isinstance(other, Value)
        
        if self.normalized == other.normalized:
            return True
        
        if isinstance(other, DateValue):
            return self.ymd == other.ymd
        
        return False

    @staticmethod
    def parse(text):
        """Try to parse into a date.

        Return:
            tuple (year, month, date) if successful; otherwise None.
        """
        try:
            ymd = text.lower().split('-')
            assert len(ymd) == 3
            year = -1 if ymd[0] in ('xx', 'xxxx') else int(ymd[0])
            month = -1 if ymd[1] == 'xx' else int(ymd[1])
            day = -1 if ymd[2] == 'xx' else int(ymd[2])
            assert not (year == month == day == -1)
            assert month == -1 or 1 <= month <= 12
            assert day == -1 or 1 <= day <= 31
            return (year, month, day)
        except:
            return None


def to_value(original_string, corenlp_value=None):
    """Convert the string to Value object.

    Args:
        original_string (basestring): Original string
        corenlp_value (basestring): Optional value returned from CoreNLP
    Returns:
        Value
    """

    if isinstance(original_string, Value):
        # Already a Value
        return original_string
    
    if not corenlp_value:
        corenlp_value = original_string
    
    # Number?
    amount = NumberValue.parse(corenlp_value)
    
    if amount is not None:
        return NumberValue(amount, original_string)
    
    # Date?
    ymd = DateValue.parse(corenlp_value)
    if ymd is not None:
        if ymd[1] == ymd[2] == -1:
            return NumberValue(ymd[0], original_string)
        else:
            return DateValue(ymd[0], ymd[1], ymd[2], original_string)
    
    # String.
    return StringValue(original_string)


def to_value_list(original_strings, corenlp_values=None):
    """Convert a list of strings to a list of Values

    Args:
        original_strings (list[basestring])
        corenlp_values (list[basestring or None])
    Returns:
        list[Value]
    """
    assert isinstance(original_strings, (list, tuple, set))
    if corenlp_values is not None:
        assert isinstance(corenlp_values, (list, tuple, set))
        assert len(original_strings) == len(corenlp_values)
        return list(set(to_value(x, y) for (x, y)
                        in zip(original_strings, corenlp_values)))
    else:
        return list(set(to_value(x) for x in original_strings))


def check_denotation(target_values, predicted_values):
    """Return True if the predicted denotation is correct.

    Args:
        target_values (list[Value])
        predicted_values (list[Value])
    Returns:
        bool
    """
    
    # Check size
    if len(target_values) != len(predicted_values):
        return False
    
    # Check items
    for target in target_values:
        if not any(target.match(pred) for pred in predicted_values):
            return False
    
    return True


def tsv_unescape(x):
    """Unescape strings in the TSV file.
    Escaped characters include:
        newline (0x10) -> backslash + n
        vertical bar (0x7C) -> backslash + p
        backslash (0x5C) -> backslash + backslash

    Args:
        x (str or unicode)
    Returns:
        a unicode
    """
    return x.replace(r'\n', '\n').replace(r'\p', '|').replace('\\\\', '\\')


def tsv_unescape_list(x):
    """Unescape a list in the TSV file.
    List items are joined with vertical bars (0x5C)

    Args:
        x (str or unicode)
    Returns:
        a list of unicodes
    """
    return [tsv_unescape(y) for y in x.split('|')]

In [None]:
model.to("cuda:0")

In [None]:
def evaluate(index):

    input_ids, attention_mask, token_type_ids, decoder_input_ids, labels = test_dataset.__getitem__(index)
    actual_output_ids = decoder_input_ids.clone()
    output_ids = model.model.generate(input_ids = input_ids.unsqueeze(0).to("cuda:0"), max_new_tokens = config.tokenizer.output_max_length, 
                                                num_beams = 3, early_stopping = True, attention_mask = attention_mask.unsqueeze(0).to("cuda:0")).squeeze().detach().cpu()

    predicted_sequence = tokenizer.decode(output_ids, skip_special_tokens=True)
    actual_sequence = tokenizer.decode(actual_output_ids, skip_special_tokens = True)

    pred = to_value_list([predicted_sequence])
    gold = to_value_list([actual_sequence])

    verdict = check_denotation(pred, gold)

    return verdict

In [None]:
correct = 0
total = 0

In [None]:
for i in tqdm(range(test_dataset.__len__()), position = 0, leave = True, total = test_dataset.__len__()):
    verdict = evaluate(i)
    if verdict:
        correct += 1

    total += 1

In [None]:
correct / total

In [None]:
model = BartModelForGenerativeQuestionAnswering(config)
model.load_state_dict(torch.load("logs/table_question_answering_tapex_mlm_pretrained_epochs30/checkpoints/epoch=20.pt"))

In [None]:
model.to("cuda:1")

In [None]:
def evaluate(index):

    input_ids, attention_mask, token_type_ids, decoder_input_ids, labels = test_dataset.__getitem__(index)
    actual_output_ids = decoder_input_ids.clone()
    output_ids = model.model.generate(input_ids = input_ids.unsqueeze(0).to("cuda:1"), max_new_tokens = config.tokenizer.output_max_length, 
                                                num_beams = 3, early_stopping = True, attention_mask = attention_mask.unsqueeze(0).to("cuda:1")).squeeze().detach().cpu()

    predicted_sequence = tokenizer.decode(output_ids, skip_special_tokens=True)
    actual_sequence = tokenizer.decode(actual_output_ids, skip_special_tokens = True)

    pred = to_value_list([predicted_sequence])
    gold = to_value_list([actual_sequence])

    verdict = check_denotation(pred, gold)

    return verdict

In [None]:
correct = 0
total = 0

In [None]:
for i in tqdm(range(test_dataset.__len__()), position = 0, leave = True, total = test_dataset.__len__()):
    verdict = evaluate(i)
    if verdict:
        correct += 1

    total += 1

In [None]:
correct / total

In [None]:
model = BartModelForGenerativeQuestionAnswering(config)
model.load_state_dict(torch.load("logs/table_question_answering_tapex_mlm_pretrained_epochs30/checkpoints/epoch=10.pt"))

In [None]:
model.to("cuda:3")

In [None]:
def evaluate(index):

    input_ids, attention_mask, token_type_ids, decoder_input_ids, labels = test_dataset.__getitem__(index)
    actual_output_ids = decoder_input_ids.clone()
    output_ids = model.model.generate(input_ids = input_ids.unsqueeze(0).to("cuda:3"), max_new_tokens = config.tokenizer.output_max_length, 
                                                num_beams = 3, early_stopping = True, attention_mask = attention_mask.unsqueeze(0).to("cuda:3")).squeeze().detach().cpu()

    predicted_sequence = tokenizer.decode(output_ids, skip_special_tokens=True)
    actual_sequence = tokenizer.decode(actual_output_ids, skip_special_tokens = True)

    print(predicted_sequence)
    print(actual_sequence)

    # return False

    pred = to_value_list([predicted_sequence])
    gold = to_value_list([actual_sequence])

    verdict = check_denotation(pred, gold)

    return verdict

In [None]:
correct = 0
total = 0

In [None]:
for i in tqdm(range(test_dataset.__len__()), position = 0, leave = True, total = test_dataset.__len__()):
    verdict = evaluate(i)
    if verdict:
        correct += 1

    total += 1

In [None]:
correct / total

In [None]:
model = BartModelForGenerativeQuestionAnswering(config)
model.load_state_dict(torch.load("logs/table_question_answering_tapex_mlm_pretrained_epochs30/checkpoints/epoch=10.pt"))

In [None]:
model.to("cuda:0")

In [None]:
def evaluate(index):

    input_ids, attention_mask, token_type_ids, decoder_input_ids, labels = test_dataset.__getitem__(index)
    actual_output_ids = decoder_input_ids.clone()
    output_ids = model.model.generate(input_ids = input_ids.unsqueeze(0).to("cuda:3"), max_new_tokens = config.tokenizer.output_max_length, 
                                                num_beams = 3, early_stopping = True, attention_mask = attention_mask.unsqueeze(0).to("cuda:3")).squeeze().detach().cpu()

    predicted_sequence = tokenizer.decode(output_ids, skip_special_tokens=True)
    actual_sequence = tokenizer.decode(actual_output_ids, skip_special_tokens = True)

    # print(predicted_sequence)
    # return predicted_sequence
    # print(actual_sequence)

    pred = to_value_list([predicted_sequence])
    gold = to_value_list([actual_sequence])

    verdict = check_denotation(pred, gold)

    return pred, gold, verdict

In [None]:
output_dict = {}

In [None]:
for i in tqdm(range(test_dataset.__len__()), position = 0, leave = True, total = test_dataset.__len__()):
    pred, gold, verdict = evaluate(i)
    output_dict[i] = {"pred": pred, "gold": gold}

In [None]:
x = {}

In [None]:
for key, val in output_dict.items():
    x[str(key)] = val

In [None]:
x["0"]

In [None]:
import pickle
with open("tapex_mlm_ckpt_wikitq_preds_epoch10.pkl", "wb") as f:
    pickle.dump(x, f)

In [None]:
import json
with open("tapex_mlm_ckpt_wikitq_preds_epoch10.json", "w") as f:
    json.dump(x, f)

In [None]:
from utils import process_config

In [None]:
from data import WikiTQDataset

In [None]:
import json
with open("configs/wiki_tq/tapex.json", "r") as f:
    config = json.load(f)
config = process_config(config)

In [None]:
import json
with open("tapex_mlm_ckpt_wikitq_preds_epoch10.json", "r") as f:
    l = json.load(f)

In [None]:
from datasets import load_dataset
dataset = load_dataset(config.data.data_path)
test_dataset = WikiTQDataset(dataset, config, data_type="test")

In [None]:
tokenizer = test_dataset.tokenizer

In [None]:
total = 0
correct = 0

In [None]:
import torch

In [None]:
from src import BartModelForGenerativeQuestionAnswering
model = BartModelForGenerativeQuestionAnswering(config)
model.load_state_dict(torch.load("logs/table_question_answering_tapex_mlm_pretrained_epochs30/checkpoints/epoch=10.pt"))

In [None]:
model.to("cuda:3")

In [None]:
for i in tqdm(range(test_dataset.__len__()), position = 0, leave = True, total = test_dataset.__len__()):
    verdict = evaluate(i)
    if verdict:
        correct += 1

    total += 1

In [None]:
correct / total

# ROW COL Embeddings Code fix

In [None]:
from transformers import AutoTokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/tapex-large-finetuned-wtq", )

In [None]:
text = "<s> how many runs batted in did darren daulton have? col : name | season(s) | position(s) | notes row 1 : omar daal | 2000–2001 | pitcher |  15–16 record\n4.52 earned run average\n158 row 2 : babe dahlgren | 1943 | first baseman | .287 batting average\n5 home runs\n56 runs batted in row 3 : sam dailey | 1929 | pitcher |  2–2 record\n7.54 earned run average\n18 row 4 : ed daily | 1885–1887 | outfielder\npitcher | .230 batting average\n6 home runs\n42–36 record row 5 : clay dalrymple | 1960–1968 | catcher | .234 batting average\n50 home runs\n312 runs batted in row 6 : tony daniels | 1945 | second baseman | .200 batting average\n2 triples\n10 runs batted in row 7 : alvin dark | 1960 | third baseman | .242 batting average\n3 home runs\n14 runs batted in row 8 : george darrow | 1934 | pitcher |  2–6 record\n5.51 earned run average\n14 row 9 : darren daulton§ | 1983\n1985–1997 | catcher | .245 batting average\n134 home runs\n567 runs batted row 10 : curt davis | 1934–1936 | pitcher |  37–35 record\n3.42 earned run average\n191 row 11 : dick davis | 1981–1982 | right fielder | .311 batting average\n4 home runs\n26 runs batted in row 12 : dixie davis | 1918 | pitcher |  0–2 record\n3.06 earned run average\n18 row 13 : jacke davis | 1962 | left fielder | .213 batting average\n1 home run\n6 runs batted in row 14 : kane davis | 2007 | pitcher |  0–1 record\n5.56 earned run average\n10 row 15 : kiddo davis | 1932\n1934 | center fielder | .302 batting average\n11 triples\n105 runs batted in row 16 : mark davis | 1980–1981\n1993 | pitcher |  2–6 record\n6.31 earned run average\n62 row 17 : spud davis | 1928–1933\n1938–1939 | catcher | .321 batting average\n53 home runs\n363 runs batted in row 18 : bill dawley | 1988 | pitcher |  0–2 record\n13.50 earned run average\n3 row 19 : bill day | 1889–1890 | pitcher |  1–4 record\n4.10 earned run average\n29 row 20 : justin de fratus | 2011 | pitcher |  1–0 record\n2.25 earned run average\n3 row 21 : valerio de los santos | 2003 | pitcher |  1–0 record\n9.00 earned run average\n4 row 22 : wayland dean | 1926–1927 | pitcher |  8–17 record\n5.01 earned run average\n53 row 23 : art decatur | 1925–1927 | pitcher |  7–18 record\n6.18 earned run average\n58 row 24 : harry decker | 1889–1890 | second baseman |.204 batting average\n1 double\n4 runs batted in row 25 : pep deininger | 1908–1909 | center fielder |.260 batting average\n9 doubles\n16 runs batted in row 26 : bill deitrick | 1927–1928 | left fielder\nshortstop |.198 batting average\n6 doubles\n7 runs batted in row 27 : iván dejesús | 1982–1984 | shortstop | .249 batting average\n15 triples\n139 runs batted in row 28 : josé dejesús | 1990–1991 | pitcher |  17–17 record\n3.55 earned run average\n205 row 29 : bobby del greco | 1960–1961\n1965 | center fielder | .240 batting average\n12 home runs\n37 runs batted in row 30 : garton del savio | 1943 | shortstop |.091 batting average\n12 plate appearances\n1 walk row</s>"

In [None]:
tokenized_text = tokenizer.tokenize(text)

In [None]:
len(tokenized_text)

In [None]:
tokenized_text

In [None]:
from datasets import load_dataset

In [None]:
train_dataset = load_dataset("wikitablequestions")["train"]

In [None]:
question = train_dataset[11063]["question"]
table = train_dataset[11063]["table"]

In [None]:
question

In [None]:
import pandas as pd

In [None]:
table_column_names = table["header"]
table_content_values = table["rows"]
table = pd.DataFrame.from_dict({str(col).lower(): [str(table_content_values[j][i]).lower() for j in range(len(table_content_values))] for i, col in enumerate(table_column_names)})

In [None]:
display(table)

In [None]:
tokenized_input = tokenizer(table, question, add_special_tokens = True, padding = "max_length", truncation = True, max_length = 960, 
          return_tensors = "pt", return_token_type_ids = True, return_attention_mask = True)

In [None]:
text = tokenizer.decode(tokenized_input["input_ids"].squeeze())

In [None]:
tokenized_text = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"].squeeze())

In [None]:
len(tokenized_text)

# Random

In [None]:
from utils import process_config
from src import EEDBartModelForGenerativeQuestionAnswering
import json

In [None]:
with open("configs/wiki_tq/tapex.json", "r") as f:
    config = json.load(f)

In [None]:
config = process_config(config)

In [None]:
model = EEDBartModelForGenerativeQuestionAnswering(config)

In [None]:
learning_rate = 1e-5
layerwise_learning_rate_decay = 0.9
weight_decay = config.training.weight_decay


In [None]:
no_decay = ["bias"]
# initialize lr for task specific layer
optimizer_grouped_parameters = [
    {
        "params": [n for n, p in model.model.lm_head.named_parameters()] + [n for n, p in model.named_parameters() if "shared" in n] + [n for n, p in model.named_parameters() if "token_classifier" in n],
        # "params": [p for n, p in model.named_parameters() if "shared" in n] + [p for n, p in model.named_parameters() if "token_classifier" in n],
        "weight_decay": 0.0,
        "lr": learning_rate,
    },
]


optimizer_grouped_parameters += [
    {
        "params": [n for n, p in model.model.model.decomposer.layernorm_embedding.named_parameters()] + [n for n, p in model.model.model.encoder.layernorm_embedding.named_parameters()] + [n for n, p in model.model.model.decoder.layernorm_embedding.named_parameters()],
        # "params": [p for n, p in model.named_parameters() if "shared" in n] + [p for n, p in model.named_parameters() if "token_classifier" in n],
        "weight_decay": 0.0,
        "lr": learning_rate,
    },
]

optimizer_grouped_parameters += [
    {
        "params": [n for n, p in model.named_parameters() if "bias" in n],
        "weight_decay": 0.0,
        "lr": learning_rate,
    },
]

# decomposer_layers = [getattr(model.model.model, "decomposer").embed_tokens] + [getattr(model.model.model, "decomposer").embed_positions] \
#                         + list(getattr(model.model.model, "decomposer").layers)

decomposer_layers = [getattr(model.model.model, "decomposer").embed_positions] + list(getattr(model.model.model, "decomposer").layers)
# decomposer_layers = list(getattr(model.model.model, "decomposer").layers)
# decomposer_layers = [getattr(model.model.model, "decomposer").embed_positions] + list(getattr(model.model.model, "decomposer").layers)


# encoder_layers = [getattr(model.model.model, "encoder").embed_tokens] + [getattr(model.model.model, "encoder").embed_positions] \
#                         + list(getattr(model.model.model, "encoder").layers)

encoder_layers = [getattr(model.model.model, "encoder").embed_positions] + list(getattr(model.model.model, "encoder").layers)
# encoder_layers = list(getattr(model.model.model, "encoder").layers)

# decoder_layers = [getattr(model.model.model, "decoder").embed_tokens] + [getattr(model.model.model, "decoder").embed_positions] \
                        # + list(getattr(model.model.model, "decoder").layers)

decoder_layers = [getattr(model.model.model, "decoder").embed_positions] + list(getattr(model.model.model, "decoder").layers)
# decoder_layers = list(getattr(model.model.model, "decoder").layers)

# decomposer_layers = list(getattr(model.model, "model").decomposer)
# encoder_layers = list(getattr(model.model, "model").encoder)
# decoder_layers = list(getattr(model.model, "model").decoder)
# # layers = [getattr(model, model_type).embeddings] + list(getattr(model, model_type).encoder.layer)



# # layers.reverse()
decomposer_layers.reverse()
encoder_layers.reverse()
decoder_layers.reverse()

lr = learning_rate
for layer in decomposer_layers:
    # if layer == "BartLearnedPositionalEmbedding":
    #     print("\n\nHere\n\n")
    # print(layer)
    print([
        {
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        },
        {
            "params": [n for n, p in layer.named_parameters() if "bias" in n],
            "weight_decay": 0.0,
            "lr": lr,
        },
    ])
    
    lr *= layerwise_learning_rate_decay
    optimizer_grouped_parameters += [
        {
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        },
        {
            "params": [n for n, p in layer.named_parameters() if "bias" in n],
            "weight_decay": 0.0,
            "lr": lr,
        },
    ]

lr = learning_rate
for layer in encoder_layers:
    lr *= layerwise_learning_rate_decay
    optimizer_grouped_parameters += [
        {
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        },
        {
            "params": [n for n, p in layer.named_parameters() if "bias" in n],
            "weight_decay": 0.0,
            "lr": lr,
        },
    ]

lr = learning_rate
for layer in decoder_layers:
    lr *= layerwise_learning_rate_decay
    optimizer_grouped_parameters += [
        {
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        },
        {
            "params": [n for n, p in layer.named_parameters() if "bias" in n],
            "weight_decay": 0.0,
            "lr": lr,
        },
    ]

In [None]:
for x in optimizer_grouped_parameters:
    print(x)

In [None]:
decoder_layers

In [None]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

In [None]:
x = [getattr(model.model.model, "encoder").embed_positions]

In [None]:
for layer in x:
    lr *= layerwise_learning_rate_decay
    print(layer)
    print({
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        })

In [None]:
for name, param in model.named_parameters():
    if "bias" in name:
        print(name)

In [None]:
for name, param in model.model.model.encoder.named_parameters():
    if "bias" in name:
        print(name)

In [None]:
print([
            {
                "params": [n for n, p in model.model.model.decomposer.named_parameters() if "bias" in n] + [n for n, p in model.model.model.encoder.named_parameters() if "bias" in n] + [n for n, p in model.model.model.decoder.named_parameters() if "bias" in n],
                "weight_decay": 0.0,
                "lr": learning_rate,
            },
        ])

In [None]:
no_decay = ["bias"]
# initialize lr for task specific layer
optimizer_grouped_parameters = [
    {
        "params": [n for n, p in model.model.lm_head.named_parameters()] + [n for n, p in model.named_parameters() if "shared" in n] + [n for n, p in model.named_parameters() if "token_classifier" in n],
        # "params": [p for n, p in model.named_parameters() if "shared" in n] + [p for n, p in model.named_parameters() if "token_classifier" in n],
        "weight_decay": 0.0,
        "lr": learning_rate,
    },
]


optimizer_grouped_parameters += [
    {
        "params": [n for n, p in model.model.model.decomposer.layernorm_embedding.named_parameters() if "bias" not in n] + [n for n, p in model.model.model.encoder.layernorm_embedding.named_parameters() if "bias" not in n] + [n for n, p in model.model.model.decoder.layernorm_embedding.named_parameters() if "bias" not in n],
        # "params": [p for n, p in model.named_parameters() if "shared" in n] + [p for n, p in model.named_parameters() if "token_classifier" in n],
        "weight_decay": 0.0,
        "lr": learning_rate,
    },
]

optimizer_grouped_parameters += [
    {
        "params": [n for n, p in model.model.model.decomposer.named_parameters() if "bias" in n] + [n for n, p in model.model.model.encoder.named_parameters() if "bias" in n] + [n for n, p in model.model.model.decoder.named_parameters() if "bias" in n],
        "weight_decay": 0.0,
        "lr": learning_rate,
    },
]

# decomposer_layers = [getattr(model.model.model, "decomposer").embed_tokens] + [getattr(model.model.model, "decomposer").embed_positions] \
#                         + list(getattr(model.model.model, "decomposer").layers)

decomposer_layers = [getattr(model.model.model, "decomposer").embed_positions] + list(getattr(model.model.model, "decomposer").layers)
# decomposer_layers = list(getattr(model.model.model, "decomposer").layers)


# encoder_layers = [getattr(model.model.model, "encoder").embed_tokens] + [getattr(model.model.model, "encoder").embed_positions] \
#                         + list(getattr(model.model.model, "encoder").layers)

encoder_layers = [getattr(model.model.model, "encoder").embed_positions] + list(getattr(model.model.model, "encoder").layers)
# encoder_layers = list(getattr(model.model.model, "encoder").layers)

# decoder_layers = [getattr(model.model.model, "decoder").embed_tokens] + [getattr(model.model.model, "decoder").embed_positions] \
                        # + list(getattr(model.model.model, "decoder").layers)

decoder_layers = [getattr(model.model.model, "decoder").embed_positions] + list(getattr(model.model.model, "decoder").layers)
# decoder_layers = list(getattr(model.model.model, "decoder").layers)

# decomposer_layers = list(getattr(model.model, "model").decomposer)
# encoder_layers = list(getattr(model.model, "model").encoder)
# decoder_layers = list(getattr(model.model, "model").decoder)
# # layers = [getattr(model, model_type).embeddings] + list(getattr(model, model_type).encoder.layer)



# # layers.reverse()
decomposer_layers.reverse()
encoder_layers.reverse()
decoder_layers.reverse()

lr = learning_rate
for layer in decomposer_layers:
    lr *= layerwise_learning_rate_decay
    optimizer_grouped_parameters += [
        {
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        },
        # {
        #     "params": [p for n, p in layer.named_parameters() if "bias" in n],
        #     "weight_decay": 0.0,
        #     "lr": lr,
        # },
    ]

lr = learning_rate
for layer in encoder_layers:
    lr *= layerwise_learning_rate_decay
    optimizer_grouped_parameters += [
        {
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        },
        # {
        #     "params": [p for n, p in layer.named_parameters() if "bias" in n],
        #     "weight_decay": 0.0,
        #     "lr": lr,
        # },
    ]

lr = learning_rate
for layer in decoder_layers:
    lr *= layerwise_learning_rate_decay
    optimizer_grouped_parameters += [
        {
            "params": [n for n, p in layer.named_parameters() if "bias" not in n],
            "weight_decay": weight_decay,
            "lr": lr,
        },
        # {
        #     "params": [p for n, p in layer.named_parameters() if "bias" in n],
        #     "weight_decay": 0.0,
        #     "lr": lr,
        # },
    ]

In [None]:
for x in optimizer_grouped_parameters:
    print(x, end = "\n\n")

In [None]:
for name, param in model.model.model.decomposer.embed_positions.named_parameters():
    print(name)

In [None]:
x = [getattr(model.model.model, "encoder").embed_positions]

In [None]:
for layer in x:
    print({"params": [n for n, p in layer.named_parameters() if "bias" not in n]})

In [None]:
count = 0
for x in optimizer_grouped_parameters:
    count += len(x["params"])

print(count)

In [None]:
i = 0
for name, param in model.named_parameters():
    print(name)

print(i)

In [None]:
decomposer_layers