In [None]:
Codes = [
    # utils/common
    '''
    def read_data(filename):
        examples = []
        with open(filename, encoding="utf-8") as f:
            for idx, line in enumerate(tqdm(f.readlines(), desc='----- [Reading]')):
                line = line.strip()
                js = json.loads(line)
                if 'idx' not in js: js['idx'] = idx
                code=' '.join(js['code_tokens']).replace('\n', ' ')
                code=' '.join(code.strip().split())
                nl=' '.join(js['docstring_tokens']).replace('\n', '')
                nl=' '.join(nl.strip().split())
                examples.append(
                    Example(idx=idx, source=code, target=nl) 
                )

        return examples
    '''
    ,
    '''
    def write_to_file(output_dir, lines, filename):
        if not os.path.exists(output_dir): os.makedirs(output_dir)      # 没有文件夹则创建
        with open(os.path.join(output_dir, filename), 'w') as f:
            for idx, line in enumerate(lines):
                f.write(str(idx) + '\t' + line + '\n')
            f.close()
    '''
    ,
    '''
    def save_checkpoint(model, output_dir, desc):
        output_desc_dir = os.path.join(output_dir, desc)
        if not os.path.exists(output_desc_dir): os.makedirs(output_desc_dir)    # 没有文件夹则创建
        model_to_save = model.module if hasattr(model, 'module') else model     # Only save the model it-self
        output_model_file = os.path.join(output_desc_dir, "pytorch_model.bin")
        torch.save(model_to_save.state_dict(), output_model_file)
    '''
    ,
    # Processor
    '''
    def __call__(self, examples, params, stage=None):
        features = self.encode(examples, stage)
        dataset = self.to_dataset(features)
        return self.to_dataloader(dataset, params)
    '''
    ,
    '''
    def encode(self, examples, stage=None):
        features = []
        for example_index, example in enumerate(tqdm(examples, desc='----- [Encoding]')):
            # source
            source_ids = torch.LongTensor(self.tokenizer.encode(example.source, 
                add_special_tokens=True, max_length=self.config.max_source_length, truncation=True))
            source_mask = torch.ones_like(source_ids)
    
            # target
            target = 'None' if stage == 'test' else example.target           
            target_ids = torch.LongTensor(self.tokenizer.encode(target, 
                add_special_tokens=True, max_length=self.config.max_target_length, truncation=True))
            target_mask = torch.ones_like(target_ids)
        
            features.append(
                InputFeatures(example_index, source_ids, source_mask, target_ids, target_mask)
            )
        return features
    '''
    ,
    '''
    def decode_one(self, pred):
        return self.tokenizer.decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    '''
    ,
    '''
    def decode(self, preds):
        return [self.decode_one(pred) for pred in tqdm(preds, desc='----- [Decoding]')]
    '''
    ,
    '''
    def to_dataset(self, features: InputFeatures):
        all_source_ids = pad_sequence([f.source_ids for f in features], batch_first=True, padding_value=self.tokenizer.pad_token_id)
        all_source_mask = pad_sequence([f.source_mask for f in features], batch_first=True, padding_value=0)
        all_target_ids = pad_sequence([f.target_ids for f in features], batch_first=True, padding_value=self.tokenizer.pad_token_id)
        all_target_mask = pad_sequence([f.target_mask for f in features], batch_first=True, padding_value=0)
        return TensorDataset(all_source_ids, all_source_mask, all_target_ids, all_target_mask)
    '''
    ,
    '''
    def to_dataloader(self, dataset, params):
        return DataLoader(dataset, **params)
    '''
    ,
    '''
    def metric(self, trues, preds, desc):
        write_to_file(self.config.output_dir, trues, (desc + '.gold'))
        write_to_file(self.config.output_dir, preds, (desc + '.output'))
        predictions = [str(idx) + '\t' + line for idx, line in enumerate(preds)]
        (goldMap, predictionMap) = bleu.computeMaps(predictions, os.path.join(self.config.output_dir, (desc + '.gold')))
        return round(bleu.bleuFromMaps(goldMap, predictionMap)[0], 2)
    '''
    ,
    # Trainer
    '''
    def train(self, train_loader: DataLoader):
        print('[Training info] Num examples: {}, Batch size: {}, Batch: {}'
        .format(len(train_loader.dataset), train_loader.batch_size, len(train_loader)))

        self.model.train()
        loss_list = []
        for batch in tqdm(train_loader, desc='----- [Training]'):
            batch = tuple(t.to(self.device) for t in batch)
            source_ids, source_mask, target_ids, target_mask = batch

            if self.config.model_type.lower() == 'codet5':
                loss = self.model(input_ids=source_ids, attention_mask=source_mask.gt(0), 
                                  labels=target_ids, decoder_attention_mask=target_mask.gt(0)).loss
            else:
                loss, _, _  = self.model(source_ids=source_ids, source_mask=source_mask, 
                                         target_ids=target_ids, target_mask=target_mask)

            loss_list.append(loss.item())
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()
    
        # Loss of train dataset
        train_loss = round(sum(loss_list) / len(loss_list), 5)

        return train_loss, loss_list
    '''
    ,
    '''
    def valid(self, val_loader: DataLoader):
        print('[Validing info] Num examples: {}, Batch size: {}, Batch: {}'
        .format(len(val_loader.dataset), val_loader.batch_size, len(val_loader)))

        self.model.eval()
        eval_loss, tokens_num = 0, 0
        true_ids, pred_ids = [], []
        for batch in tqdm(val_loader, desc='----- [Validing]'):
            batch = tuple(t.to(self.device) for t in batch)
            source_ids, source_mask, target_ids, target_mask = batch

            with torch.no_grad():
                if self.config.model_type.lower() == 'codet5':
                    loss = self.model(input_ids=source_ids, attention_mask=source_mask, 
                                    labels=target_ids, decoder_attention_mask=target_mask).loss
                    eval_loss += loss.item()
                    tokens_num += 1
                    preds = self.model.generate(source_ids, attention_mask=source_mask, use_cache=True, 
                                        num_beams=self.config.beam_size, early_stopping=True, max_length=128)
                else:
                    _, loss, num = self.model(source_ids=source_ids,source_mask=source_mask,
                                          target_ids=target_ids,target_mask=target_mask)
                    eval_loss += loss.sum().item()
                    tokens_num += num.sum().item()
                    preds = self.model(source_ids=source_ids, source_mask=source_mask)[:, 0]
                    self.processor.decode(preds)
                
                true_ids.extend(target_ids)
                pred_ids.extend(preds)

        # Metrics(ppl bleu) of dev dataset    
        eval_ppl = round(np.exp(eval_loss / tokens_num), 5)
        eval_bleu = self.processor.metric(self.processor.decode(true_ids), self.processor.decode(pred_ids), 'dev')

        return eval_ppl, eval_bleu
    '''
    ,
    '''
    def predict(self, test_loader: DataLoader):
        print('Predicting info] Num examples: {}, Batch size: {}, Batch: {}'
        .format(len(test_loader.dataset), test_loader.batch_size, len(test_loader)))

        self.model.eval()
        pred_ids = []
        for batch in tqdm(test_loader, desc='----- [Predicting]'):
            batch = tuple(t.to(self.device) for t in batch)
            source_ids, source_mask, _, _ = batch
            with torch.no_grad():
                if self.config.model_type.lower() == 'codet5':
                    preds = self.model.generate(source_ids, attention_mask=source_mask, use_cache=True, 
                            num_beams=self.config.beam_size, early_stopping=True, max_length=128)
                else:
                    preds = self.model(source_ids=source_ids, source_mask=source_mask)[:, 0]
                
                pred_ids.extend(preds)

        return pred_ids
    '''
]

In [None]:
# 在colab上取消注释这两行
# !pip install transformers
# !wget https://storage.googleapis.com/sfr-codet5-data-research/finetuned_models/summarize_python_codet5_base.bin

In [None]:
import torch
from transformers import RobertaTokenizer, T5ForConditionalGeneration

if __name__ == '__main__':
    tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
    model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')
    model.load_state_dict(torch.load('gs://sfr-codet5-data-research/finetuned_models/summarize_python_codet5_base.bin'))

    text_list = Codes

    for text in text_list:
        input_ids = tokenizer(text, return_tensors="pt").input_ids

        generated_ids = model.generate(input_ids, use_cache=True, 
                  num_beams=10, early_stopping=True, max_length=128)
        print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
        