In [None]:
%load_ext autoreload
%autoreload 2
from transformers import AutoModel
import torch

import pandas as pd
import numpy as np

import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path+"\\py_scripts")
    sys.path.append(module_path+"\\py_scripts\\codeXGLUE_code-to-text")

from transformers import RobertaTokenizer, T5ForConditionalGeneration

import torch
import torchvision as thv
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
import torch.utils.data as data_utils
import torch.nn as nn

from tqdm import tqdm

from helper import get_j_c_data_loaders, to_device, get_device, plot_data


In [None]:
# GLOBALS
device = get_device()
BEAM_SIZE = 10
MAX_SEQ_LEN = 200
SOURCE_LEN = 200
LEARNING_RATE = 5e-5
EPOCHS = 1
BATCH_SIZE = 10 # change depending on the GPU Colab gives you

torch.cuda.empty_cache()

In [None]:
token = RobertaTokenizer.from_pretrained("Salesforce/codet5-small")
train_dl , valid_dl, test_dl = get_j_c_data_loaders(BATCH_SIZE, token)

In [None]:
class SeqSeqModel(torch.nn.Module):
    def __init__(self, seq_len, tokenizer, doc_tokenizer):
        super(SeqSeqModel, self).__init__()
        self.tokenizer = tokenizer
        self.doc_tokenizer = doc_tokenizer
        self.seq_len = seq_len
        self.model_to_doc = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")
        self.model_to_code = T5ForConditionalGeneration.from_pretrained("Salesforce/codet5-small")

    def get_text_embedding(self, input_ids):
        
        dict_embed = {
            "ids":  [],
            "mask": []
        }
        
        for doc in self.generate_doc(input_ids):
            text = self.doc_tokenizer.tokenize(doc)
            if(len(text) != 0):
                input = self.doc_tokenizer.encode_plus(
                    text,
                    None,
                    add_special_tokens=True,
                    max_length=self.seq_len,
                    padding='max_length',
                    return_token_type_ids=True,
                    truncation=True
                )

                dict_embed['ids'] += [input['input_ids']]
                dict_embed['mask'] += [input['attention_mask']]
            else:
                dict_embed['ids'] += [np.zeros(self.seq_len, dtype=np.int_)]
                dict_embed['mask'] += [np.zeros(self.seq_len, dtype=np.int_)]

        dict_embed['ids'] = torch.tensor(np.array(dict_embed['ids']) , dtype=torch.long).to(device)
        dict_embed['mask'] = torch.tensor(np.array(dict_embed['mask']) , dtype=torch.long).to(device)
        return dict_embed

    def forward(self, input_ids, attention_mask, labels):
        doc_out = self.model_to_doc(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=input_ids)
        doc_embedding = self.get_text_embedding(input_ids)
        
        code_out = self.model_to_code(input_ids=doc_embedding['ids']
        , attention_mask=doc_embedding['mask']
        , encoder_outputs=doc_out.decoder_attentions
        , labels=labels)

        return code_out.loss

    
    def generate_doc(self, input_ids, attention_mask=None):
        gen = self.model_to_doc.generate(input_ids, max_length=self.seq_len, attention_mask=attention_mask)
        return [self.tokenizer.decode(entry, skip_special_tokens=True) for entry in gen]

    def generate_code(self, input_ids, attention_mask=None):
        gen = self.model_to_code.generate(input_ids, max_length=self.seq_len, attention_mask=attention_mask)
        return self.tokenizer.decode(gen, skip_special_tokens=True)

In [None]:
@torch.no_grad()
def validate(model, val_load):
    outputs = [model(batch).detach().cpu() for batch in val_load]
    return np.average(outputs)

In [None]:
def train(clf, opt, dl, val):
    local_loss = []
    val_loss = []
    for _ in range(EPOCHS):
        for i, data in tqdm(enumerate(dl)):
            mod_out = clf(data['j_ids'], data['j_mask'], data['j_ids'])
            clf.zero_grad()
            local_loss.append(mod_out.detach().cpu())
            mod_out.backward()
            opt.step()

            if i % 100 == 0 and i != 0:
                val_loss.append(validate(clf, val) )
    
    return local_loss, val_loss
        

In [None]:
model = to_device(SeqSeqModel(MAX_SEQ_LEN, token, token), device)
rob_optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [None]:
loss_arr, val_arr = train(model, rob_optimizer, train_dl, valid_dl)

In [None]:
plot_data((loss_arr, val_arr ), ('training loss', 'validation loss'), (len(loss_arr), len(valid_dl)  ), 'epochs' , 'loss' , 'loss/epoch' )

In [None]:
def sample(clf, dl):
    for _ in range(EPOCHS):
        for data in tqdm(dl):
            with torch.no_grad():
                mod_out = clf.generate_doc(data['j_ids'], data['j_mask'])
                print("jdoc: " , data['j_doc'])
                print("mod_out: " , mod_out)
                break
        break

In [None]:
sample(model, train_dl)