In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

class CustomModel(nn.Module):
    def __init__(self, num_categories, hidden_size, device=device):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert")
        self.model = AutoModelForSequenceClassification.from_pretrained("Rostlab/prot_bert", num_labels=2).to(device)
        self.embedding = nn.Embedding(num_categories, hidden_size)
        self.classifier = nn.Linear(hidden_size + self.model.config.hidden_size, 1)  # Binary classification head

    def forward(self, input_ids, attention_mask, categorical_data):
        # Tokenized sequence
        sequence_output = self.model(input_ids, attention_mask=attention_mask).last_hidden_state

        # Embed categorical data
        cat_embed = self.embedding(categorical_data)  # Shape: (batch_size, hidden_size)

        # Concatenate sequence and categorical embeddings
        combined_rep = torch.cat([sequence_output[:, 0], cat_embed], dim=-1)  # Shape: (batch_size, hidden_size + bert_hidden_size)

        # Binary classification
        logits = self.classifier(combined_rep)
        return logits

# Example usage
model = CustomModel(num_categories=4, hidden_size=128)
input_ids = torch.tensor([[1, 2, 3, ...]])  # Tokenized input sequence
attention_mask = torch.tensor([[1, 1, 1, ...]])  # Attention mask
categorical_data = torch.tensor([0, 1, 2, ...])  # Categorical data (A=0, B=1, C=2, D=3)

logits = model(input_ids, attention_mask, categorical_data)


In [81]:
from transformers import BertModel, AutoTokenizer
model = BertModel.from_pretrained("Rostlab/prot_bert")#.to(device)
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert")
tokens = tokenizer("A B C", max_length=5, padding='max_length', return_tensors="pt")
tokens

{'input_ids': tensor([[ 2,  6, 27, 23,  3]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [80]:
seq_out = model(**tokens)#.logits
seq_out.last_hidden_state.shape

torch.Size([1, 5, 1024])

In [77]:
import torch.nn as nn
from torch import tensor
import torch
embedding = nn.Embedding(4, 1024)
embed = embedding(tensor([1, 0, 0, 0]))
embed.shape

torch.Size([4, 1024])

In [76]:
torch.cat((torch.ones(4,12),torch.ones((4,12))), dim=-1)

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1.]])

In [72]:
torch.cat((seq_out.last_hidden_state[:, 0], embed), dim=-1)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 4 for tensor number 1 in the list.