In [2]:
import sys
sys.path.append('src')

In [3]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast as autocast
from transformers import AutoTokenizer, logging

from codeblip import CodeBlip
from codeblip_qformer import CodeQformer
# from modelling_t5 import T5Config, T5ForConditionalGeneration
from transformers import AutoTokenizer, T5ForConditionalGeneration

logging.set_verbosity_error()  # Only show errors, not warnings

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def update_loss_dict(loss_dict, new_loss_dict):
    for key in new_loss_dict:
        if key not in loss_dict:
            loss_dict[key] = 0
        loss_dict[key] += new_loss_dict[key].item()
    return loss_dict

set_seed(42)

# Stage 1

In [17]:
class CodeQformer(CodeBlip):  # Inherits from Blip2Base
    """
    CodeBlip Qformer model for code translation task.
    """

    def __init__(
        self,
        num_query_token=32,
        cross_attention_freq=2,
        embed_dim=256,
        max_source_len=128,
        max_target_len=128,
    ):
        super().__init__()

        self.code_encoder, self.ln_code = self.init_code_encoder()
        # freeze the encoder
        for param in self.code_encoder.parameters():
            param.requires_grad = False
        self.code_encoder.eval()
        # self.code_encoder.train =


        # Initialize the Qformer and Query Tokens
        self.Qformer, self.query_tokens = self.init_Qformer(
            num_query_token, embed_dim, cross_attention_freq
        )
        # Tokenizer for encoding source and target code
        self.tokenizer = self.init_tokenizer()
        self.Qformer.resize_token_embeddings(len(self.tokenizer))

        # not sure what this does
        state_dict = self.Qformer.state_dict()
        for name, param in self.Qformer.named_parameters():
            if "_query" in name:
                key_orig = name.replace("_query", "")
                param.data.copy_(state_dict[key_orig])

        # Projection layers for source and target code
        self.source_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
        self.target_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)

        # Temperature parameter for contrastive loss
        self.temp = nn.Parameter(0.07 * torch.ones([]))

        # Max lengths for source and target tokens
        self.max_source_len = max_source_len
        self.max_target_len = max_target_len

        self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2)

    def forward(self, samples):
        source_code = samples["source_code"] # image
        target_code = samples["target_code"] # text

        # Process source code
        source_tokens = self.tokenizer(
            source_code, padding="max_length", truncation=True,
            max_length=self.max_source_len, return_tensors="pt"
        ).to(self.Qformer.device)

        # code embedding
        # source_output = self.ln_code(self.code_encoder(source_tokens.input_ids, attention_mask=source_tokens.attention_mask, return_dict=True).last_hidden_state)
        source_output = self.ln_code(self.code_encoder(**source_tokens, return_dict=True).last_hidden_state)

        # source_output = self.Qformer.bert(**source_tokens, return_dict=True)
        # source_representations = F.normalize(
        #     self.source_proj(source_output.last_hidden_state[:, 0, :]), dim=-1
        # )

        # expand query tokens to batch size
        query_tokens = self.query_tokens.expand(source_output.shape[0], -1, -1)

        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=source_output,
            encoder_attention_mask=source_tokens.attention_mask,
            use_cache=True,
            return_dict=True,
        )

        source_features = F.normalize(
            self.source_proj(query_output.last_hidden_state), dim=-1
        )

        # Process target code
        target_tokens = self.tokenizer(
            target_code, padding="max_length", truncation=True,
            max_length=self.max_target_len, return_tensors="pt"
        ).to(self.Qformer.device)
        target_output = self.Qformer.bert(target_tokens.input_ids, attention_mask=target_tokens.attention_mask, return_dict=True)

        target_features = F.normalize(
            self.target_proj(target_output.last_hidden_state[:, 0, :]), dim=-1
        )

        # Contrastive learning
        # sim_matrix = torch.matmul(source_features, target_features.T) / self.temp
        # loss_contrastive = F.cross_entropy(sim_matrix, torch.arange(source_features.size(0), device=sim_matrix.device))
        #
        # return loss_contrastive

        # Contrastive learning
        # image feat = source_features
        # text feat = target_features

        source_features_all = source_features
        target_features_all = target_features

        #
        sim_q2t = torch.matmul(source_features.unsqueeze(1), target_features_all.unsqueeze(-1)).squeeze()

        # source-target similarity
        sim_s2t, _ = sim_q2t.max(-1)
        sim_s2t /= self.temp

        # sim_t2q = torch.matmul(target_features.unsqueeze(1), source_features_all.unsqueeze(-1)).squeeze()
        sim_t2q = torch.matmul(
            target_features.unsqueeze(1).unsqueeze(1), source_features_all.permute(0, 2, 1)
        ).squeeze()
        # target-source similarity ; aggregate the max across all query tokens
        sim_t2s, _ = sim_t2q.max(-1)
        sim_t2s /= self.temp

        # rank = dist.get_rank()
        rank = 0 # something to do with distributed training, but we're not using it so just set to 0
        bs = source_features.size(0)
        targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(sim_s2t.device)

        loss_stc = (F.cross_entropy(sim_s2t, targets, label_smoothing=0.1) + F.cross_entropy(sim_t2s, targets, label_smoothing=0.1)) / 2

        # return loss_stc

        # Sorce - Target Matching
        # image feat = source_features
        # text feat = target_features
        target_input_ids_world = target_tokens.input_ids # if distributed, this is all_gather(target_tokens.input_ids)
        target_attention_mask_world = target_tokens.attention_mask # if distributed, this is all_gather(target_tokens.attention_mask)
        # image_embeds_world = all_gather_with_grad(image_embeds)
        source_features_world = source_features_all # if distributed, this is all_gather_with_grad(source_features_all)
        with torch.no_grad():
            # if "image_id" in samples.keys():
            #     mask = torch.eq(image_ids, image_ids_all.t())
            #     sim_t2i.masked_fill_(mask, -10000)
            #     sim_i2t.masked_fill_(mask, -10000)
            # else:
            sim_s2t[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)
            sim_t2s[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)

            weights_t2s = F.softmax(sim_t2s, dim=1)
            weights_s2t = F.softmax(sim_s2t, dim=1)

        # select a negative image for each text
        source_embeds_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_t2s[b], 1).item()
            source_embeds_neg.append(source_features_world[neg_idx])
        source_embeds_neg = torch.stack(source_embeds_neg, dim=0)

        # select a negative text for each image
        target_ids_neg = []
        target_atts_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_s2t[b], 1).item()
            target_ids_neg.append(target_input_ids_world[neg_idx])
            target_atts_neg.append(target_attention_mask_world[neg_idx])

        target_ids_neg = torch.stack(target_ids_neg, dim=0)
        target_atts_neg = torch.stack(target_atts_neg, dim=0)

        target_ids_all = torch.cat(
            [target_tokens.input_ids, target_tokens.input_ids, target_ids_neg], dim=0
        )  # pos, pos, neg
        target_atts_all = torch.cat(
            [target_tokens.attention_mask, target_tokens.attention_mask, target_atts_neg],
            dim=0,
        )

        query_tokens_itm = self.query_tokens.expand(target_ids_all.shape[0], -1, -1)
        query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
            self.Qformer.device
        )
        attention_mask_all = torch.cat([query_atts_itm, target_atts_all], dim=1)

        source_embeds_all = torch.cat(
            [source_features, source_embeds_neg, source_features], dim=0
        )  # pos, neg, pos
        source_atts_all = torch.ones(source_embeds_all.size()[:-1], dtype=torch.long).to(
            self.Qformer.device
        )

        output_itm = self.Qformer.bert(
            target_ids_all,
            query_embeds=query_tokens_itm,
            attention_mask=attention_mask_all,
            encoder_hidden_states=source_embeds_all,
            encoder_attention_mask=source_atts_all,
            return_dict=True,
        )

        vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
        vl_output = self.itm_head(vl_embeddings)
        logits = vl_output.mean(dim=1)

        itm_labels = torch.cat(
            [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
            dim=0,
        ).to(self.Qformer.device)
        loss_stm = F.cross_entropy(logits, itm_labels)


        # 3rd loss
        decoder_input_ids = target_tokens.input_ids.clone()
        decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
        labels = decoder_input_ids.masked_fill(
            decoder_input_ids == self.tokenizer.pad_token_id, -100
        )

        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
            self.Qformer.device
        )
        attention_mask = torch.cat([query_atts, target_tokens.attention_mask], dim=1)
        lm_output = self.Qformer(
            decoder_input_ids,
            attention_mask=attention_mask,
            past_key_values=query_output.past_key_values,
            return_dict=True,
            labels=labels,
        )

        loss_lm = lm_output.loss

        total_loss =  loss_stc + loss_stm + loss_lm

        return {
            "loss": total_loss,
            "loss_stc": loss_stc,
            "loss_stm": loss_stm,
            "loss_lm": loss_lm,
        }



    @classmethod
    def from_config(cls, cfg):
        num_query_token = cfg.get("num_query_token", 32)
        cross_attention_freq = cfg.get("cross_attention_freq", 2)
        embed_dim = cfg.get("embed_dim", 768)
        max_source_len = cfg.get("max_source_len", 512)
        max_target_len = cfg.get("max_target_len", 512)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = cls(
            num_query_token=num_query_token,
            cross_attention_freq=cross_attention_freq,
            embed_dim=embed_dim,
            max_source_len=max_source_len,
            max_target_len=max_target_len,
        ).to(device)
        pretrained_path = cfg.get("pretrained_path", os.path.join("models", "stage1_out", "stage1_best.pt"))
        model.load_state_dict(torch.load(pretrained_path))
        return model

