In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer, AdamW, get_scheduler
from datasets import load_dataset
import tree_sitter
from tree_sitter import Language, Parser
import numpy as np
from tqdm import tqdm
from tree_sitter_java import language

JAVA = Language(language())

parser = Parser(JAVA)

import os

# Step 1: Setup tree-sitter for AST parsing (example for Python code)

# Step 2: Load CodeXGLUE Code Refinement Dataset
dataset = load_dataset("google/code_x_glue_cc_code_refinement", 'medium', split='train[:40%]')
val_dataset = load_dataset("google/code_x_glue_cc_code_refinement", 'medium', split='validation[:2%]')

In [None]:
# Step 3: Initialize UniXcoder Model and Tokenizer
model_name = "microsoft/unixcoder-base"  # Replace with the correct UniXcoder model if different
tokenizer = AutoTokenizer.from_pretrained(model_name)
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
model = AutoModel.from_pretrained(model_name)

class CodeRefinementDataset(torch.utils.data.Dataset):
    def __init__(self, data, unix_tokenizer, gpt_tokenizer, max_length=512):
        self.data = data
        self.unix_tokenizer = unix_tokenizer
        self.gpt_tokenizer = gpt_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        buggy_code = item['original_string']
        fixed_code = item['fixed_string']

        # Encode input and target
        input_enc = self.unix_tokenizer(buggy_code, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
        target_enc = self.gpt_tokenizer(fixed_code, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")

        # Fake AST features for demo (replace with real AST)
        ast_features = self.extract_ast_features(buggy_code)  # (seq_len, ast_feature_dim)

        return {
            'input_ids': input_enc['input_ids'].squeeze(0),
            'attention_mask': input_enc['attention_mask'].squeeze(0),
            'decoder_input_ids': target_enc['input_ids'].squeeze(0),
            'labels': target_enc['input_ids'].squeeze(0),
            'ast_features': ast_features
        }

    def extract_ast_features(self, code):
        """
        Extract AST features using tree-sitter.
        For simplicity, return a placeholder (e.g., node count or depth).
        In practice, you can encode AST paths or node types into embeddings.
        """
        tree = parser.parse(bytes(code, "utf8"))
        root_node = tree.root_node
        
        # Example: Count number of nodes as a simple feature
        def count_nodes(node):
            count = 1
            for child in node.children:
                count += count_nodes(child)
            return count
        
        node_count = count_nodes(root_node)
        # Placeholder: Return a simple feature vector (expand this for real use)
        return [node_count, len(code.splitlines())]

# Step 5: DataLoader Setup
train_dataset = CodeRefinementDataset(dataset, tokenizer)
val_dataset = CodeRefinementDataset(val_dataset, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

ValueError: Unrecognized configuration class <class 'transformers.models.roberta.configuration_roberta.RobertaConfig'> for this kind of AutoModel: AutoModelForSeq2SeqLM.
Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, GPTSanJapaneseConfig, LEDConfig, LongT5Config, M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, NllbMoeConfig, PegasusConfig, PegasusXConfig, PLBartConfig, ProphetNetConfig, Qwen2AudioConfig, SeamlessM4TConfig, SeamlessM4Tv2Config, SwitchTransformersConfig, T5Config, UMT5Config, XLMProphetNetConfig.

In [None]:
from transformers import GPT2LMHeadModel

class UniXcoderWithASTEncoder(nn.Module):
    def __init__(self, base_model, ast_feature_dim=2):
        super().__init__()
        self.base_model = base_model
        self.ast_projection = nn.Linear(ast_feature_dim, base_model.config.hidden_size)

    def forward(self, input_ids, attention_mask, ast_features=None):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        hidden = outputs.last_hidden_state
        if ast_features is not None:
            ast_embeds = self.ast_projection(ast_features)
            hidden = hidden + ast_embeds
        return hidden, attention_mask


class UniXcoderSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder_name="gpt2"):
        super().__init__()
        self.encoder = encoder
        self.decoder = GPT2LMHeadModel.from_pretrained(decoder_name)
        self.decoder.config.pad_token_id = self.decoder.config.eos_token_id
        self.enc_to_dec = nn.Linear(
            encoder.base_model.config.hidden_size,
            self.decoder.config.hidden_size
        )

    def forward(self, input_ids, attention_mask, decoder_input_ids, ast_features=None, labels=None):
        encoder_hidden_states, encoder_attention_mask = self.encoder(input_ids, attention_mask, ast_features)
        encoder_hidden_states = self.enc_to_dec(encoder_hidden_states)

        return self.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            labels=labels
        )
def generate_with_beam_search(
    model, input_ids, attention_mask, ast_features, tokenizer,
    max_length=64, num_beams=4
):
    encoder_hidden_states, encoder_attention_mask = model.encoder(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ast_features=ast_features
    )
    encoder_hidden_states = model.enc_to_dec(encoder_hidden_states)

    # Use decoder's generate function (inherited from PreTrainedModel)
    generated_ids = model.decoder.generate(
        input_ids=None,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        max_length=max_length,
        num_beams=num_beams,
        early_stopping=True,
        decoder_start_token_id=tokenizer.bos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )

    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate model
encoder = UniXcoderWithASTEncoder(model)
model = UniXcoderSeq2Seq(encoder).to(device)

# Prepare data
train_dataset = CodeRefinementDataset(dataset, tokenizer, gpt_tokenizer)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training loop
for epoch in range(3):
    model.train()
    total_loss = 0

    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        decoder_input_ids = batch['decoder_input_ids'].to(device)
        labels = batch['labels'].to(device)
        ast_features = batch['ast_features'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            labels=labels,
            ast_features=ast_features
        )

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

    print(f"Epoch {epoch + 1} - Loss: {total_loss / len(train_loader)}")