In [1]:
import torch
import torch.nn as nn
from datasets import load_dataset
import random
import gc
from concurrent.futures import ProcessPoolExecutor
from torch.utils.data import DataLoader,Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from transformers import AutoModel,AutoTokenizer,T5ForConditionalGeneration,T5Tokenizer,RobertaTokenizer

2025-05-27 16:14:21.791529: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748362462.055286      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748362462.132087      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def ruleBasedImputer(code):
    replacements = {
        '>': '<',
        ':': '',
        'def': '\n',
        '<': '>',
        '=': '!',
        '\t':' ',
        '+':'-',
        '*':'/',
        'and':'or',
        '&':'|',
        'in':'is'
    }
    for old, new in replacements.items():
        code = code.replace(old, new)
    return code


In [4]:
class LoadCodeAndNL(Dataset):
    def __init__(self,code_files,max_code_len=512,max_nl_len=512):
        self.code_files = code_files
        self.max_len = max_code_len
        self.max_nl_len = max_nl_len
        self.codebert_tokenizer =AutoTokenizer.from_pretrained(f'microsoft/codebert-base')
        self.NLtokenizer = T5Tokenizer.from_pretrained("t5-small")
        self.codeT5tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-small")

    def __len__(self):
        return len(self.code_files)

    def __getitem__(self,id):
        idx = id
        data = self.code_files[idx]
        
        code = data['func_code_string']
        doc = data['func_documentation_string']

        og_tok = self.codeT5tokenizer(code,return_tensors = 'pt',truncation=True,padding=True,max_length=self.max_len)
        og_tok_ids = og_tok.input_ids
        og_tok_mask = og_tok.attention_mask
        
        nl_tok = self.NLtokenizer(doc,return_tensors = 'pt',truncation=True,padding=True,max_length=self.max_nl_len)
        nl_tok_ids = nl_tok.input_ids
        nl_tok_mask = nl_tok.attention_mask
        
        if random.randint(1,100)%2==1 :
            code = ruleBasedImputer(code)
           # print('Imputed Code')
        codeTokens = self.codebert_tokenizer(code,return_tensors='pt', padding=True,truncation=True,max_length=self.max_len)
        
        return  {
            "code_input_ids": codeTokens['input_ids'].squeeze(0),
            "code_attention_mask": codeTokens['attention_mask'].squeeze(0),
            "doc_input_ids": nl_tok_ids.squeeze(0),
            "doc_attention_mask": nl_tok_mask.squeeze(0),
            "org_input_ids":og_tok_ids.squeeze(0),
            "org_tok_mask":og_tok_mask.squeeze(0),
        }
        

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self,pretrained_model = 'microsoft/codebert-base'):
        super(EncoderBlock,self).__init__()
        self.encoder = AutoModel.from_pretrained(pretrained_model)

    def forward(self,input_ids,attention_mask):
        output = self.encoder(input_ids=input_ids,attention_mask=attention_mask)
        return output.last_hidden_state

In [6]:
class CrossAttentinBlock(nn.Module):
    def __init__(self,embd_size=512,h=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embd_size,h,batch_first=True)
        self.norm = nn.LayerNorm(embd_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self,query,key,value,key_padding_mask=None):
        self_att,_ = self.attn(query,key,value,key_padding_mask=key_padding_mask)
        out = self.dropout(self_att)
        out = self.norm(query + out)
        return out

In [7]:
class NlDecoder(nn.Module):
    def __init__(self,pretrained_model = 't5-small'):
        super().__init__()
        
        model = T5ForConditionalGeneration.from_pretrained(pretrained_model)
        self.vocab_size = model.config.vocab_size
        self.decoder = model.decoder
        self.proj = nn.Linear(768,512)
        self.head = nn.Linear(512,self.vocab_size ,bias=False)
        
    def forward(self,encoder_inputs,de_inp_ids,de_attn_mask):
        #print("Running  decoder1Forward")
        d_inps = de_inp_ids
        att_mask = de_attn_mask
        
        x = self.proj(encoder_inputs)
        d_outputs = self.decoder(input_ids = d_inps,
                                 attention_mask = att_mask,    
                                 encoder_hidden_states = x,
                                 encoder_attention_mask=torch.ones(x.shape[:2], dtype=torch.long),
                                 return_dict=True
                                )

        hidden_states = d_outputs.last_hidden_state
        logits = self.head(hidden_states)

        return logits,hidden_states