In [34]:
def stage1_test():
    model = CodeQformer(num_query_token=32, cross_attention_freq=2, embed_dim=768, max_source_len=512, max_target_len=512)

    stage1_pretrained_path = os.path.join("models", "stage1_out", "stage1_best.pt")
    # model = CodeQformer.from_config({
    #     "num_query_token": 32,
    #     "cross_attention_freq": 2,
    #     "embed_dim": 768,
    #     "max_source_len": 512,
    #     "max_target_len": 512,
    #     "pretrained_path": stage1_pretrained_path
    # })

    print(model.query_tokens, model.query_tokens.shape)

    # Dummy source and target code samples
    source_code_samples = [
        "public class HelloWorld { public static void main(String[] args) { System.out.println(\"Hello, world!\"); } }",
        "public class Test { public static int add(int a, int b) { return a + b; } }"]
    target_code_samples = [
        "class HelloWorld { static void Main(string[] args) { Console.WriteLine(\"Hello, world!\"); } }",
        "class Test { static int Add(int a, int b) { return a + b; } }"]

    # Prepare input for the model
    samples = {"source_code": source_code_samples, "target_code": target_code_samples}

    # Perform a forward pass
    losses = model(samples)
    print(f"Loss from forward pass: {losses['loss'].item()}")

stage1_test()

