<a href="https://colab.research.google.com/github/0xVolt/whats-up-doc/blob/main/src/experimental-notebooks/code_trans_t5_small_code_documentation_generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
%pip install -q --no-cache transformers sentencepiece datasets pytorch_lightning

In [11]:
import transformers
from transformers import AutoTokenizer, AutoModelWithLMHead
from datasets import load_dataset
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.optim import AdamW
from torch.utils.data import DataLoader

In [12]:
class CodeTransDataModule(LightningDataModule):
    def __init__(self, dataset, tokenizer, batch_size=16):
        super().__init__()
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.dataset = self.dataset.rename_columns({"instruction": "description", "output": "code"})

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

    def test_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

    def collate_fn(self, batch):
        input_texts = [x["code"] for x in batch]
        description_texts = [x["description"] for x in batch]

        encoding = self.tokenizer(input_texts, description_texts, padding=True, truncation=True, return_tensors="pt", return_attention_mask=True)

        # Manually generate labels from description_texts
        labels = self.tokenizer(description_texts, padding=True, truncation=True, return_tensors="pt").input_ids

        return {
            "input_ids": encoding["input_ids"],
            "labels": labels,
            "attention_mask": encoding["attention_mask"],
        }

    def fit(self, model, trainer, **kwargs):
        # Pass the position bias tensor to the fit() method
        trainer.fit(model, train_dataloader=self.train_dataloader(), position_bias=self.position_bias, **kwargs)

In [13]:
class CodeTransLightningModule(LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate

    def forward(self, input_ids, labels):
        outputs = self.model(input_ids=input_ids, labels=labels)
        return outputs.loss

    def training_step(self, batch, batch_idx):
        loss = self.model(batch["input_ids"], batch["labels"])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.model(batch["input_ids"], batch["labels"])
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self.model(batch["input_ids"], batch["labels"])
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

In [14]:
# Load the pre-trained code trans model
model = AutoModelWithLMHead.from_pretrained("SEBIS/code_trans_t5_base_source_code_summarization_python_transfer_learning_finetune")

# Create the tokenizer
tokenizer = AutoTokenizer.from_pretrained("SEBIS/code_trans_t5_base_source_code_summarization_python_transfer_learning_finetune", model_max_length=29, use_fast=False)



In [15]:
# Load the dataset
dataset = load_dataset("iamtarun/python_code_instructions_18k_alpaca")
dataset = dataset['train']

In [16]:
# Create the LightningDataModule
data_module = CodeTransDataModule(dataset, tokenizer)

In [17]:
# Create the Trainer
trainer = Trainer()

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [18]:
# Train the model
trainer.fit(CodeTransLightningModule(model), data_module)

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py:72: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing t

ValueError: ignored

In [None]:
# Evaluate the model
trainer.test(CodeTransLightningModule(model), data_module)