# Code Similarity with Contrastive Learning

## Dependencies

In [None]:
# for data augmentation
%pip install python-minifier

In [None]:
%pip install pytorch-metric-learning

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.tensorboard import SummaryWriter
# Transformers (for CodeBERT etc.)
from transformers import AutoTokenizer, AutoModel

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

## Dataset Access

In [None]:
import pandas as pd
from google.colab import userdata
labeled_dataset_url = f"https://drive.google.com/uc?export=download&id={userdata.get('labeledDataset')}"
unlabeled_dataset_url = f"https://drive.google.com/uc?export=download&id={userdata.get('unlabeledDataset')}"

## Dataset and Data Augmentation

In [None]:
# Code datasets
# (for labeled and unlabeled code snippets)

def _augment(df, *functions):
    """Calculates data augmentations on a dataframe with labeled source code."""
    augs = []
    _aug = df[df['label'].apply(lambda x: x[-1] == '1')]  # dataframe to augment
    for function in functions:
        aug = _aug.copy()
        aug.loc[:, 'source'] = aug['source'].apply(function)
        augs.append(aug)
    df = pd.concat([df, *augs], ignore_index=True)
    # Sort the dataframe so matching labels are next to eachother.
    df.sort_values(by='label', inplace=True)
    return df


def num_label(labels_map: dict, label: str):
    """Transforms a string label using `label_map: dict`."""
    return labels_map[label] if label[-1] == '1' else -1


def num_labels(labels) -> torch.Tensor:
    """Transforms string labels to numeric labels for the NTXent loss function."""
    labels_map = { label: i for i, label in enumerate(sorted(set(labels))) }
    labels_num = [ num_label(labels_map, label) for label in labels ]
    labels_num = torch.Tensor(labels_num)
    neg_indices = (labels_num == -1).nonzero(as_tuple=True)[0]
    M = max(labels_num)
    labels_num[neg_indices] = torch.arange(M + 1, M + 1 + len(neg_indices))
    return labels_num.to(device)


class LabeledCodeDataset(Dataset):
    def __init__(self, codes, labels):
        self.codes = codes
        self.labels = num_labels(labels)

    def __getitem__(self, idx):
        code = self.codes[idx]
        label = self.labels[idx]
        return code, label

    def __len__(self):
        return len(self.codes)

    @classmethod
    def from_csv_data(cls, path: str, sample_size=0,
        augment=False, augment_functions=None):
        df = pd.read_csv(path)
        print(df.shape)
        if augment:
            print('augmenting dataframe...')
            df = _augment(df, *augment_functions)
            print(df.shape)
        if sample_size:
            print('sampling dataframe...')
            df = df.sample(sample_size, ignore_index=True)
            print(df.shape)
        codes = df['source']
        labels = df['label']
        return cls(codes, labels)


class UnlabeledCodeDataset(Dataset):
    def __init__(self, codes):
        self.codes = codes

    def __getitem__(self, idx):
        code = self.codes[idx]
        return code

    def __len__(self):
        return len(self.codes)

    @classmethod
    def from_csv_data(cls, path: str, sample_size=0):
        df = pd.read_csv(path)
        codes = df['file_content']
        if sample_size:
            print('sampling dataframe...')
            df = df.sample(sample_size, ignore_index=True)
            print(df.shape)
        return cls(codes)

In [None]:
# Augmentation functions

import python_minifier

def minify(code: str) -> str:
    try: return python_minifier.minify(code)
    except Exception as error:
        #print(f'Error while minifying: {error}')  # use a log file for this
        pass
    return code

## Model

In [None]:
# Model