Parameter containing:
tensor([[[ 0.0137,  0.0076, -0.0196,  ..., -0.0229, -0.0487,  0.0145],
         [-0.0258, -0.0233, -0.0146,  ...,  0.0117, -0.0187, -0.0115],
         [ 0.0406,  0.0035, -0.0143,  ..., -0.0202,  0.0014, -0.0066],
         ...,
         [-0.0292, -0.0065, -0.0202,  ...,  0.0041, -0.0126, -0.0023],
         [-0.0280, -0.0186, -0.0026,  ...,  0.0067,  0.0158,  0.0110],
         [-0.0106, -0.0122,  0.0075,  ...,  0.0217,  0.0329,  0.0013]]],
       requires_grad=True) torch.Size([1, 32, 768])
Loss from forward pass: 12.496234893798828


In [None]:
# # Trained

# Parameter containing:
# tensor([[[ 0.0353, -0.0337, -0.0014,  ...,  0.0036,  0.0390, -0.0232],
#          [-0.0112,  0.0043,  0.0010,  ..., -0.0294, -0.0440, -0.0227],
#          [-0.0105, -0.0109, -0.0025,  ...,  0.0016, -0.0248, -0.0081],
#          ...,
#          [ 0.0352,  0.0090,  0.0372,  ...,  0.0266, -0.0092, -0.0088],
#          [ 0.0177,  0.0234,  0.0203,  ..., -0.0191, -0.0187,  0.0254],
#          [-0.0103,  0.0153,  0.0066,  ...,  0.0129, -0.0445, -0.0137]]],
#        device='cuda:0', requires_grad=True) torch.Size([1, 32, 768])
# Loss from forward pass: 6.058288097381592

# Stage 2

