In [None]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
import sys, os, math

sys.path.insert(0, '../dlp')

asdasfknas
from data_access import PQDataAccess

from data_process import *

pd.set_option('future.no_silent_downcasting', True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

batch_size = 64
block_size = 32
da = PQDataAccess("/home/aac/Alireza/datasets/taxseq/corpus_1000", batch_size)
epochs= 10_000
val_epoch = 100
num_val = 25

model_name = "Karpethy_GPT"
checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 


  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.
cuda:0
../checkpoints/Karpethy_GPT_checkpoints


In [2]:
def estimate_loss(eval_iters):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            tensor_batch = GPT_data_to_tensor_batch(da.get_batch(), block_size)
            tensor_batch.gpu(device)
            
            _, loss = model(tensor_batch.input_ids, tensor_batch.output_ids)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [3]:
from models.GPT import GPTLanguageModel

vocab_size = 23
n_embd = 512
n_head = 8
n_layer = 6
dropout = 0.2

model = GPTLanguageModel(vocab_size, block_size, n_embd, n_head, n_layer, dropout, device).to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

18.946071 M parameters


In [4]:
for iter_ in range(epochs):
    # every once in a while evaluate the loss on train and val sets
    if iter_ % val_epoch == 0 or iter_ == epochs - 1:
        losses = estimate_loss(num_val)
        print(f"step {iter_}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    tensor_batch = GPT_data_to_tensor_batch(da.get_batch(), block_size)
    tensor_batch.gpu(device)
    
    _, loss = model(tensor_batch.input_ids, tensor_batch.output_ids)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 3.2579, val loss 3.2510
step 100: train loss 2.9004, val loss 2.8926
step 200: train loss 2.8879, val loss 2.8944
step 300: train loss 2.8537, val loss 2.8737
step 400: train loss 2.8711, val loss 2.8640
step 500: train loss 2.8649, val loss 2.8716
step 600: train loss 2.8567, val loss 2.8571
step 700: train loss 2.8507, val loss 2.8595
step 800: train loss 2.8605, val loss 2.8842
step 900: train loss 2.8532, val loss 2.8572
step 1000: train loss 2.8559, val loss 2.8585
step 1100: train loss 2.8546, val loss 2.8500
step 1200: train loss 2.8557, val loss 2.8508
step 1300: train loss 2.8412, val loss 2.8513
step 1400: train loss 2.8464, val loss 2.8610
step 1500: train loss 2.8533, val loss 2.8527
step 1600: train loss 2.8424, val loss 2.8646
step 1700: train loss 2.8585, val loss 2.8465
step 1800: train loss 2.8480, val loss 2.8575
step 1900: train loss 2.8423, val loss 2.8462
step 2000: train loss 2.8299, val loss 2.8461
step 2100: train loss 2.8542, val loss 2.8447


In [15]:
# generate from the model
for _ in range(1):
    context = torch.ones((1, 1), device=device).to(torch.long)
    output = generate(context, max_new_tokens=50, block_size=block_size)[0].tolist()

    print(*[special_idx_to_char[s] for s in output])

22
12
21
4
17
11
22
8
21
22
5
19
8
8
18
15
3
17
17
6
15
7
15
3
11
10
15
16
5
19
20
9
10
15
20
17
8
18
8
14
11
5
7
8
5
5
15
19
3
15
<s> Y L W C R K Y G W Y D T G G S P A R R E P F P A K I P Q D T V H I P V R G S G N K D F G D D P T A P


In [14]:
from torch.nn import functional as F


def generate(idx, max_new_tokens, block_size):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -block_size:]
        # get the predictions
        logits, loss = model(idx_cond)
        # focus only on the last time step
        logits = logits[:, -1, :]  # becomes (B, C)
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1)  # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)

        print(idx_next[0].item())
    return idx


In [5]:
class BartEncoder(nn.Module):
    def __init__(self, config, embed_tokens=None):
        super().__init__()
        self.config = config
        
        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
            
        self.layers = nn.ModuleList(
            [BartEncoderLayer(config) for _ in range(config.encoder_layers)]
        )
        self.layernorm_embedding = nn.LayerNorm(config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model)
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        **kwargs
    ):
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError("You have to specify either input_ids or inputs_embeds")
            inputs_embeds = self.embed_tokens(input_ids)
            
        hidden_states = self.layernorm_embedding(inputs_embeds)
        
        # Create attention mask based on inputs_embeds if none provided
        if attention_mask is None:
            attention_mask = torch.ones(
                inputs_embeds.size()[:2],  # [batch_size, sequence_length]
                device=inputs_embeds.device
            )
            
        # Extend attention mask for multi-head attention
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        # Apply encoder layers
        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
            )
            
        hidden_states = self.layer_norm(hidden_states)
        
        return (hidden_states,)

