In [12]:

# Alternatively, create a custom model class
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer, BertModel

# Load the default BERT configuration
config = BertConfig.from_pretrained('bert-base-uncased')
# Customize the configuration
# reduce vocab size
config.vocab_size = 8 * 11 * 11
config.max_position_embeddings = 82
config.num_hidden_layers = 6          # Reduce the number of transformer layers
config.hidden_size = 384              # Decrease the hidden size
config.num_attention_heads = 6        # Adjust the number of attention heads
config.intermediate_size = 1536       # Modify the size of the feedforward layers
config.hidden_dropout_prob = 0.2      # Increase dropout to prevent overfitting
config.attention_probs_dropout_prob = 0.2
config.num_labels = 40                 # For a 3-class classification problem

In [13]:
model = BertModel(config)

In [17]:
with torch.no_grad():
    outputs = model(inputs_embeds=torch.randn(1, 81, 384))
    print(outputs)


BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-2.2080,  2.6604,  2.1954,  ...,  0.4806,  0.5610, -2.0110],
         [ 0.2337, -1.2299,  0.5211,  ...,  0.2094, -0.5270, -0.5702],
         [ 0.2861,  0.8851,  0.0446,  ...,  0.5752,  0.3057, -0.4804],
         ...,
         [ 0.1296,  0.5067, -0.1854,  ..., -0.3859, -0.2433, -0.6874],
         [ 0.5463, -0.5763,  2.0978,  ..., -0.1174, -0.4010,  1.9679],
         [-3.2281,  2.2028,  0.2580,  ...,  1.0704,  0.2995, -0.0545]]]), pooler_output=tensor([[-7.5671e-02, -3.9718e-02,  1.7146e-02,  2.7881e-01,  1.8024e-01,
          1.3249e-01, -1.7843e-01,  3.4028e-01,  2.8295e-01, -3.4014e-01,
         -6.5024e-01, -9.9486e-02,  2.4463e-01, -3.7889e-01, -3.5526e-01,
         -2.1448e-01, -3.2601e-01, -5.1166e-01,  4.8278e-01, -8.3381e-01,
          1.6468e-02,  3.1327e-01, -6.6662e-01, -7.4291e-01,  5.2439e-01,
          1.6485e-01, -2.3875e-01, -5.4885e-01,  7.3861e-01, -2.4308e-01,
          5.2492e-01, -5.3238e-01, -

In [19]:
import torch as th
import torch.nn.functional as F
import sys
sys.path.append('/n/home12/binxuwang/Github/DiffusionReasoning/')
from GPT_models.GPT_RAVEN_model_lib import SepWordEmbed, CmbWordEmbed, SepLMhead, CmbLMhead
class MultiIdxBERTModel(nn.Module):
    def __init__(self, attribute_dims=(7,10,10), vocab_size=0, max_length=128, n_embd=768, n_class=40, is_sep_embed=True, **kwargs):

        super().__init__()
        # Combine embeddings
        combined_embedding_size = n_embd  # Adjust based on your combination strategy
        if is_sep_embed:
            self.sep_word_embed = SepWordEmbed(attribute_dims, embed_size=n_embd//3)
            self.multi_lmhead = SepLMhead(attribute_dims, embed_size=n_embd//3)
        else:
            self.sep_word_embed = CmbWordEmbed(attribute_dims, embed_size=n_embd)
            self.multi_lmhead = CmbLMhead(attribute_dims, embed_size=n_embd)
        config = BertConfig(vocab_size=vocab_size, 
                            max_position_embeddings=max_length, 
                            hidden_size=combined_embedding_size, **kwargs)
        self.bert = BertModel(config)
        self.context_embed = nn.Embedding(1, n_embd) # dummy embedding for start token
        self.classifier = nn.Linear(n_embd, n_class)

    def forward(self, input_ids, y=None):
        # input_ids is expected to be a list of three tensors [attr1, attr2, attr3]
        SOS = torch.zeros(input_ids.shape[0], dtype=th.long).to(input_ids[0].device)
        SOS_vec = self.context_embed(SOS)
        combined_embedding = self.sep_word_embed(input_ids)
        combined_embedding = torch.concat([SOS_vec[:,None,:], combined_embedding, ], dim=1)
        outputs = self.bert(inputs_embeds=combined_embedding)
        logits = self.classifier(outputs.pooler_output)
        return logits, outputs.last_hidden_state
    

# def multi_attr_loss(outputs, targets, loss_fn=F.cross_entropy, ):
#     loss1 = loss_fn(outputs[0].permute(0,2,1), targets[..., 0])
#     loss2 = loss_fn(outputs[1].permute(0,2,1), targets[..., 1])
#     loss3 = loss_fn(outputs[2].permute(0,2,1), targets[..., 2])
#     return loss1 + loss2 + loss3


# def multi_attr_loss_vec(outputs, targets, loss_fn=F.cross_entropy, ):
#     logits1, logits2, logits3 = outputs[0], outputs[1], outputs[2]
#     loss1 = loss_fn(logits1.reshape(-1, logits1.size(-1)), targets[..., 0].view(-1))
#     loss2 = loss_fn(logits2.reshape(-1, logits2.size(-1)), targets[..., 1].view(-1))
#     loss3 = loss_fn(logits3.reshape(-1, logits3.size(-1)), targets[..., 2].view(-1))
#     return loss1 + loss2 + loss3


# def next_token_loss(outputs, targets, loss_fn=F.cross_entropy):
#     logits1, logits2, logits3 = outputs[0], outputs[1], outputs[2]
#     loss1 = loss_fn(logits1[:, :-1, :].permute(0,2,1), targets[:, 1:, 0])
#     loss2 = loss_fn(logits2[:, :-1, :].permute(0,2,1), targets[:, 1:, 1])
#     loss3 = loss_fn(logits3[:, :-1, :].permute(0,2,1), targets[:, 1:, 2])
#     return loss1 + loss2 + loss3

In [20]:
import einops
import numpy as np
from os.path import join

def preprocess_ids(attr_seq_tsr, ):
    attr_seq_tsr_pps = attr_seq_tsr + 1 # clone() removed
    return attr_seq_tsr_pps

cmb_per_class = 4000
heldout_id = [1, 16, 20, 34, 37]
# Create a mask with all True values
# Set the specified rows to False
train_mask = torch.ones(40, dtype=torch.bool)
train_mask[heldout_id] = False
# old version
# data_dir = '/n/home12/binxuwang/Github/DiffusionReasoning/'
# attr_all = np.load(data_dir+'attr_all.npy')
data_dir = "/n/holylfs06/LABS/kempner_fellow_binxuwang/Users/binxuwang/Datasets/RPM_dataset/RPM1000k"
attr_all = np.load(join(data_dir, "attr_all_1000k.npy"))
print(attr_all.shape)
attr_all_rows = torch.from_numpy(attr_all, )
del attr_all
# attr_img_tsr = einops.rearrange(attr_all_rows,  'class (B R) p (h w) attr -> class B attr (R h) (p w)', h=3,w=3,p=3,R=3)
attr_seq_tsr = einops.rearrange(attr_all_rows,  'class (B R) p (h w) attr -> class B (R p h w) attr', h=3,w=3,p=3,R=3)
del attr_all_rows
# Set the y of the dataset, which is the class index; split the y into training and validation
y_rule = th.arange(attr_seq_tsr.shape[0], dtype=th.long).unsqueeze(1)
y_rule = y_rule.repeat(1, attr_seq_tsr.shape[1])
attr_seq_tsr = preprocess_ids(attr_seq_tsr)
# if the cmb_per_class is too large, change it such that it won't overlap with the validation set
if cmb_per_class > attr_seq_tsr.shape[1] - 500:
    cmb_per_class = attr_seq_tsr.shape[1] - 500
attr_seq_tsr_train, attr_seq_tsr_val, attr_seq_tsr_val_eval = \
    attr_seq_tsr[train_mask, :cmb_per_class], attr_seq_tsr[:, -500:], attr_seq_tsr[:, -50:] # changed June 30, 2024, also eval on untrained rules.
y_rule_train, y_rule_val, y_rule_val_eval = \
    y_rule[train_mask, :cmb_per_class], y_rule[:, -500:], y_rule[:, -50:]
y_rule_train = einops.rearrange(y_rule_train, 'class B -> (class B)', )
y_rule_val = einops.rearrange(y_rule_val, 'class B -> (class B)', )
y_rule_val_eval = einops.rearrange(y_rule_val_eval, 'class B -> (class B)', )
# combine the first 2 axes into 1
attr_seq_tsr_train = einops.rearrange(attr_seq_tsr_train, 'class B (R p h w) attr -> (class B) (R p h w) attr', R=3, p=3, h=3, w=3)
attr_seq_tsr_val = einops.rearrange(attr_seq_tsr_val, 'class B (R p h w) attr -> (class B) (R p h w) attr', R=3, p=3, h=3, w=3)
attr_seq_tsr_val_eval = einops.rearrange(attr_seq_tsr_val_eval, 'class B (R p h w) attr -> (class B) (R p h w) attr', R=3, p=3, h=3, w=3)
print(attr_seq_tsr_train.shape, attr_seq_tsr_val.shape, attr_seq_tsr_val_eval.shape)
del attr_seq_tsr


(40, 1200000, 3, 9, 3)




torch.Size([140000, 81, 3]) torch.Size([20000, 81, 3]) torch.Size([2000, 81, 3])


In [21]:
del attr_seq_tsr

In [30]:
from tqdm.auto import trange
from torch.utils.data import DataLoader, TensorDataset
from transformers import AdamW, get_linear_schedule_with_warmup
batch_size = 64
train_dataset = TensorDataset(attr_seq_tsr_train, y_rule_train)
val_dataset = TensorDataset(attr_seq_tsr_val, y_rule_val)
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, drop_last=False)

bert_raven = MultiIdxBERTModel(attribute_dims=(7,10,10), vocab_size=27, max_length=83, 
                                n_class=40, n_embd=384, is_sep_embed=True, n_layer=6, n_head=6)
# train loop
lr = 1e-4 #2e-5
num_warmup_steps = 1000
epoch_total = 10
eval_every_step = 500
total_steps = len(data_loader) * epoch_total
# bug fix @2024-08-18, before which, the num_training_steps is not the total_steps, so wrong scheduler 
# num_training_steps = len(data_loader) * epoch_total
optimizer = AdamW(bert_raven.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps)
bert_raven.train().to('cuda')



MultiIdxBERTModel(
  (sep_word_embed): SepWordEmbed(
    (embedding1): Embedding(8, 128)
    (embedding2): Embedding(11, 128)
    (embedding3): Embedding(11, 128)
  )
  (multi_lmhead): SepLMhead(
    (lmhead1): Linear(in_features=128, out_features=8, bias=True)
    (lmhead2): Linear(in_features=128, out_features=11, bias=True)
    (lmhead3): Linear(in_features=128, out_features=11, bias=True)
  )
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(27, 384, padding_idx=0)
      (position_embeddings): Embedding(83, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key)

In [31]:
pbar = trange(total_steps)
data_iter = iter(data_loader)
for step in pbar:
    try:
        inputs, ys = next(data_iter)
    except StopIteration:
        data_iter = iter(data_loader)
        inputs, ys = next(data_iter)
    inputs = inputs.cuda()
    ys = ys.cuda()
    optimizer.zero_grad()
    logits, outputs = bert_raven(inputs, y=ys)
    loss = F.cross_entropy(logits, ys)
    loss.backward()
    optimizer.step()
    scheduler.step()
    pbar.set_postfix(loss=loss.item())
    # evaluate test set
    if (step + 1) % eval_every_step == 0 or step == total_steps - 1 or step == 0:
        bert_raven.eval()
        with torch.no_grad():
            val_loss = 0
            acc_cnt = 0
            for inputs, ys in val_loader:
                inputs = inputs.cuda()
                ys = ys.cuda()
                logits, outputs = bert_raven(inputs, y=ys)
                loss = F.cross_entropy(logits, ys)
                val_loss += loss.item()
                acc_cnt += (logits.argmax(dim=-1) == ys).float().sum().item()
            acc_ratio = acc_cnt / len(val_loader.dataset)
            loss_avg = val_loss / len(val_loader)
            print(f"Step {step} Validation loss: {loss_avg}, accuracy: {acc_ratio}")
        bert_raven.train()

  0%|          | 0/21880 [00:00<?, ?it/s]

Step 0 Validation loss: 3.710217418549936, accuracy: 0.0283
Step 249 Validation loss: 3.572138068042224, accuracy: 0.0682
Step 499 Validation loss: 3.59108993071544, accuracy: 0.08995
Step 749 Validation loss: 3.518622908411147, accuracy: 0.11245
Step 999 Validation loss: 3.520975873440127, accuracy: 0.12895
Step 1249 Validation loss: 3.4385248603700083, accuracy: 0.1527
Step 1499 Validation loss: 3.3835425965393644, accuracy: 0.16985
Step 1749 Validation loss: 3.3505380010303063, accuracy: 0.1835
Step 1999 Validation loss: 3.2733103814004343, accuracy: 0.20335
Step 2249 Validation loss: 3.284302406861812, accuracy: 0.21135
Step 2499 Validation loss: 3.175745446847964, accuracy: 0.2245
Step 2749 Validation loss: 3.1793598615670504, accuracy: 0.2381
Step 2999 Validation loss: 3.089216758933248, accuracy: 0.24835
Step 3249 Validation loss: 3.071441326714769, accuracy: 0.2623
Step 3499 Validation loss: 2.990386594134041, accuracy: 0.27275
Step 3749 Validation loss: 2.991743972029867, accu

KeyboardInterrupt: 

In [10]:
# Initialize the model with the custom configuration
model = BertForSequenceClassification(config)

In [11]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(968, 384, padding_idx=0)
      (position_embeddings): Embedding(82, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.2, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.2, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, 

In [2]:

# class CustomBERTModel(nn.Module):
#     def __init__(self, config):
#         super(CustomBERTModel, self).__init__()
#         self.bert = BertForSequenceClassification(config)
#         # Add additional layers if needed
#         self.additional_layer = nn.Linear(config.hidden_size, config.hidden_size)
#         self.relu = nn.ReLU()

#     def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
#         outputs = self.bert(
#             input_ids, 
#             attention_mask=attention_mask, 
#             token_type_ids=token_type_ids, 
#             labels=labels,
#             return_dict=False
#         )
#         pooled_output = outputs[1]
#         # Apply additional layers
#         x = self.additional_layer(pooled_output)
#         x = self.relu(x)
#         logits = self.bert.classifier(x)
#         loss = None
#         if labels is not None:
#             loss = nn.CrossEntropyLoss()(logits.view(-1, self.config.num_labels), labels.view(-1))
#         return (loss, logits) if loss is not None else logits

# # Instantiate the custom model
# custom_model = CustomBERTModel(config)

In [None]:
from torch.utils.data import DataLoader, TensorDataset
from transformers import AdamW

# Create a dataset and dataloader
dataset = TensorDataset(
    encoded_inputs['input_ids'], 
    encoded_inputs['attention_mask'], 
    labels
)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

# Training loop
model.train()
for epoch in range(3):  # Number of epochs
    for batch in dataloader:
        optimizer.zero_grad()
        input_ids, attention_mask, labels = batch
        outputs = model(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            labels=labels
        )
        loss = outputs[0]
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")