In [8]:
class BugDecoder(nn.Module):
    def __init__(self,pretrained_model='Salesforce/codet5-small'):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
        model = T5ForConditionalGeneration.from_pretrained(pretrained_model)
        self.vocab_size = model.config.vocab_size
        self.decoder = model.decoder
        self.proj = nn.Linear(768,512)
        self.cross_attn = CrossAttentinBlock(512,8)
        self.head = nn.Linear(512,self.vocab_size ,bias=False)
        
    def forward(self,encoder_inputs,decoderNLinputs,de_inp_ids,de_att_mask):
       # print("Running  decoder2Forward")
        
        d_inps = de_inp_ids
        att_mask = de_att_mask

        x = self.proj(encoder_inputs)
        x = self.cross_attn(x,decoderNLinputs,decoderNLinputs)
        d_outputs = self.decoder(input_ids = d_inps,
                                 attention_mask = att_mask,    
                                 encoder_hidden_states = x,
                                 encoder_attention_mask=torch.ones(x.shape[:-1], dtype=torch.long),
                                 return_dict=True
                                )
        hidden_states = d_outputs.last_hidden_state
        logits = self.head(hidden_states)
        return logits
        

In [9]:
def collate_fn(batch):
    collated = {}
    for k in batch[0]:
        if isinstance(batch[0][k], torch.Tensor):
            if batch[0][k].dim() == 1:
                # Pad 1D sequences (token sequences)
                collated[k] = pad_sequence([item[k] for item in batch], batch_first=True, padding_value=0)
            else:
                collated[k] = torch.stack([item[k] for item in batch])
        else:
            collated[k] = [item[k] for item in batch]
    return collated

In [10]:
class Training(nn.Module):
    def __init__(self, data,val_data, batch=32): 
        super().__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dataset = LoadCodeAndNL(data)
        self.val_loader = LoadCodeAndNL(val_data)
        self.dataloader = DataLoader(self.dataset, batch_size=batch, shuffle=True, collate_fn=collate_fn)
        self.best_acc = 0
        self.encoder = EncoderBlock().to(self.device)
        self.decoder1 = NlDecoder().to(self.device)
        self.decoder2 = BugDecoder().to(self.device)
        
        self.lossFunction = nn.CrossEntropyLoss(ignore_index=0).to(self.device)

        self.freeze_base_layers()

      
        self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=1e-4)

    
    def freeze_base_layers(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

        for name, param in self.decoder1.named_parameters():
            if 'proj' not in name:
                param.requires_grad = False

        for name, param in self.decoder2.named_parameters():
            if 'proj' not in name and 'cross_attn' not in name:
                param.requires_grad = False

    def test(self):
        self.encoder.eval()
        self.decoder1.eval()
        self.decoder2.eval()
    
        nl_correct = 0
        code_correct = 0
        total_nl = 0
        total_code = 0
    
        with torch.no_grad():
            for batch in self.test_dataloader:
                input_ids = batch['code_input_ids'].to(self.device)
                attention_mask = batch['code_attention_mask'].to(self.device)
                doc_ids = batch["doc_input_ids"].to(self.device)
                doc_att = batch["doc_attention_mask"].to(self.device)
                og_ids = batch["org_input_ids"].to(self.device)
                og_att = batch["org_tok_mask"].to(self.device)
    
                e_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
                d_output1, hidden_states = self.decoder1(
                    encoder_inputs=e_output,
                    de_inp_ids=doc_ids,
                    de_attn_mask=doc_att
                )
                d_output2 = self.decoder2(
                    encoder_inputs=e_output,
                    decoderNLinputs=hidden_states.detach(),
                    de_inp_ids=og_ids,
                    de_att_mask=og_att
                )
    
                pred_nl = torch.argmax(d_output1, dim=-1)
                nl_mask = doc_ids != 0
                nl_correct += (pred_nl == doc_ids).masked_select(nl_mask).sum().item()
                total_nl += nl_mask.sum().item()
    
                pred_code = torch.argmax(d_output2, dim=-1)
                code_mask = og_ids != 0
                code_correct += (pred_code == og_ids).masked_select(code_mask).sum().item()
                total_code += code_mask.sum().item()
    
        nl_acc = 100 * nl_correct / total_nl
        code_acc = 100 * code_correct / total_code
        print(f" Test | NL Accuracy: {nl_acc:.2f}% | Code Accuracy: {code_acc:.2f}%")

    
    def validate(self):
        self.encoder.eval()
        self.decoder1.eval()
        self.decoder2.eval()
    
        nl_correct = 0
        code_correct = 0
        total_nl = 0
        total_code = 0
    
        with torch.no_grad():
            for batch in self.val_loader:
                input_ids = batch['code_input_ids'].to(self.device)
                attention_mask = batch['code_attention_mask'].to(self.device)
                doc_ids = batch["doc_input_ids"].to(self.device)
                doc_att = batch["doc_attention_mask"].to(self.device)
                og_ids = batch["org_input_ids"].to(self.device)
                og_att = batch["org_tok_mask"].to(self.device)
    
                e_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
                d_output1, hidden_states = self.decoder1(
                    encoder_inputs=e_output,
                    de_inp_ids=doc_ids,
                    de_attn_mask=doc_att
                )
    
                d_output2 = self.decoder2(
                    encoder_inputs=e_output,
                    decoderNLinputs=hidden_states.detach(),
                    de_inp_ids=og_ids,
                    de_att_mask=og_att
                )
    
                pred_nl = torch.argmax(d_output1, dim=-1)
                nl_mask = doc_ids != 0  
                nl_correct += (pred_nl == doc_ids).masked_select(nl_mask).sum().item()
                total_nl += nl_mask.sum().item()
    
                pred_code = torch.argmax(d_output2, dim=-1)
                code_mask = og_ids != 0
                code_correct += (pred_code == og_ids).masked_select(code_mask).sum().item()
                total_code += code_mask.sum().item()
    
        nl_acc = 100 * nl_correct / total_nl
        code_acc = 100 * code_correct / total_code
        print(f"Validation | NL Accuracy: {nl_acc:.2f}% | Code Accuracy: {code_acc:.2f}%")

    def train(self, epochs=3):
        for epoch in range(epochs):
            for batch in tqdm(self.dataloader, desc=f"Epoch {epoch+1}"):
                input_ids = batch['code_input_ids'].to(self.device)
                attention_mask = batch['code_attention_mask'].to(self.device)
                doc_ids = batch["doc_input_ids"].to(self.device)
                doc_att = batch["doc_attention_mask"].to(self.device)
                og_ids = batch["org_input_ids"].to(self.device)
                og_att = batch["org_tok_mask"].to(self.device)

             
                with torch.no_grad():
                    e_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

                d_output1, hidden_states = self.decoder1(
                    encoder_inputs=e_output,
                    de_inp_ids=doc_ids,
                    de_attn_mask=doc_att
                )

        
                d_output2 = self.decoder2(
                    encoder_inputs=e_output,
                    decoderNLinputs=hidden_states.detach(),  
                    de_inp_ids=og_ids,
                    de_att_mask=og_att
                )
                e_output = e_output
                doc_ids = doc_ids
                #doc_attn_mask = doc_attn_mask


                nl_loss = self.lossFunction(d_output1.view(-1, d_output1.size(-1)), doc_ids.view(-1))
                code_loss = self.lossFunction(d_output2.view(-1, d_output2.size(-1)), og_ids.view(-1))

                loss = nl_loss + code_loss

       
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                print(f"Epoch {epoch + 1} | NL Loss: {nl_loss.item():.4f} | Code Loss: {code_loss.item():.4f}")

                torch.cuda.empty_cache()
                gc.collect()
            nl_acc,code_acc = self.validate()
            if nl_acc + code_acc > self.best_acc:
                self.best_acc = nl_acc + code_acc
                torch.save({
                    'encoder': self.encoder.state_dict(),
                    'decoder1': self.decoder1.state_dict(),
                    'decoder2': self.decoder2.state_dict(),
                    'optimizer': self.optimizer.state_dict()
                }, 'best_model.pt')


In [10]:
Data =  load_dataset("code_search_net", "python")
train = Training(Data['train'],Data["validation"])
train.train(epochs=3)

README.md:   0%|          | 0.00/12.9k [00:00<?, ?B/s]

code_search_net.py:   0%|          | 0.00/8.44k [00:00<?, ?B/s]

The repository for code_search_net contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/code_search_net.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


python.zip:   0%|          | 0.00/941M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/412178 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/22176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/23107 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/498 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


tokenizer_config.json:   0%|          | 0.00/1.48k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/703k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/294k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]


