In [None]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, RobertaModel, RobertaTokenizer, AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gpt2_model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(device)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")

In [None]:
bert_model = AutoModelForCausalLM.from_pretrained("FacebookAI/roberta-base", is_decoder=True)
bert_tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")

In [None]:
vocab1 = gpt2_tokenizer.get_vocab()
vocab2 = bert_tokenizer.get_vocab()

In [None]:
common_tokens = set.intersection(set([*vocab1.keys()]), set([*vocab2.keys()]))

In [None]:
vocab1_keys = set(gpt2_tokenizer.get_vocab().keys())
vocab2_keys = set(bert_tokenizer.get_vocab().keys())
intersection = vocab1_keys & vocab2_keys
print(len(intersection) <= min(len(vocab1_keys), len(vocab2_keys)))  # Must be True

True


In [None]:
id_mapping = {}
for token in common_tokens:
    id_mapping[vocab1[token]] = vocab2[token]

In [None]:
dataset =

In [None]:
# итеративно
def collate_fn(batch, max_length=64, window_size=3):
    texts = [item["text"] for item in batch]
    gpt_inputs = gpt2_tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
    bert_inputs = bert_tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")

    input_ids = bert_inputs["input_ids"]
    attention_mask = bert_inputs["attention_mask"]

    all_bert_inputs = []
    all_labels = []

    for i in range(input_ids.size(0)):
        non_pad = (input_ids[i] != bert_tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
        seq_len = len(non_pad)

        # несколько маскирований, передвигая окно
        for pos in range(seq_len - window_size + 1):
            masked_input = input_ids[i].clone()
            labels = torch.full_like(masked_input, -100)

            # маскирование текущего окна
            mask_indices = non_pad[pos:pos+window_size]
            labels[mask_indices] = masked_input[mask_indices]
            masked_input[mask_indices] = bert_tokenizer.mask_token_id

            all_bert_inputs.append(masked_input)
            all_labels.append(labels)

    # stack все маскирования
    if len(all_bert_inputs) > 0:
        bert_input_ids = torch.stack(all_bert_inputs)
        bert_labels = torch.stack(all_labels)

        # attention_mask для каждого маскирования
        bert_attention_mask = attention_mask.repeat(len(all_bert_inputs) // attention_mask.size(0), 1)

        # для gpt
        gpt_input_ids = gpt_inputs["input_ids"].repeat(len(all_bert_inputs) // gpt_inputs["input_ids"].size(0), 1)
        gpt_attention_mask = gpt_inputs["attention_mask"].repeat(len(all_bert_inputs) // gpt_inputs["attention_mask"].size(0), 1)

    else:
        # если последовательность меньше окна
        bert_input_ids = input_ids
        bert_labels = torch.full_like(input_ids, -100)
        bert_attention_mask = attention_mask
        gpt_input_ids = gpt_inputs["input_ids"]
        gpt_attention_mask = gpt_inputs["attention_mask"]

    return {
        "gpt_input_ids": gpt_input_ids.to(device),
        "gpt_attention_mask": gpt_attention_mask.to(device),
        "bert_input_ids": bert_input_ids.to(device),
        "bert_attention_mask": bert_attention_mask.to(device),
        "bert_labels": bert_labels.to(device)
    }

In [None]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
optimizer = torch.optim.AdamW(bert_model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

In [None]:
for epoch in range(3):
    bert_model.train()
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    for batch in pbar:

        with torch.no_grad():
            gpt_logits = gpt2_model(input_ids=batch["gpt_input_ids"], attention_mask=batch["gpt_attention_mask"]).logits

        bert_logits = bert_model(input_ids=batch["bert_input_ids"], attention_mask=batch["bert_attention_mask"])

        loss_mask = (batch["bert_labels"] != -100)
        loss = criterion(bert_logits[loss_mask].view(-1, gpt2_model.config.vocab_size), gpt_logits[loss_mask].argmax(dim=-1).view(-1))

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_postfix({"loss": loss.item()})

In [None]:
bert_model.save_pretrained("")