class BartDecoder(nn.Module):
    def __init__(self, config, embed_tokens=None):
        super().__init__()
        self.config = config
        
        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
            
        self.layers = nn.ModuleList(
            [BartDecoderLayer(config) for _ in range(config.decoder_layers)]
        )
        self.layernorm_embedding = nn.LayerNorm(config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model)
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        **kwargs
    ):
        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError("You have to specify either input_ids or inputs_embeds")
            inputs_embeds = self.embed_tokens(input_ids)
            
        hidden_states = self.layernorm_embedding(inputs_embeds)
        
        # Create attention masks based on inputs_embeds if none provided
        if attention_mask is None:
            attention_mask = torch.ones(
                inputs_embeds.size()[:2],
                device=inputs_embeds.device
            )
            
        if encoder_attention_mask is None and encoder_hidden_states is not None:
            encoder_attention_mask = torch.ones(
                encoder_hidden_states.size()[:2],
                device=hidden_states.device
            )
            
        # Extend attention masks for multi-head attention
        if attention_mask is not None:
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
            extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        if encoder_attention_mask is not None:
            encoder_extended_attention_mask = encoder_attention_mask.unsqueeze(1).unsqueeze(2)
            encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=torch.float32)
            encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
        else:
            encoder_extended_attention_mask = None
        
        # Apply decoder layers
        for layer in self.layers:
            hidden_states = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_extended_attention_mask,
            )
            
        hidden_states = self.layer_norm(hidden_states)
        
        return (hidden_states,)

In [6]:
import torch
import torch.nn as nn
from transformers import BartModel, BartConfig
from transformers import BartPretrainedModel
from transformers.modeling_outputs import Seq2SeqModelOutput

class ProteinBartConfig(BartConfig):
    def __init__(
        self,
        vocab_size=26,  # 20 amino acids + special tokens
        max_position_embeddings=1024,
        d_model=768,
        encoder_layers=6,
        decoder_layers=6,
        encoder_attention_heads=12,
        decoder_attention_heads=12,
        **kwargs
    ):
        super().__init__(
            vocab_size=vocab_size,
            max_position_embeddings=max_position_embeddings,
            d_model=d_model,
            encoder_layers=encoder_layers,
            decoder_layers=decoder_layers,
            encoder_attention_heads=encoder_attention_heads,
            decoder_attention_heads=decoder_attention_heads,
            **kwargs
        )

class ProteinBartModel(BartPretrainedModel):
    def __init__(self, config: ProteinBartConfig):
        super().__init__(config)
        
        # Standard BART architecture
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
        
        # Protein-specific embeddings (optional)
        self.aa_type_embeddings = nn.Embedding(config.vocab_size, config.d_model)
        
        # Position encodings
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.d_model
        )
        
        # BART encoder and decoder
        self.encoder = BartEncoder(config, self.shared)
        self.decoder = BartDecoder(config, self.shared)
        
        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        encoder_outputs=None,
        **kwargs
    ):
        # Add position IDs
        position_ids = torch.arange(
            input_ids.size(1), dtype=torch.long, device=input_ids.device
        ).unsqueeze(0).expand_as(input_ids)
        
        # Get embeddings
        inputs_embeds = self.shared(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        
        # Combine embeddings
        hidden_states = inputs_embeds + position_embeds
        
        # Optional: Add amino acid type embeddings
        aa_type_embeds = self.aa_type_embeddings(input_ids)
        hidden_states = hidden_states + aa_type_embeds
        
        # Encoder
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=None,
                attention_mask=attention_mask,
                inputs_embeds=hidden_states,
                **kwargs
            )
        
        # Decoder
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            **kwargs
        )
        
        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs[0],
            encoder_last_hidden_state=encoder_outputs[0],
        )

# Tasks-specific heads
class ProteinBartForSequenceClassification(BartPretrainedModel):
    def __init__(self, config: ProteinBartConfig, num_labels: int):
        super().__init__(config)
        self.bart = ProteinBartModel(config)
        self.classification_head = nn.Linear(config.d_model, num_labels)
        
    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        outputs = self.bart(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
        # Use encoder's output for classification
        sequence_output = outputs.encoder_last_hidden_state
        
        # Pool the output (mean pooling)
        pooled_output = torch.mean(sequence_output, dim=1)
        
        # Classify
        logits = self.classification_head(pooled_output)
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.encoder_last_hidden_state
        )

# Example usage
def train_protein_bart():
    config = ProteinBartConfig(
        vocab_size=26,  # Adjust based on your amino acid vocabulary
        max_position_embeddings=1024,
        d_model=768,
        encoder_layers=6,
        decoder_layers=6
    )
    
    model = ProteinBartForSequenceClassification(config, num_labels=2)
    
    # Training data preparation example
    batch_size = 32
    seq_length = 512
    fake_proteins = torch.randint(0, 26, (batch_size, seq_length))
    labels = torch.randint(0, 2, (batch_size,))
    
    # Forward pass
    outputs = model(input_ids=fake_proteins, labels=labels)
    loss = outputs.loss
    
    # Backward pass
    loss.backward()

train_protein_bart()

AssertionError: For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D but found 4-D tensor instead