Архитектура модели анализа кода

В данном файле проводится анализ архитектуры модели, токенизатора и подготовка к обучению модели

Импортируем необходимые модули

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
import torch.nn as nn 
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence

from torch.utils.tensorboard import summary, writer, SummaryWriter
from tqdm import tqdm

Устанавливаем SEED

In [3]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

import warnings
warnings.filterwarnings("ignore")

Далее считываем исходный датасет и немного дорабатываем его

In [4]:
code_dataset = pd.read_parquet('code_dataset.parquet', engine='pyarrow')

Импортируем дополнительеные функции

In [5]:
from create_dataset_m2t_tokens import final_preparations

Получаем датасет

In [6]:
code_dataset = final_preparations(code_dataset_copy=code_dataset)

100%|██████████| 280458/280458 [00:00<00:00, 391088.22it/s]
100%|██████████| 280458/280458 [00:00<00:00, 308981.95it/s]
100%|██████████| 280458/280458 [00:00<00:00, 318787.39it/s]
100%|██████████| 280458/280458 [00:00<00:00, 359814.29it/s]
100%|██████████| 280458/280458 [00:00<00:00, 317530.69it/s]
100%|██████████| 280458/280458 [00:00<00:00, 294418.12it/s]
100%|██████████| 280458/280458 [00:01<00:00, 154419.45it/s]
100%|██████████| 280458/280458 [00:02<00:00, 108791.82it/s]


In [7]:
code_dataset.head()

