<a href="https://colab.research.google.com/github/0xVolt/whats-up-doc/blob/main/src/experimental-notebooks/code_trans_t5_small_code_documentation_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import transformers
from transformers import AutoTokenizer, AutoModelWithLMHead
from datasets import load_dataset
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.optim import AdamW
from torch.utils.data import DataLoader

class DataModule(LightningDataModule):
    def __init__(self, dataset, tokenizer, batch_size=16):
        super().__init__()
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        # Change the instruction column to description and the output column to code
        self.dataset = self.dataset.rename_columns({"instruction": "description", "output": "code"})

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

    def collate_fn(self, batch):
        input_ids = self.tokenizer([x["code"] for x in batch], padding=True, truncation=True, return_tensors="pt").input_ids
        labels = self.tokenizer([x["description"] for x in batch], padding=True, truncation=True, return_tensors="pt").input_ids

        # Cast the mask tensor to a floating-point tensor before adding it to the position bias tensor
        mask = self.tokenizer([x["code"] for x in batch], return_tensors="pt").attention_mask.float()
        position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

        return {"input_ids": input_ids, "labels": labels, "position_bias": position_bias}

        # input_ids = self.tokenizer([x["instruction"] + " " + x["prompt"] for x in batch], padding=True, truncation=True, return_tensors="pt").input_ids
        # labels = self.tokenizer([x["code"] for x in batch], return_tensors="pt").input_ids
        # return {"input_ids": input_ids, "labels": labels}

class ModelModule(LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate

    def forward(self, input_ids, labels):
        outputs = self.model(input_ids=input_ids, labels=labels)
        return outputs.loss

    def training_step(self, batch, batch_idx):
        loss = self.model(batch["input_ids"], batch["labels"])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.model(batch["input_ids"], batch["labels"])
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self.model(batch["input_ids"], batch["labels"])
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

def main():
    # Load the pre-trained code trans model
    model = AutoModelWithLMHead.from_pretrained("SEBIS/code_trans_t5_base_source_code_summarization_python_transfer_learning_finetune")

    # Load the dataset
    dataset = load_dataset("iamtarun/python_code_instructions_18k_alpaca")
    dataset = dataset['train']

    # Create the tokenizer
    tokenizer = AutoTokenizer.from_pretrained("SEBIS/code_trans_t5_base_source_code_summarization_python_transfer_learning_finetune")

    # Create the LightningDataModule
    data_module = DataModule(dataset, tokenizer)

    # Create the Trainer
    trainer = Trainer()

    # Train the model
    trainer.fit(ModelModule(model), data_module)

    # Evaluate the model
    trainer.test(ModelModule(model), data_module)

if __name__ == "__main__":
    main()