class LinearWithBatchNorm(nn.Module):
    """Linear layer with 1D batch normalization."""
    def __init__(self, in_feat, out_feat, dropout_rate: float = 0.2):
        super().__init__()
        layers = []
        layers.append(nn.Linear(in_feat, out_feat))
        layers.append(nn.BatchNorm1d(out_feat))
        layers.append(nn.ReLU())
        layers.append(nn.Dropout(dropout_rate))
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class CodeSimilarityModel(nn.Module):
    def __init__(self,
        pretrained_model="microsoft/codebert-base", seq_len=512,
        fc_hidden_size=512, fc_out=256,
        mlp_hidden_sizes=(128, 64),
        out_feat=16,
        dropout_rate=0.2,
        token_embeddings: bool = False,
    ):
        super().__init__()
        self.seq_len = seq_len
        self.token_embeddings = token_embeddings
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
        if not token_embeddings:
            # Only create transformer if needed
            self.transformer = AutoModel.from_pretrained(pretrained_model).to(device)
            self.in_feat = self.transformer.config.hidden_size
        else:
            self.in_feat = self.seq_len

        self.relu = nn.ReLU()
        # Linear layers
        self.lin1 = nn.Linear(self.in_feat,
                              fc_hidden_size)
        self.lin2 = nn.Linear(fc_hidden_size,
                              fc_out)
        # MLP projection head
        mlp_layers = []
        mlp_layers.append(LinearWithBatchNorm(fc_out,
                                              mlp_hidden_sizes[0], dropout_rate))
        mlp_layers.append(LinearWithBatchNorm(mlp_hidden_sizes[0],
                                              mlp_hidden_sizes[1], dropout_rate))
        mlp_layers.append(nn.Linear(mlp_hidden_sizes[1], out_feat))
        self.mlp = nn.Sequential(*mlp_layers)

    def _tokenize(self, code: str):
        return self.tokenizer(code, return_tensors='pt', truncation=True, padding='max_length', max_length=self.seq_len)
        # TODO:
        # Throw away sequences that are too long instead of truncation.

    def embeddings(self, codes: tuple[str]) -> torch.Tensor:
        """Embed a code snippets with `self.transformer`."""
        with torch.device(device):
            # Tokenize code snippet
            inputs = self._tokenize(codes).to(device)
            if self.token_embeddings:
                # Get the input ids
                return inputs['input_ids'].type(torch.float32)
            else:
                # Get the transformer's pooler output
                return self.transformer(**inputs).pooler_output

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.device(device):
            x = self.lin1(x)
            x = self.relu(x)
            x = self.lin2(x)
            x = self.relu(x)
            x = self.mlp(x)
            return x

## NTXent Loss Function

In [None]:
from pytorch_metric_learning import losses
ntxent_loss = losses.NTXentLoss(temperature=0.5)
ntxent_loss = losses.SelfSupervisedLoss(ntxent_loss)  # for self supervised learning from unlabeled data

## Training

In [None]:
# Create the dataset
#dataset = LabeledCodeDataset.from_csv_data(path=labeled_dataset_url, augment=True, augment_functions=[minify])
dataset = UnlabeledCodeDataset.from_csv_data(path=unlabeled_dataset_url, sample_size=24_000)

In [None]:
# Split the data
tsize = int(0.8 * len(dataset))
vsize = len(dataset) - tsize
training_data, validation_data = random_split(dataset, [tsize, vsize])

In [None]:
# Create the data loaders
BATCH_SIZE = 20  # NOTE: Bigger batch size generally leads to better results in contrastive learning
SHUFFLE = True  # NOTE: 'shuffle' may be set to 'True' for random sampling
training_loader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=SHUFFLE)
validation_loader = DataLoader(validation_data, batch_size=BATCH_SIZE, shuffle=SHUFFLE)

In [None]:
# Training loop for the NTXEnt loss function
from tqdm.auto import tqdm


def compute_loss(batched_data, model, loss_func, is_labeled_data: bool):
    """Computes the loss value for a batch of data.
    This function depends on the loss function passed throug `loss_func` which depends on `is_labeled_data`:
    - `is_labeled_data = True` - Any loss may be used with correctly labeled data.
    - `is_labeled_data = False` - `losses.SelfSupervisedLoss` loss should be used.
    """
    if is_labeled_data:
        codes, labels = batched_data
        embeddings = model.embeddings(codes)  # transformer
        embeddings = model(embeddings)  # fc1-fc2, MLP
        loss = loss_func(embeddings, labels)
        return loss
    else:
        BATH_SIZE = len(batched_data)
        ref_codes = list(batched_data)
        aug_codes = [ minify(code) for code in ref_codes ]  # TODO: This is a PLACEHOLDER, select a random augmentation when more augmentation methods are found/implemented!
        codes = [*ref_codes, *aug_codes]
        embeddings = model.embeddings(codes)  # transformer
        embeddings = model(embeddings)  # fc1-fc2, MLP
        ref_emb, aug_emb = embeddings[:BATH_SIZE], embeddings[BATH_SIZE:]
        loss = loss_func(ref_emb, aug_emb)
        return loss