Unnamed: 0,response,focal_method,focal_cls,focal_method_ast,focal_cls_ast,focal_method_info,focal_cls_info,input_string_focal_method,input_string_focal_cls
0,"from microdot import Microdot, Response, abort...","<FUNC_TOKEN> def get(self, key, default=None):...",<CLS_TOKEN> <FUNC_TOKEN>,<AST_TOKEN> Module( body=[ FunctionDef( name='...,<AST_TOKEN>,<INFO_TOKEN>,<INFO_TOKEN>,"<FUNC_TOKEN> def get(self, key, default=None):...",<CLS_TOKEN> <FUNC_TOKEN> <INFO_TOKEN> <AST_TOKEN>
1,"from microdot import Microdot, Response, abort...","<FUNC_TOKEN> def get(self, url_pattern): retur...","<CLS_TOKEN> class Microdot: def route(self, ur...",<AST_TOKEN> Module( body=[ FunctionDef( name='...,<AST_TOKEN> Module( body=[ ClassDef( name='Mic...,<INFO_TOKEN> <DESCRIPTION_TOKEN> Decorator tha...,<INFO_TOKEN> Module( body=[ ClassDef( name='Mi...,"<FUNC_TOKEN> def get(self, url_pattern): retur...","<CLS_TOKEN> class Microdot: def route(self, ur..."
2,"from microdot import Microdot, Response, abort...","<FUNC_TOKEN> def post(self, url_pattern): retu...","<CLS_TOKEN> class Microdot: def route(self, ur...",<AST_TOKEN> Module( body=[ FunctionDef( name='...,<AST_TOKEN> Module( body=[ ClassDef( name='Mic...,<INFO_TOKEN> <DESCRIPTION_TOKEN> Decorator tha...,<INFO_TOKEN> Module( body=[ ClassDef( name='Mi...,"<FUNC_TOKEN> def post(self, url_pattern): retu...","<CLS_TOKEN> class Microdot: def route(self, ur..."
3,"from microdot import Microdot, Response, abort...","<FUNC_TOKEN> def mount(self, subapp, url_prefi...",<CLS_TOKEN> <FUNC_TOKEN>,<AST_TOKEN> Module( body=[ FunctionDef( name='...,<AST_TOKEN>,<INFO_TOKEN> <DESCRIPTION_TOKEN> Mount a sub-a...,<INFO_TOKEN>,"<FUNC_TOKEN> def mount(self, subapp, url_prefi...",<CLS_TOKEN> <FUNC_TOKEN> <INFO_TOKEN> <AST_TOKEN>
4,from pyner.named_entity.corpus import bio2bioe...,<FUNC_TOKEN> def iob2bio(tags): processed_tags...,<CLS_TOKEN> def split_tag(tag: str): if tag in...,<AST_TOKEN> Module( body=[ FunctionDef( name='...,<AST_TOKEN> Module( body=[ FunctionDef( name='...,<INFO_TOKEN> <DESCRIPTION_TOKEN> should be bio...,<INFO_TOKEN> Module( body=[ FunctionDef( name=...,<FUNC_TOKEN> def iob2bio(tags): processed_tags...,<CLS_TOKEN> def split_tag(tag: str): if tag in...


In [8]:
print(code_dataset['input_string_focal_method'].values[0])

<FUNC_TOKEN> def get(self, key, default=None): kl = key.lower() return super().get(self.keymap.get(kl, kl), default) <INFO_TOKEN> <AST_TOKEN> Module( body=[ FunctionDef( name='get', args=arguments( posonlyargs=[], args=[ arg(arg='self'), arg(arg='key'), arg(arg='default')], kwonlyargs=[], kw_defaults=[], defaults=[ Constant(value=None)]), body=[ Assign( targets=[ Name(id='kl', ctx=Store())], value=Call( func=Attribute( value=Name(id='key', ctx=Load()), attr='lower', ctx=Load()), args=[], keywords=[])), Return( value=Call( func=Attribute( value=Call( func=Name(id='super', ctx=Load()), args=[], keywords=[]), attr='get', ctx=Load()), args=[ Call( func=Attribute( value=Attribute( value=Name(id='self', ctx=Load()), attr='keymap', ctx=Load()), attr='get', ctx=Load()), args=[ Name(id='kl', ctx=Load()), Name(id='kl', ctx=Load())], keywords=[]), Name(id='default', ctx=Load())], keywords=[]))], decorator_list=[], type_params=[])], type_ignores=[])


Наконец, переходим к анализу архитектур нейросетей

Решено использовать подход, основанный на обучении (fine-tuning) нейросети CodeBERT, в основе которой лежит модель RoBERTa. Далее будем использовать метамодель в виде декодера (CodeGen или GPTBigCode)

In [9]:
from transformers import AutoTokenizer, AutoModel

Device:

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


Токенизаторы:

In [11]:
tokenizer_code_bert = AutoTokenizer.from_pretrained("microsoft/codebert-base")
tokenizerGPT = AutoTokenizer.from_pretrained("gpt2")
tokenizerGPT.add_special_tokens({'pad_token': '<PAD>'})

1

Посмотрим как работает базовый токенизатор для CodeBERT

Перед этим добавим новые служебные токены:

In [12]:
new_special_tokens = ['<FUNC_TOKEN>',
            '<INFO_TOKEN>',
            '<CLS_TOKEN>', 
            '<AST_TOKEN>', 
            '<DESCRIPTION_TOKEN>',
            '<COMMENTS_TOKEN>']

special_tokens_dict = {
    'additional_special_tokens': new_special_tokens
}

tokenizer_code_bert.add_special_tokens(special_tokens_dict)
# model_code_bert.resize_token_embeddings(len(tokenizer_code_bert))

6

In [13]:
def tokenization_example(input_str: str):
	'''Функция отображения результатов токенизации'''
	code_bert_tokens_example = tokenizer_code_bert.tokenize(input_str)
	code_bert_tokens_ids = tokenizer_code_bert.convert_tokens_to_ids(code_bert_tokens_example)
	code_bert_decoded = tokenizer_code_bert.decode(code_bert_tokens_ids)
	print(f"Длина закодированной последовательности: {len(code_bert_tokens_example)}")
	print(f"Как выглядят токены исходной фразы: {code_bert_tokens_example}")
	print(f"Индексы токенов: {code_bert_tokens_ids}")
	print(f"Декодированная строка: {code_bert_decoded}")

tokenization_example(code_dataset['input_string_focal_method'].values[0])

Длина закодированной последовательности: 325
Как выглядят токены исходной фразы: ['<FUNC_TOKEN>', 'Ġdef', 'Ġget', '(', 'self', ',', 'Ġkey', ',', 'Ġdefault', '=', 'None', '):', 'Ġk', 'l', 'Ġ=', 'Ġkey', '.', 'lower', '()', 'Ġreturn', 'Ġsuper', '().', 'get', '(', 'self', '.', 'key', 'map', '.', 'get', '(', 'kl', ',', 'Ġk', 'l', '),', 'Ġdefault', ')', 'Ġ', '<INFO_TOKEN>', 'Ġ', '<AST_TOKEN>', 'ĠModule', '(', 'Ġbody', '=[', 'ĠFunction', 'Def', '(', 'Ġname', "='", 'get', "',", 'Ġargs', '=', 'arg', 'uments', '(', 'Ġpos', 'only', 'args', '=[', '],', 'Ġargs', '=[', 'Ġarg', '(', 'arg', "='", 'self', "'),", 'Ġarg', '(', 'arg', "='", 'key', "'),", 'Ġarg', '(', 'arg', "='", 'default', "')", '],', 'Ġk', 'w', 'only', 'args', '=[', '],', 'Ġk', 'w', '_', 'default', 's', '=[', '],', 'Ġdefaults', '=[', 'ĠConstant', '(', 'value', '=', 'None', ')]', '),', 'Ġbody', '=[', 'ĠAss', 'ign', '(', 'Ġtargets', '=[', 'ĠName', '(', 'id', "='", 'kl', "',", 'Ġc', 'tx', '=', 'Store', '())', '],', 'Ġvalue', '=', 'Call', '

Далее необхоимо описать класс Dataset для нашей модели

In [14]:
class Code2TestDataset(Dataset):
	'''Класс датасет для задачи генерации тестов'''

	def __init__(self, code_dataset, tokenizer_code_bert, tokenizer_gpt, max_length=512):
		'''
		Конструктор датасета

		Параметры:
		- code_dataset: датасет pd.DataFrame
		- tokenizer_code_bert: токенизатор code_bert
		- tokenizer_gpt: токенизатор gpt
		- max_length: максимальная длина последовательности (default: 512)
		'''
		self.code_dataset = code_dataset
		self.tokenizer_code_bert = tokenizer_code_bert
		self.tokenizer_gpt = tokenizer_gpt
		self.max_length = max_length

	def __getitem__(self, idx, idx_to_token=False):
		'''
		Get-метод - возвращает сэмпл по индексу

		Параметры:
		- idx: индекс
		- idx_to_token: флаг для отображения токенов из индексов (default: False)
		'''
		focal_method_input = self.code_dataset.at[idx, 'input_string_focal_method']
		focal_cls_input = self.code_dataset.at[idx, 'input_string_focal_cls']
		response = self.code_dataset.at[idx, 'response']

		def encode_text(text, tokenizer):
			encoding = tokenizer.encode_plus(
				text,
				add_special_tokens=True,
				max_length=self.max_length if tokenizer == self.tokenizer_code_bert else self.max_length * 2,
				padding='max_length',
				truncation=True,
				return_attention_mask=True,
				return_tensors='pt',
			)
			input_ids = encoding['input_ids'].flatten()
			attention_mask = encoding['attention_mask'].flatten()
			return input_ids, attention_mask

		input_ids_focal_method, attention_mask_focal_method = encode_text(focal_method_input, self.tokenizer_code_bert)
		input_ids_focal_cls, attention_mask_focal_cls = encode_text(focal_cls_input, self.tokenizer_code_bert)
		input_ids_response, attention_mask_response = encode_text(response, self.tokenizer_gpt)

		if idx_to_token:
			return {
				'input_ids_focal_method': self.tokenizer_code_bert.convert_ids_to_tokens(input_ids_focal_method),
				'attention_mask_focal_method': attention_mask_focal_method,
				'input_ids_focal_cls': self.tokenizer_code_bert.convert_ids_to_tokens(input_ids_focal_cls),
				'attention_mask_focal_cls': attention_mask_focal_cls,
				'ids_response': self.tokenizer_gpt.convert_ids_to_tokens(input_ids_response),
				'attention_mask_response': attention_mask_response
			}
		return {
			'input_ids_focal_method': input_ids_focal_method,
			'attention_mask_focal_method': attention_mask_focal_method,
			'input_ids_focal_cls': input_ids_focal_cls,
			'attention_mask_focal_cls': attention_mask_focal_cls,
			'ids_response': input_ids_response,
			'attention_mask_response': attention_mask_response
		}
	
	def __len__(self):
		'''Функция возвращает длину датасета. В качестве длины берется размер датасета по axis = 0'''
		return self.code_dataset.shape[0]


Тестируем написанный класс

In [15]:
code2test_dataset = Code2TestDataset(code_dataset=code_dataset,
                                     tokenizer_code_bert=tokenizer_code_bert,
                                     tokenizer_gpt=tokenizerGPT)

In [16]:
print(code2test_dataset.__getitem__(490, idx_to_token=True)['input_ids_focal_method'])

['<s>', '<FUNC_TOKEN>', 'Ġdef', 'Ġarray', '_', 'to', '_', 's', 'log', '(', 'x', ':', 'ĠArray', ')', 'Ġ->', 'ĠSL', 'Array', ':', 'Ġreturn', 'Ġ(', 'j', 'np', '.', 'sign', '(', 'x', '),', 'Ġj', 'np', '.', 'log', '(', 'j', 'np', '.', 'abs', '(', 'x', ')))', 'Ġ', '<INFO_TOKEN>', 'Ġ', '<DESCRIPTION_TOKEN>', 'ĠCon', 'verts', 'Ġa', 'Ġregular', 'Ġarray', 'Ġinto', 'Ġ(', 'sign', ',', 'Ġlog', 'abs', ')', 'Ġform', '.', 'ĠAr', 'gs', ':', 'Ġx', 'Ġ(', 'Array', '):', 'Ġinput', 'Ġdata', '.', 'ĠReturns', ':', 'Ġ(', 'SL', 'Array', '):', 'Ġdata', 'Ġin', 'Ġform', 'Ġ(', 'sign', '(', 'x', '),', 'Ġlog', '(', 'abs', '(', 'x', ')))', 'Ġ', '<COMMENTS_TOKEN>', 'Ġ', '<AST_TOKEN>', 'ĠModule', '(', 'Ġbody', '=[', 'ĠFunction', 'Def', '(', 'Ġname', "='", 'array', '_', 'to', '_', 's', 'log', "',", 'Ġargs', '=', 'arg', 'uments', '(', 'Ġpos', 'only', 'args', '=[', '],', 'Ġargs', '=[', 'Ġarg', '(', 'Ġarg', "='", 'x', "',", 'Ġannotation', '=', 'Name', '(', 'id', "='", 'Array', "',", 'Ġc', 'tx', '=', 'Load', '()', '))', '],'

In [17]:
print(code2test_dataset.__getitem__(490, idx_to_token=True)['input_ids_focal_cls'])

['<s>', '<CLS_TOKEN>', 'Ġfrom', 'Ġ.', 'ty', 'ping', 'Ġimport', 'ĠArray', ',', 'ĠSL', 'Array', ',', 'ĠArray', 'List', ',', 'ĠSL', 'Array', 'List', 'import', 'Ġj', 'ax', '.', 'n', 'umpy', 'Ġas', 'Ġj', 'np', 'Ġ', '<FUNC_TOKEN>', 'Ġ', '<INFO_TOKEN>', 'ĠModule', '(', 'Ġbody', '=[', 'ĠImport', 'From', '(', 'Ġmodule', "='", 'ty', 'ping', "',", 'Ġnames', '=[', 'Ġalias', '(', 'name', "='", 'Array', "'),", 'Ġalias', '(', 'name', "='", 'SL', 'Array', "'),", 'Ġalias', '(', 'name', "='", 'Array', 'List', "'),", 'Ġalias', '(', 'name', "='", 'SL', 'Array', 'List', "')", '],', 'Ġlevel', '=', '1', '),', 'ĠImport', '(', 'Ġnames', '=[', 'Ġalias', '(', 'name', "='", 'j', 'ax', '.', 'n', 'umpy', "',", 'Ġas', 'name', "='", 'j', 'np', "')", ']),', 'ĠFunction', 'Def', '(', 'Ġname', "='", 'array', '_', 'to', '_', 's', 'log', "',", 'Ġargs', '=', 'arg', 'uments', '(', 'Ġpos', 'only', 'args', '=[', '],', 'Ġargs', '=[', 'Ġarg', '(', 'Ġarg', "='", 'x', "',", 'Ġannotation', '=', 'Name', '(', 'id', "='", 'Array', "',

In [18]:
print(code2test_dataset.__getitem__(490, idx_to_token=True)['ids_response'])

['from', 'Ġtests', '.', 'test', '_', 'utils', 'Ġimport', 'Ġassert', '_', 'py', 'tree', '_', 'all', 'close', 'Ċ', 'import', 'Ġv', 'mc', 'net', '.', 'utils', '.', 's', 'log', '_', 'help', 'ers', 'Ġas', 'Ġhelpers', 'Ċ', 'from', 'Ġtyping', 'Ġimport', 'ĠT', 'uple', 'Ċ', 'from', 'Ġv', 'mc', 'net', '.', 'utils', '.', 'ty', 'ping', 'Ġimport', 'ĠArray', ',', 'ĠSL', 'Array', 'Ċ', 'import', 'Ġj', 'ax', '.', 'n', 'umpy', 'Ġas', 'Ġj', 'np', 'Ċ', 'Ċ', 'def', 'Ġ_', 'get', '_', 'array', '_', 'and', '_', 's', 'log', '_', 'vals', '()', 'Ġ->', 'ĠT', 'uple', '[', 'Array', ',', 'ĠSL', 'Array', ']:', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġv', 'als', 'Ġ=', 'Ġj', 'np', '.', 'array', '([', 'j', 'np', '.', 'e', ',', 'Ġ-', 'j', 'np', '.', 'e', '**', '0', '.', '5', ',', 'Ġ0', ',', 'Ġ1', '])', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġsigns', 'Ġ=', 'Ġj', 'np', '.', 'array', '([', '1', ',', 'Ġ-', '1', ',', 'Ġ0', ',', 'Ġ1', '])', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġlogs', 'Ġ=', 'Ġj', 'np', '.', 'array', '([', '1', ',', 'Ġ0', '.', '5', ',', 'Ġ-', 'j', 'np', '.', 'inf

In [19]:
print(code2test_dataset[490]['attention_mask_response'])

tensor([1, 1, 1,  ..., 0, 0, 0])


In [20]:
print(code2test_dataset.__getitem__(490, idx_to_token=True)['attention_mask_focal_method'])

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [21]:
print(code2test_dataset.__getitem__(490, idx_to_token=True)['attention_mask_focal_cls'])

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [22]:
print(f"Длина датасета составляет: {len(code2test_dataset)}")

Длина датасета составляет: 280458


Всё работает корректно! Следующим шагом необходимо разбить датасет на train и val

In [23]:
def get_datasets(dataset_cls = Code2TestDataset,
				max_length = 512,
				data = code_dataset,
				tokenizer_code_bert = tokenizer_code_bert,
				tokenizer_gpt = tokenizerGPT,
				train_size = 0.7):
	'''
	Функция get_datasets() возвращает train и val датасеты на основе конструктора AccentDataset, делая train_val_spilt
	
	Параметры:
	-dataset_cls: класс датасета, конструктор которого будет вызываться (default: Code2TestDataset)
	-max_length: максимальная статья последовательности токенов
	-data: датасает pd.DataFrame (default: code_dataset)
	-tokenizer: токенизатор (default: tokenizer_code_bert)
	-train_size: размер тренировочной выборки (default: 0.7)
	
	'''
	
	dataset = dataset_cls(code_dataset = data,
					   	tokenizer_code_bert = tokenizer_code_bert,
						tokenizer_gpt=tokenizer_gpt,
						max_length=max_length)
	
	train_size = int(train_size * len(dataset))
	val_size = len(dataset) - train_size
	train_dataset, test_dataset = random_split(dataset, [train_size, val_size])

	return train_dataset, test_dataset

train_dataset, val_dataset = get_datasets()

Проверяем полученные датасеты

In [24]:
print(f"Количество данных в train и val выборках соответственно: {len(train_dataset), len(val_dataset)}")

Количество данных в train и val выборках соответственно: (196320, 84138)


In [25]:
def decode_sequence(tokens_ids, tokenizer):
	'''Декодирование последовательности токенов'''
	code_bert_decoded = tokenizer.decode(tokens_ids)
	print(f"Декодированная строка: {code_bert_decoded}")

Тренировочный сэмпл

In [26]:
decode_sequence(train_dataset[0]['input_ids_focal_method'], tokenizer_code_bert)

Декодированная строка: <s><FUNC_TOKEN> def score_to_label(pred_scores, outliers_fraction=0.1): pred_scores = column_or_1d(pred_scores) check_parameter(outliers_fraction, 0, 1) threshold = percentile(pred_scores, 100 * (1 - outliers_fraction)) pred_labels = (pred_scores > threshold).astype('int') return pred_labels <INFO_TOKEN> <DESCRIPTION_TOKEN> Turn raw outlier outlier scores to binary labels (0 or 1). Parameters ---------- pred_scores : list or numpy array of shape (n_samples,) Raw outlier scores. Outliers are assumed have larger values. outliers_fraction : float in (0,1) Percentage of outliers. Returns ------- outlier_labels : numpy array of shape (n_samples,) For each observation, tells whether or not it should be considered as an outlier according to the fitted model. Return the outlier probability, ranging in [0,1]. <COMMENTS_TOKEN> check input values <AST_TOKEN> Module( body=[ FunctionDef( name='score_to_label', args=arguments( posonlyargs=[], args=[ arg(arg='pred_scores'), arg

In [27]:
decode_sequence(train_dataset[0]['input_ids_focal_cls'], tokenizer_code_bert)

Декодированная строка: <s><CLS_TOKEN> from sklearn.utils import column_or_1dfrom pyod.utils.utility import check_parameterfrom numpy import percentile <FUNC_TOKEN> <INFO_TOKEN> Module( body=[ ImportFrom( module='sklearn.utils', names=[ alias(name='column_or_1d')], level=0), ImportFrom( module='pyod.utils.utility', names=[ alias(name='check_parameter')], level=0), ImportFrom( module='numpy', names=[ alias(name='percentile')], level=0), FunctionDef( name='score_to_label', args=arguments( posonlyargs=[], args=[ arg(arg='pred_scores'), arg(arg='outliers_fraction')], kwonlyargs=[], kw_defaults=[], defaults=[ Constant(value=0.1)]), body=[ Expr( value=Constant(value='Turn raw outlier outlier scores to binary labels (0 or 1).\n Parameters\n ----------\n pred_scores : list or numpy array of shape (n_samples,)\n Raw outlier scores. Outliers are assumed have larger values.\n outliers_fraction : float in (0,1)\n Percentage of outliers.\n Returns\n -------\n outlier_labels : numpy array of shape (n

In [28]:
decode_sequence(train_dataset[0]['ids_response'], tokenizerGPT)

Декодированная строка: import unittest
from numpy.testing import assert_allclose
from utils.utility import score_to_label

class TestMetrics(unittest.TestCase):
    def test_score_to_label(self):
        manual_scores = [0.1, 0.4, 0.2, 0.3, 0.5, 0.9, 0.7, 1, 0.8, 0.6]
        labels = score_to_label(manual_scores, outliers_fraction=0.1)
        assert_allclose(labels, [0, 0, 0, 0, 0, 0, 0, 1, 0, 0])
        labels = score_to_label(manual_scores, outliers_fraction=0.3)
        assert_allclose(labels, [0, 0, 0, 0, 0, 1, 0, 1, 1, 0])
if __name__ == '__main__':
    unittest.main()
<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><

Валидационный сэмпл

In [29]:
decode_sequence(val_dataset[0]['input_ids_focal_method'], tokenizer_code_bert)

Декодированная строка: <s><FUNC_TOKEN> def valid_actions(self): num_raises_so_far = sum([p.raised for p in self.players]) if num_raises_so_far == self.num_players: return ['F', 'C'] else: if self.round == 0: return ['F', 'C', '2R'] else: return ['F', 'C', '4R'] <INFO_TOKEN> <AST_TOKEN> Module( body=[ FunctionDef( name='valid_actions', args=arguments( posonlyargs=[], args=[ arg(arg='self')], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[ Assign( targets=[ Name(id='num_raises_so_far', ctx=Store())], value=Call( func=Name(id='sum', ctx=Load()), args=[ ListComp( elt=Attribute( value=Name(id='p', ctx=Load()), attr='raised', ctx=Load()), generators=[ comprehension( target=Name(id='p', ctx=Store()), iter=Attribute( value=Name(id='self', ctx=Load()), attr='players', ctx=Load()), ifs=[], is_async=0)])], keywords=[])), If( test=Compare( left=Name(id='num_raises_so_far', ctx=Load()), ops=[ Eq()], comparators=[ Attribute( value=Name(id='self', ctx=Load()), attr='num_players', ctx=Load())]), b

In [30]:
decode_sequence(val_dataset[0]['input_ids_focal_cls'], tokenizer_code_bert)

Декодированная строка: <s><CLS_TOKEN> <FUNC_TOKEN> <INFO_TOKEN> <AST_TOKEN></s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [31]:
decode_sequence(val_dataset[0]['ids_response'], tokenizerGPT)

Декодированная строка: from leduc.state import Leduc
from leduc.state import State

def test_valid_actions():
    state = State([1, 2, 3], 2, None)
    actions = state.valid_actions()
    assert actions == ['F', 'C', '1R'], actions
    state.take('C')
    actions = state.valid_actions()
    assert actions == ['F', 'C', '1R'], actions
    state = State([1, 2, 3], 2, None)
    state.take('1R')
    actions = state.valid_actions()
    assert actions == ['F', 'C'], actions
<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><P

Корректно работает!

Далее получим DataLoader, по которому будем итерироваться

In [32]:
def get_loaders(train_dataset = train_dataset,
			val_dataset = val_dataset,
			shuffle_train = True,
			shuffle_val = False,
			batch_size = 32):
	
	'''
	Функция get_loaders() для получения train, val даталоадеров

	Параметры:
	-train_dataset: тренировочный датасет (default: train_dataset)
	-val_dataset: валидационный датасет (default: val_dataset)
	-shuffle_train: флаг перемешивания для train (default: True)
	-shuffle_val: флаг перемешивания для val (default: False)
	-batch_size: размер батча данных (default: 32)
	'''
	
	# train_dataloader
	train_dataloader = DataLoader(
			train_dataset,   
			batch_size = batch_size,
			shuffle = shuffle_train,
		)

	# validation_dataloader
	validation_dataloader = DataLoader(
			val_dataset, 
			batch_size = batch_size,
			shuffle = shuffle_val,
		)
	
	# Возвращаем даталоадеры
	return train_dataloader, validation_dataloader

train_dataloader, validation_dataloader = get_loaders(batch_size=2)

Проверка

In [33]:
decode_sequence(train_dataloader.dataset[0]['input_ids_focal_method'], tokenizer_code_bert)

Декодированная строка: <s><FUNC_TOKEN> def score_to_label(pred_scores, outliers_fraction=0.1): pred_scores = column_or_1d(pred_scores) check_parameter(outliers_fraction, 0, 1) threshold = percentile(pred_scores, 100 * (1 - outliers_fraction)) pred_labels = (pred_scores > threshold).astype('int') return pred_labels <INFO_TOKEN> <DESCRIPTION_TOKEN> Turn raw outlier outlier scores to binary labels (0 or 1). Parameters ---------- pred_scores : list or numpy array of shape (n_samples,) Raw outlier scores. Outliers are assumed have larger values. outliers_fraction : float in (0,1) Percentage of outliers. Returns ------- outlier_labels : numpy array of shape (n_samples,) For each observation, tells whether or not it should be considered as an outlier according to the fitted model. Return the outlier probability, ranging in [0,1]. <COMMENTS_TOKEN> check input values <AST_TOKEN> Module( body=[ FunctionDef( name='score_to_label', args=arguments( posonlyargs=[], args=[ arg(arg='pred_scores'), arg

In [34]:
decode_sequence(train_dataloader.dataset[0]['input_ids_focal_cls'], tokenizer_code_bert)

Декодированная строка: <s><CLS_TOKEN> from sklearn.utils import column_or_1dfrom pyod.utils.utility import check_parameterfrom numpy import percentile <FUNC_TOKEN> <INFO_TOKEN> Module( body=[ ImportFrom( module='sklearn.utils', names=[ alias(name='column_or_1d')], level=0), ImportFrom( module='pyod.utils.utility', names=[ alias(name='check_parameter')], level=0), ImportFrom( module='numpy', names=[ alias(name='percentile')], level=0), FunctionDef( name='score_to_label', args=arguments( posonlyargs=[], args=[ arg(arg='pred_scores'), arg(arg='outliers_fraction')], kwonlyargs=[], kw_defaults=[], defaults=[ Constant(value=0.1)]), body=[ Expr( value=Constant(value='Turn raw outlier outlier scores to binary labels (0 or 1).\n Parameters\n ----------\n pred_scores : list or numpy array of shape (n_samples,)\n Raw outlier scores. Outliers are assumed have larger values.\n outliers_fraction : float in (0,1)\n Percentage of outliers.\n Returns\n -------\n outlier_labels : numpy array of shape (n

In [35]:
decode_sequence(train_dataloader.dataset[0]['ids_response'], tokenizerGPT)

Декодированная строка: import unittest
from numpy.testing import assert_allclose
from utils.utility import score_to_label

class TestMetrics(unittest.TestCase):
    def test_score_to_label(self):
        manual_scores = [0.1, 0.4, 0.2, 0.3, 0.5, 0.9, 0.7, 1, 0.8, 0.6]
        labels = score_to_label(manual_scores, outliers_fraction=0.1)
        assert_allclose(labels, [0, 0, 0, 0, 0, 0, 0, 1, 0, 0])
        labels = score_to_label(manual_scores, outliers_fraction=0.3)
        assert_allclose(labels, [0, 0, 0, 0, 0, 1, 0, 1, 1, 0])
if __name__ == '__main__':
    unittest.main()
<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><

In [36]:
decode_sequence(validation_dataloader.dataset[0]['input_ids_focal_method'], tokenizer_code_bert)

Декодированная строка: <s><FUNC_TOKEN> def valid_actions(self): num_raises_so_far = sum([p.raised for p in self.players]) if num_raises_so_far == self.num_players: return ['F', 'C'] else: if self.round == 0: return ['F', 'C', '2R'] else: return ['F', 'C', '4R'] <INFO_TOKEN> <AST_TOKEN> Module( body=[ FunctionDef( name='valid_actions', args=arguments( posonlyargs=[], args=[ arg(arg='self')], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[ Assign( targets=[ Name(id='num_raises_so_far', ctx=Store())], value=Call( func=Name(id='sum', ctx=Load()), args=[ ListComp( elt=Attribute( value=Name(id='p', ctx=Load()), attr='raised', ctx=Load()), generators=[ comprehension( target=Name(id='p', ctx=Store()), iter=Attribute( value=Name(id='self', ctx=Load()), attr='players', ctx=Load()), ifs=[], is_async=0)])], keywords=[])), If( test=Compare( left=Name(id='num_raises_so_far', ctx=Load()), ops=[ Eq()], comparators=[ Attribute( value=Name(id='self', ctx=Load()), attr='num_players', ctx=Load())]), b

In [37]:
decode_sequence(validation_dataloader.dataset[0]['input_ids_focal_cls'], tokenizer_code_bert)

Декодированная строка: <s><CLS_TOKEN> <FUNC_TOKEN> <INFO_TOKEN> <AST_TOKEN></s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

In [38]:
decode_sequence(validation_dataloader.dataset[0]['ids_response'], tokenizerGPT)

Декодированная строка: from leduc.state import Leduc
from leduc.state import State

def test_valid_actions():
    state = State([1, 2, 3], 2, None)
    actions = state.valid_actions()
    assert actions == ['F', 'C', '1R'], actions
    state.take('C')
    actions = state.valid_actions()
    assert actions == ['F', 'C', '1R'], actions
    state = State([1, 2, 3], 2, None)
    state.take('1R')
    actions = state.valid_actions()
    assert actions == ['F', 'C'], actions
<PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><P

Проверка итерирования

In [39]:
for i, batch in enumerate(tqdm(train_dataloader)):
    if i == 0:
        break

  0%|          | 0/98160 [00:00<?, ?it/s]


Корректно отрабатывает!

Далее, собираем архитектуру и готовимся обучать

In [40]:
model_code_bert = AutoModel.from_pretrained("microsoft/codebert-base", output_hidden_states= True)
model_code_bert.resize_token_embeddings(len(tokenizer_code_bert))

Embedding(50271, 768, padding_idx=1)

Как работает модель codeBERT:

In [41]:
for i, batch in enumerate(train_dataloader):
	
	# Проверка корректности работы
	b_input_ids = batch['input_ids_focal_method'].to(device)
	b_input_mask = batch['attention_mask_focal_method'].to(device)
	
	outputs_code_bert = model_code_bert(b_input_ids, attention_mask=b_input_mask)
	last_hidden_state_code_bert = outputs_code_bert['last_hidden_state']
	print(last_hidden_state_code_bert.size())
	break

torch.Size([2, 512, 768])


Таким образом, для каждого токена мы получим свое закодированное значение размерности 768

Модель GPT2:

In [42]:
from transformers import AutoConfig

modelGPT2Path = "gpt2"
config = AutoConfig.from_pretrained(modelGPT2Path, is_decoder=True, add_cross_attention= True)
config.add_cross_attention = True  # Включение cross-attention

modelGPT2 = AutoModel.from_pretrained(modelGPT2Path, config=config)
modelGPT2.resize_token_embeddings(len(tokenizerGPT))

Some weights of GPT2Model were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossattentio

Embedding(50258, 768)

Как работает модель GPTBigCode

In [43]:
for i, batch in enumerate(train_dataloader):

	b_input_ids = batch['input_ids_focal_method'].to(device)
	b_input_mask = batch['attention_mask_focal_method'].to(device)
	
	outputs_code_bert = model_code_bert(b_input_ids, attention_mask=b_input_mask)
	last_hidden_state_code_bert = outputs_code_bert['last_hidden_state']

	print(last_hidden_state_code_bert.size())
	
	# Проверка корректности работы
	response_input_ids = batch['ids_response'].to(device)
	response_input_mask = batch['attention_mask_response'].to(device)
	gpt_output = modelGPT2(input_ids=response_input_ids, 
							  attention_mask=response_input_mask, 
							  encoder_hidden_states = last_hidden_state_code_bert)
	print(gpt_output['last_hidden_state'].size())
	
	
	# outputs_code_bert = model_code_bert(b_input_ids, attention_mask=b_input_mask)
	# last_hidden_state_code_bert = outputs_code_bert['last_hidden_state']
	# print(last_hidden_state_code_bert.size())
	break

torch.Size([2, 512, 768])
torch.Size([2, 1024, 768])


Ну, как-то худо-бедно всё это дело запускается. Пробуем строить модель

In [None]:
from transformers import GPT2LMHeadModel

class LargeCodeModel(nn.Module):
	'''Класс для сложной языковой модели, которая обрабатывает входной код'''
	def __init__(self, bert_model_name, gpt2_name):
		super(LargeCodeModel, self).__init__()
		
		self.bert1 = AutoModel.from_pretrained(bert_model_name, output_hidden_states= True)
		self.bert2 = AutoModel.from_pretrained(bert_model_name, output_hidden_states= True)
		self.tokenizer_code_bert = AutoTokenizer.from_pretrained(bert_model_name)

		self.new_special_tokens = ['<FUNC_TOKEN>',
            '<INFO_TOKEN>',
            '<CLS_TOKEN>', 
            '<AST_TOKEN>', 
            '<DESCRIPTION_TOKEN>',
            '<COMMENTS_TOKEN>']

		self.special_tokens_dict = {
			'additional_special_tokens': new_special_tokens
		}

		self.tokenizer_code_bert.add_special_tokens(self.special_tokens_dict)
		self.bert1.resize_token_embeddings(len(self.tokenizer_code_bert))
		self.bert2.resize_token_embeddings(len(self.tokenizer_code_bert))

		self.gpt2_config = AutoConfig.from_pretrained(gpt2_name, is_decoder=True, add_cross_attention= True)
		self.gpt2_config.add_cross_attention = True  # Включение cross-attention
		self.tokenizerGPT = AutoTokenizer.from_pretrained(gpt2_name)
		self.tokenizerGPT.add_special_tokens({'pad_token': '<PAD>'})
		self.gpt2 = GPT2LMHeadModel.from_pretrained(modelGPT2Path, config=config)
		self.gpt2.resize_token_embeddings(len(self.tokenizerGPT))

		self.layer_norm = nn.LayerNorm(self.bert1.config.hidden_size)

		self.projection = nn.Linear(
            self.bert1.config.hidden_size + self.bert2.config.hidden_size,
            self.gpt2.config.hidden_size
        )

	# forward call
	def forward(self, focal_method_input_ids, 
			 			focal_method_attention_masks, 
						focal_cls_input_ids,
						focal_cls_attention_masks,
						response_ids, response_attention_masks):
		
		print(focal_method_input_ids.size())
		print(focal_method_attention_masks.size())
		
		bert1_outputs = self.bert1(focal_method_input_ids, focal_method_attention_masks)
		last_hidden_state_bert1 = bert1_outputs['last_hidden_state']

		bert2_outputs = self.bert2(focal_cls_input_ids, focal_cls_attention_masks)
		last_hidden_state_bert2 = bert2_outputs['last_hidden_state']

		# print(last_hidden_state_bert1.size())
		# print(last_hidden_state_bert2.size())

		concat_hidden_states = torch.cat([last_hidden_state_bert1, last_hidden_state_bert2], dim=1)

		# print(concat_hidden_states.size())

		# LayerNormalization
		normalized_hidden_states = self.layer_norm(concat_hidden_states)

		# Для BatchNorm
		# batch_norm_input = concat_hidden_states.view(-1, 768)
		# normalized_hidden_states = self.batch_norm(batch_norm_input)
		# normalized_hidden_states = normalized_hidden_states.view(2, 1024, 768)
		# print(normalized_hidden_states.size())
		# print(torch.cat([focal_method_attention_masks, focal_cls_attention_masks], dim=1).size())
		# print(response_ids.size())
		# print(response_input_mask.size())

		# print(response_attention_masks.size())
		
		gpt2_outputs = self.gpt2(
            input_ids=response_ids,
            attention_mask=response_attention_masks,
            encoder_hidden_states=normalized_hidden_states,
            encoder_attention_mask=torch.cat([focal_method_attention_masks, focal_cls_attention_masks], dim=1),
			labels=response_ids
        )

		return gpt2_outputs

		

Отлаживаем модель

In [47]:
CodeModel = LargeCodeModel(bert_model_name="microsoft/codebert-base",
                           gpt2_name="gpt2")

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.crossattention.c_attn.bias', 'h.0.crossattention.c_attn.weight', 'h.0.crossattention.c_proj.bias', 'h.0.crossattention.c_proj.weight', 'h.0.crossattention.q_attn.bias', 'h.0.crossattention.q_attn.weight', 'h.0.ln_cross_attn.bias', 'h.0.ln_cross_attn.weight', 'h.1.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.weight', 'h.1.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.weight', 'h.1.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.1.ln_cross_attn.weight', 'h.10.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.10.crossattention.c_proj.bias', 'h.10.crossattention.c_proj.weight', 'h.10.crossattention.q_attn.bias', 'h.10.crossattention.q_attn.weight', 'h.10.ln_cross_attn.bias', 'h.10.ln_cross_attn.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.c_attn.weight', 'h.11.crossat

In [49]:
for i, batch in enumerate(train_dataloader):
    
	focal_method_input_ids = batch['input_ids_focal_method']
	focal_method_attention_masks = batch['attention_mask_focal_method']

	focal_cls_input_ids = batch['input_ids_focal_cls']
	focal_cls_attention_masks = batch['attention_mask_focal_cls']

	response_ids = batch['ids_response']
	response_attention_masks = batch['attention_mask_response']

	output = CodeModel(focal_method_input_ids, focal_method_attention_masks,
						focal_cls_input_ids, focal_cls_attention_masks,
						response_ids, response_attention_masks)
	
	print(output['logits'].size())
	print(output['loss'])
	break

torch.Size([2, 512])
torch.Size([2, 512])
torch.Size([2, 1024, 768])
torch.Size([2, 1024, 768])
torch.Size([2, 1024])
torch.Size([2, 1024, 50258])
tensor(15.0449, grad_fn=<NllLossBackward0>)


Далее необходимо объявить функцию train-val loop