In [4]:
def codet5test():
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5ForConditionalGeneration, RobertaTokenizer
    import torch

    checkpoint = "Salesforce/codet5-base"
    sp_checkpoint = "Salesforce/codet5-base-codexglue-translate-java-cs"
    device = "cuda" # for GPU usage or "cpu" for CPU usage

    # task_prefix = 'Translate Java to Python:'
    task_prefix = 'Translate Java to C#:'

    tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
    model = T5ForConditionalGeneration.from_pretrained(sp_checkpoint,
                                                torch_dtype=torch.float32,
                                                low_cpu_mem_usage=True,
                                                trust_remote_code=True,
                                                early_stopping=False,
                                                ).to(device)

    # text = ":"
    # sample_java_code = "public void serialize(LittleEndianOutput out) {out.writeShort(field_1_vcenter);}"
    # hello world
    sample_java_code_with_class = "public class HelloWorld{ public static void main(String[] args) {System.out.println(\"Hello, world!\"); }"
    sample_java_code_without_class = "public static void main(String[] args) {System.out.println(\"Hello, world!\"); }"
    # sample_java_code = ': if x==0: x += 1"'

    encoding = tokenizer([task_prefix + sample_java_code_with_class], return_tensors="pt", padding=True).to(device)
    # encoding['decoder_input_ids'] = encoding['input_ids'].clone()
    outputs = model.generate(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], do_sample=False, max_length=750)
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

    encoding = tokenizer([task_prefix + sample_java_code_without_class], return_tensors="pt", padding=True).to(device)
    # encoding['decoder_input_ids'] = encoding['input_ids'].clone()
    outputs = model.generate(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], do_sample=False, max_length=750)
    print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

codet5test()

['public static void Main(string[] args){throw new System.NotImplementedException();}']
['public static void Main(string[] args){WriteLine("Hello, world!");}']