def train_epoch(
    model: CodeSimilarityModel,
    loader: DataLoader,
    loss_func,
    optimizer,
    epochs: int                  = 0,     # number of epochs so far (for logging),
    writer: SummaryWriter | None = None,  # for logging loss values,
):
    """Trains the model for one epoch."""
    def get_last_loss(n_batches, c_batches, batch, acc_loss):
        if batch % c_batches == c_batches - 1:
            return 0, acc_loss / c_batches
        elif batch == N_BATCHES - 1:
            return 0, acc_loss / (n_batches % c_batches)
        return acc_loss, 0

    def write_loss(writer, epoch, n_batches, batch, last_loss):
        # Log the average loss over the last  batches
        #print('',f'Batch: {batch + 1}/{n_batches}, Loss: {last_loss}')  # use a log file for this
        if writer is not None:
            writer.add_scalar("loss/train", last_loss, epochs * n_batches + batch + 1)

    N_BATCHES = len(loader)  # number of batches
    C_BATCHES = 50  # number of batches over which the logged loss is cumulated
    sum_loss = 0  # Loss accumulated per EPOCH
    acc_loss = 0  # Loss accumulated per last 25 batches
    progress_bar = tqdm(range(N_BATCHES))
    for i, data in enumerate(loader):
        optimizer.zero_grad()
        loss = compute_loss(data, model, loss_func, is_labeled_data=False)
        # Adjust the weights
        loss.backward()
        optimizer.step()
        # Increase loss accumulator
        loss_val = loss.item()
        acc_loss += loss_val; sum_loss += loss_val
        # Update the loss accumulator and log the last loss
        progress_bar.update(1)
        acc_loss, last_loss = get_last_loss(N_BATCHES, C_BATCHES, i, acc_loss)
        if last_loss: write_loss(writer, epochs, N_BATCHES, i, last_loss)

    # Return the average loss in the epoch
    avg_loss = sum_loss / N_BATCHES
    return avg_loss


def validate(
    model: CodeSimilarityModel,
    loader: DataLoader,
    loss_func,
):
    """Validates the model for one epoch."""
    with torch.no_grad():
        sum_loss = 0
        for data in loader:
            loss = compute_loss(data, model, loss_func, is_labeled_data=False)
            sum_loss += loss.item()
        avg_loss = sum_loss / len(loader)
        return avg_loss


def train(
    model: CodeSimilarityModel,
    dataloaders,
    loss_func,
    optimizer, scheduler,
    epochs: int = 5,
):
    tLosses, vLosses = [], []
    writer = None #SummaryWriter()
    training_loader, validation_loader = dataloaders
    model.train()

    for epoch in range(epochs):
        print(f'EPOCH {epoch + 1}/{epochs}')
        # Train then validate
        avg_tLoss = train_epoch(model, training_loader, loss_func, optimizer, epoch, writer)
        avg_vLoss = validate(model, validation_loader, loss_func)
        # Adjust the LR scheduler
        if scheduler is not None:
            scheduler.step()
        # Log the losses
        print(f"EPOCH {epoch + 1}/{epochs}, AVG loss: {avg_tLoss}, AVG validation loss: {avg_vLoss}")
        tLosses.append(avg_tLoss)
        vLosses.append(avg_vLoss)

    if writer is not None:
        writer.close()
    return tLosses, vLosses

In [None]:
# Create model, tokenizer, and optimizer
model = CodeSimilarityModel(pretrained_model='huggingface/CodeBERTa-small-v1', seq_len=512, token_embeddings=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

In [None]:
# Train the model
epochs = 12
loss_func = ntxent_loss
losses = train(model, (training_loader, validation_loader), loss_func, optimizer, lr_scheduler, epochs)

In [None]:
plt.plot(losses[0])
plt.plot(losses[1])
plt.legend(['training loss', 'validation loss'])
plt.show()

## Notes:

### Pooling's effect on the loss:

| Pooling | Effect |
| ------- | ------ |
|mean pooling after embedding w/ transformer|avg-loss/epoch converges to ~1.6|
|max pooling before loss function|avg-loss/epoch converges to ~1.4|
|max pooling before final MLP|loss jumps around 2-3 in the first epoch|

### Data TODOs:
- ✅ - Throw away code snippets that are too long
- ❎ - A lot of codes snippets mined from github can't be minified, filter unlabeled code dataset!
- ❎ - Pre calculate data augmentations

### Model TODOs:
- ❎ - Log hyperparameters before training
- ✅ - Try tokenizer embeddings
- ❎ - Try code2vec embedding
- ❎ - Try training with transformer's `pooler_output`