In [1]:
from datasets import load_from_disk
from transformers import (
    RobertaTokenizer,
    T5ForConditionalGeneration,
)
import torch

In [2]:
torch.set_num_threads(16)

In [3]:
tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')
model = T5ForConditionalGeneration.from_pretrained('/data/nicolasmaier/model/codet5-finetuned-split/checkpoint-312000')

In [6]:
dataset = load_from_disk("/data/nicolasmaier/dataset/hf_clean_dataset")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['code', 'contents', 'xmi', 'originalLine', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 425631
    })
    valid: Dataset({
        features: ['code', 'contents', 'xmi', 'originalLine', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 14634
    })
    test: Dataset({
        features: ['code', 'contents', 'xmi', 'originalLine', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 25156
    })
})


In [7]:
example = dataset['test'][21]
print(example['contents'])
print(example['xmi'])

package targetpackage;

public class TargetClass {
    @CheckReturnValue
    @NonNull
    @SchedulerSupport(SchedulerSupport.CUSTOM)
    public static Observable<Long> interval(long initialDelay, long period, TimeUnit unit, Scheduler scheduler) {
        ObjectHelper.requireNonNull(unit, "unit is null");
        ObjectHelper.requireNonNull(scheduler, "scheduler is null");

        return RxJavaPlugins.onAssembly(new ObservableInterval(Math.max(0L, initialDelay), Math.max(0L, period), unit, scheduler));
    }
}

<?xml version="1.0" encoding="ASCII"?>
<java:Model xmi:version="2.0" xmlns:xmi="http://www.omg.org/XMI" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:java="http://www.eclipse.org/MoDisco/Java/0.2.incubation/java" name="TargetProject">
  <ownedElements name="targetpackage">
    <ownedElements xsi:type="java:ClassDeclaration" originalCompilationUnit="//@compilationUnits.0" name="TargetClass">
      <modifier visibility="public"/>
      <bodyDeclarations xsi:type="jav

In [8]:
def generate_token(model, input_ids, decoder_input_ids):
    outputs = model(input_ids=input_ids, decoder_input_ids=torch.tensor([decoder_input_ids]))
    token = torch.argmax(outputs[0][0, -1, :])
    return token.item()

In [194]:
res = [1]

In [None]:
for i in range(200):
    input = tokenizer(example["contents"], return_tensors='pt')
    token = generate_token(model, input.input_ids, res)

    res.append(token)
    print(tokenizer.decode(torch.tensor(res)))

In [195]:
input = tokenizer(example["contents"], return_tensors='pt')
outputs = model.generate(input.input_ids, max_length=1000, num_beams=5, early_stopping=True)

In [196]:
print(tokenizer.decode(outputs[0]))

<pad><s><?xml version="1.0" encoding="ASCII"?>
<java:Model xmi:version="2.0" xmlns:xmi="http://www.omg.org/XMI" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:java="http://www.eclipse.org/MoDisco/Java/0.2.incubation/java" name="TargetProject">
  <ownedElements name="targetpackage">
    <ownedElements xsi:type="java:ClassDeclaration" originalCompilationUnit="//@compilationUnits.0" name="TargetClass">
      <modifier visibility="public"/>
      <bodyDeclarations xsi:type="java:MethodDeclaration" originalCompilationUnit="//@compilationUnits.0" name="interval">
        <annotations originalCompilationUnit="//@compilationUnits.0">
          <type type="//@unresolvedItems.0"/>
        </annotations>
        <annotations originalCompilationUnit="//@compilationUnits.0">
          <type type="//@unresolvedItems.1"/>
        </annotations>
        <modifier visibility="public" static="true"/>
        <body originalCompilationUnit="//@compilationUnits.0">
          <statements xsi:ty

In [197]:
print(outputs[0].size())

torch.Size([513])