In [15]:
class CodeBlipT5(CodeBlip):
    def __init__(self, qformer, query_tokens, t5_tokenizer, t5_model, prompt, max_source_len=512,
        max_target_len=512, num_query_token=32, embed_dim=768, cross_attention_freq=2):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Device: {self.device}")

        self.tokenizer = self.init_tokenizer()
        self.code_encoder, self.ln_code = self.init_code_encoder()

        # self.code_encoder.to(self.device)
        # self.ln_code.to(self.device)

        # freeze the encoder
        for param in self.code_encoder.parameters():
            param.requires_grad = False
        self.code_encoder.eval()

        self.max_source_len = max_source_len
        self.max_target_len = max_target_len

        if qformer is not None and query_tokens is not None:
            self.Qformer = qformer
            self.query_tokens = query_tokens
        else:
            self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, embed_dim, cross_attention_freq)

        # not sure if this is needed
        # self.Qformer.cls = None
        # self.Qformer.bert.embeddings.word_embeddings = None
        # self.Qformer.bert.embeddings.position_embeddings = None
        # for layer in self.Qformer.bert.encoder.layer:
        #     layer.output = None
        #     layer.intermediate = None

        self.Qformer.to(self.device)


        # self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model)
        # self.t5_tokenizer = AutoTokenizer.from_pretrained(t5_model)
        self.t5_tokenizer = t5_tokenizer
        # t5_config = T5Config.from_pretrained(t5_model)
        # t5_config.dense_act_fn = "gelu"
        # self.t5_model = T5ForConditionalGeneration.from_pretrained(
        #     t5_model, config=t5_config
        # )

        # self.t5_model = T5ForConditionalGeneration.from_pretrained(t5_model)
        self.t5_model = t5_model

        for param in self.t5_model.parameters():
            param.requires_grad = False
            # param.data = param.data.float16()
            # make it fp16
            param.data = param.data.half()

        self.t5_proj = nn.Linear(
            self.Qformer.config.hidden_size, self.t5_model.config.hidden_size
        )

        # self.max_txt_len = max_txt_len
        self.prompt = prompt


    def forward(self, samples):
        source_code = samples["source_code"] # image
        # target_code = samples["target_code"] # text

        # Process source code
        source_tokens = self.tokenizer(
            source_code, padding="max_length", truncation=True,
            max_length=self.max_source_len, return_tensors="pt"
        ).to(self.device)

        # code embedding
        source_output = self.ln_code(self.code_encoder(**source_tokens, return_dict=True).last_hidden_state)

        # expand query tokens to batch size
        query_tokens = self.query_tokens.expand(source_output.shape[0], -1, -1)

        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=source_output,
            encoder_attention_mask=source_tokens.attention_mask,
            use_cache=True,
            return_dict=True,
        )

        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(self.device)


        with self.maybe_autocast(dtype=torch.float16):
            input_tokens = self.t5_tokenizer(
                samples['target_code_input'],
                padding="longest",
                truncation=True,
                max_length=self.max_source_len,
                return_tensors="pt",
            ).to(self.device)
            output_tokens = self.t5_tokenizer(
                samples['target_code_output'],
                padding="longest",
                truncation=True,
                max_length=self.max_target_len,
                return_tensors="pt",
            ).to(self.device)

        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

        targets = output_tokens.input_ids.masked_fill(
            output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100
        )

        inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
        inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

        outputs = self.t5_model(
            inputs_embeds=inputs_embeds,
            attention_mask=encoder_atts,
            decoder_attention_mask=output_tokens.attention_mask,
            return_dict=True,
            labels=targets,
        )
        loss = outputs.loss

        return {"loss": loss}

    @torch.no_grad()
    def generate(
        self,
        samples,
        use_nucleus_sampling=False,
        num_beams=5,
        max_length=30,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1.0,
        length_penalty=0.9,
        num_captions=1,
        temperature=1,
    ):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
            num_beams (int): Number of beams for beam search. 1 means no beam search.
            max_length (int): The maximum length of the sequence to be generated.
            min_length (int): The minimum length of the sequence to be generated.
            top_p (float): The cumulative probability for nucleus sampling.
            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
            num_captions (int): Number of captions to be generated for each image.
        Returns:
            captions (list): A list of strings of length batch_size * num_captions.
        """

        source_code = samples["source_code"] # image

        source_tokens = self.tokenizer(
            source_code, padding="max_length", truncation=True,
            max_length=self.max_source_len, return_tensors="pt"
        ).to(self.device)

        source_output = self.ln_code(self.code_encoder(**source_tokens, return_dict=True).last_hidden_state)
        query_tokens = self.query_tokens.expand(source_output.shape[0], -1, -1)

        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=source_output,
            encoder_attention_mask=source_tokens.attention_mask,
            return_dict=True,
        )

        inputs_t5 = self.t5_proj(query_output.last_hidden_state)
        atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(self.device)


        # if "prompt" in samples.keys():
        #     prompt = samples["prompt"]
        # else:
        #     prompt = self.prompt

        # if isinstance(prompt, str):
        #     prompt = [prompt] * image.size(0)
        # else:
        #     assert len(prompt) == image.size(
        #         0
        #     ), "The number of prompts must be equal to the batch size."

        # input_tokens = self.t5_tokenizer(
        #     prompt, padding="longest", return_tensors="pt"
        # ).to(image.device)
        input_tokens = self.t5_tokenizer(
            self.prompt,
            padding="longest",
            return_tensors="pt",
        ).to(self.device)

        encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

        with self.maybe_autocast(dtype=torch.float16):
            inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1)

            outputs = self.t5_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                do_sample=use_nucleus_sampling,
                top_p=top_p,
                temperature=temperature,
                num_beams=num_beams,
                max_new_tokens=max_length,
                min_length=min_length,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                num_return_sequences=num_captions,
            )
            output_text = self.t5_tokenizer.batch_decode(
                outputs, skip_special_tokens=True
            )

        return output_text[0]

    @classmethod
    def from_config(cls, config):
        checkpoint = config['pretrained_path']
        stage1_model = CodeQformer.from_config({'pretrained_path': checkpoint})

        stage1_qformer = stage1_model.Qformer
        stage1_query_tokens = stage1_model.query_tokens

        return cls(stage1_qformer, stage1_query_tokens, **config)


In [19]:
def stage2_test():
    stage_1_checkpoint = 'models/stage1_out/stage1_best.pt'
    stage1_model = CodeQformer.from_config({'pretrained_path': stage_1_checkpoint})

    stage1_qformer = stage1_model.Qformer
    stage1_query_tokens = stage1_model.query_tokens

    source_lang='java'
    target_lang='cs'

    t5_tokenizer = AutoTokenizer.from_pretrained('Salesforce/codet5-base')
    t5_model = T5ForConditionalGeneration.from_pretrained(f'Salesforce/codet5-base-codexglue-translate-{source_lang}-{target_lang}')

    # prompt = f'Translate {source_lang} to {target_lang}'
    prompt = ''



    model = CodeBlipT5(stage1_qformer, stage1_query_tokens, t5_tokenizer=t5_tokenizer, t5_model=t5_model, prompt=prompt).to('cuda' if torch.cuda.is_available() else 'cpu')

    # stage_2_checkpoint = os.path.join("models", "stage2_out", "stage2_best.pt")
    # model.load_state_dict(torch.load(stage_2_checkpoint))
    # Dummy source and target code samples
    source_code_samples = [
        "public class HelloWorld { public static void main(String[] args) { System.out.println(\"Hello, world!\"); } }",
        "public class Test { public static int add(int a, int b) { return a + b; } }"]
    target_code_samples = [
        "class HelloWorld { static void Main(string[] args) { Console.WriteLine(\"Hello, world!\"); } }",
        "class Test { static int Add(int a, int b) { return a + b; } }"]

    # split each sample of target code into two parts: target_code_input and target_code_output
    tci = []
    tco = []
    for i in range(len(target_code_samples)):
        middle = random.randint(1, len(target_code_samples[i]))
        target_code_input = target_code_samples[i][:middle]
        target_code_output = target_code_samples[i][middle:]

        tci.append(target_code_input)
        tco.append(target_code_output)

    # Prepare input for the model
    # samples = {"source_code": source_code_samples, "target_code": target_code_samples}
    samples = {"source_code": source_code_samples, "target_code_input": tci, "target_code_output": tco}
    print(samples)
    # Perform a forward pass
    losses = model(samples)
    print(f"Loss from forward pass: {losses['loss'].item()}")

stage2_test()

Device: cuda
{'source_code': ['public class HelloWorld { public static void main(String[] args) { System.out.println("Hello, world!"); } }', 'public class Test { public static int add(int a, int b) { return a + b; } }'], 'target_code_input': ['class HelloWorld { static void Main(', 'class Test { sta'], 'target_code_output': ['string[] args) { Console.WriteLine("Hello, world!"); } }', 'tic int Add(int a, int b) { return a + b; } }']}
Loss from forward pass: 5.6953125


In [None]:
Device: cuda
Loss from forward pass: 30.375

# Inference

In [20]:
def inference_test():
    stage_1_checkpoint = 'models/stage1_out/stage1_best.pt'
    stage1_model = CodeQformer.from_config({'pretrained_path': stage_1_checkpoint})

    stage1_qformer = stage1_model.Qformer
    stage1_query_tokens = stage1_model.query_tokens

    source_lang='java'
    target_lang='cs'

    t5_tokenizer = AutoTokenizer.from_pretrained('Salesforce/codet5-base')
    t5_model = T5ForConditionalGeneration.from_pretrained(f'Salesforce/codet5-base-codexglue-translate-{source_lang}-{target_lang}')

    # prompt = f'Translate {source_lang} to {target_lang}'
    prompt = ''


    model = CodeBlipT5(stage1_qformer, stage1_query_tokens, t5_tokenizer=t5_tokenizer, t5_model=t5_model, prompt=prompt).to('cuda' if torch.cuda.is_available() else 'cpu')

    # stage_2_checkpoint = os.path.join("models", "stage2_out", "stage2_best.pt")
    # model.load_state_dict(torch.load(stage_2_checkpoint))
    # Dummy source and target code samples
    source_code_samples = [
        "public static void main(String[] args) { System.out.println(\"Hello, world!\"); }",
        "public class Test { public static int add(int a, int b) { return a + b; } }"]

    for source_code in source_code_samples:
        # Prepare input for the model
        # samples = {"source_code": source_code_samples, "target_code": target_code_samples}
        samples = {"source_code": [source_code]}
        print(f'Current Sample: {source_code}')
        # Perform a forward pass
        output = model.generate(samples, max_length=750)
        print(f"Current Output: {output}")
        print()

inference_test()

Device: cuda
Current Sample: public static void main(String[] args) { System.out.println("Hello, world!"); }




Current Output: public override void SetQuality(int quality){if (quality >= 0.0f){return;}if (quality >= 0.0f){return;}if (quality >= 0.0f){return;}if (quality >= 0.0f){return;}if (quality >= 0.0f){return;}

Current Sample: public class Test { public static int add(int a, int b) { return a + b; } }
Current Output: public override void Finish(object o){if (o == null){throw new ArgumentNullException("o");}else{o = new object(o);}}



In [7]:
source_code_samples = [
    "public class HelloWorld { public static void main(String[] args) { System.out.println(\"Hello, world!\"); } }",
    "public class Test { public static int add(int a, int b) { return a + b; } }",
    "public class Main { public static void main(String[] args) { int a = 5; int b = 10; System.out.println(a + b); } }",
]

In [15]:
for sample in source_code_samples:
    print(sample)
    print(model.generate({'source_code': sample}, use_nucleus_sampling=True, num_beams=5, max_length=30, min_length=1, top_p=0.9, repetition_penalty=1.0, length_penalty=1.0, num_captions=1, temperature=1))

public class HelloWorld { public static void main(String[] args) { System.out.println("Hello, world!"); } }
[';;}.}']
public class Test { public static int add(int a, int b) { return a + b; } }
[';;']
public class Main { public static void main(String[] args) { int a = 5; int b = 10; System.out.println(a + b); } }
['']