Epoch 1:   0%|          | 0/12881 [00:00<?, ?it/s][APassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.

Epoch 1:   0%|          | 1/12881 [02:25<519:45:30, 145.27s/it][A

Epoch 1 | NL Loss: 10.4803 | Code Loss: 17.0875



Epoch 1:   0%|          | 2/12881 [04:37<492:36:47, 137.70s/it][A

Epoch 1 | NL Loss: 10.4434 | Code Loss: 16.0720



Epoch 1:   0%|          | 3/12881 [06:41<469:31:44, 131.26s/it][A

Epoch 1 | NL Loss: 10.4159 | Code Loss: 15.1116



Epoch 1:   0%|          | 4/12881 [08:51<467:24:02, 130.67s/it][A

Epoch 1 | NL Loss: 10.3861 | Code Loss: 14.4257



Epoch 1:   0%|          | 5/12881 [10:59<464:31:12, 129.88s/it][A

Epoch 1 | NL Loss: 10.3715 | Code Loss: 13.7674


Epoch 1:   0%|          | 5/12881 [11:33<496:04:17, 138.70s/it]


Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_35/1537798478.py", line 3, in <cell line: 0>
    train.train(epochs=3)
  File "/tmp/ipykernel_35/3629609618.py", line 141, in train
    d_output1, hidden_states = self.decoder1(
                               ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_35/2563809120.py", line 17, in forward
    d_outputs = self.decoder(input_ids = d_inps,
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/loc

TypeError: object of type 'NoneType' has no len()