In [80]:
import pickle
import warnings
from tqdm import tqdm

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F

with open(f"data/text_data/train_text_dataset.pkl", "rb") as file:
    train_data = pickle.load(file)

In [21]:
from transformers import XGLMTokenizer, XGLMForCausalLM

tokenizer = XGLMTokenizer.from_pretrained("facebook/xglm-564M")
model = XGLMForCausalLM.from_pretrained("facebook/xglm-564M")

# data_samples = {
#     'el': [
#         {
#             "premise": "Σαν επίθετο ελληνικά/ελληνικός σημαίνει από, ή σχετικά με την Ελλάδα, το λαό της, ή την κουλτούρα της.",
#             "choice1": "Αρχαία ελληνική λογοτεχνία Ελληνική γλώσσα Ελληνική μυθολογία Ελληνική φιλοσοφία Ελληνικό αλφάβητο Ελληνικός καφές Ελληνικός πολιτισμός Ελληνική: Πρωινή εφημερίδα Αθηνών από το 1925.",
#         }
#     ]
# }

# prompt = data_samples['el'][0]['premise']

# inputs = tokenizer(prompt, return_tensors="pt")
# input_ids, output_ids = inputs["input_ids"], inputs["input_ids"][:, 1:]
# outputs = model(**inputs, labels=input_ids)
# logits = outputs.logits
# logprobs = torch.gather(F.log_softmax(logits, dim=2), 2, output_ids.unsqueeze(2))

In [71]:
toknized_text = tokenizer(train_data[0], return_tensors="pt")
input_ids, output_ids = toknized_text["input_ids"], toknized_text["input_ids"][:, 1:]
outputs = model(**toknized_text, labels=input_ids)
logits = outputs.logits

In [75]:
F.log_softmax(logits, dim=2)

tensor([[[-29.7202, -29.8207,  -3.3591,  ..., -29.4810, -29.9216, -29.5775],
         [-37.7262, -38.5448,  -9.7600,  ..., -37.8513, -38.6274, -38.3732],
         [-40.8652, -41.5271,  -9.3609,  ..., -40.8041, -41.8329, -41.0001],
         ...,
         [-39.2513, -39.8346,  -6.9689,  ..., -38.9556, -40.3755, -39.1697],
         [-52.7946, -53.6060,  -7.2909,  ..., -52.4221, -54.0735, -52.9135],
         [-33.3091, -34.0045,  -5.1757,  ..., -33.3538, -34.3533, -33.3771]]],
       grad_fn=<LogSoftmaxBackward0>)

In [73]:
logits.view(1, logits.size(2), logits.size(1))[:, :, 1:]

tensor([[[-2.7024, 23.7592, 17.5587,  ..., 18.5534, 18.1724, 13.6467],
         [12.0696, 13.5225, 18.1403,  ..., 21.0680, 13.8093, 16.1025],
         [16.4932,  9.5688, 18.4145,  ..., 20.8660, 10.4263, 20.6164],
         ...,
         [-0.4551,  3.1389,  9.3599,  ...,  7.9678,  4.7735,  7.7314],
         [-1.0877, 10.4731,  6.8372,  ...,  9.9263, 15.2117, 11.8275],
         [14.5214, 11.4354,  3.3154,  ..., -0.2465, -1.2460, -0.2698]]],
       grad_fn=<SliceBackward0>)

In [77]:
F.cross_entropy(
    logits.view(1, logits.size(2), logits.size(1))[:, :, 1:], output_ids
).item()

37.36952209472656

In [95]:
cross_entropy = 0
total_preds = 0
log_prob_sum = 0
bar = tqdm(train_data, bar_format="{l_bar}{bar:15}{r_bar}{bar:-15b}")
for text in bar:
    toknized_text = tokenizer(text, return_tensors="pt")

    if toknized_text["input_ids"].shape[1] < model.config.max_position_embeddings:
        input_ids, output_ids = (
            toknized_text["input_ids"],
            toknized_text["input_ids"][:, 1:],
        )

        outputs = model(**toknized_text, labels=input_ids)
        logits = outputs.logits

        total_preds += output_ids.shape[1]
        cross_entropy += F.cross_entropy(
            logits.view(1, logits.size(2), logits.size(1))[:, :, 1:], output_ids
        )

        log_prob_sum += torch.gather(
            F.softmax(logits, 2), 2, output_ids.unsqueeze(2)
        ).sum()
    else:
        diff = (
            toknized_text["input_ids"].shape[1] - model.config.max_position_embeddings
        )
        for i in range(toknized_text["input_ids"].shape[1] - diff):
            input_ids = toknized_text["input_ids"][
                :, i : i + model.config.max_position_embeddings
            ]
            output_ids = toknized_text["input_ids"][
                :, i + 1 : i + model.config.max_position_embeddings + 1
            ]

            attention_mask = toknized_text["attention_mask"][
                :, i : i + model.config.max_position_embeddings
            ]

            outputs = model(input_ids, attention_mask, labels=input_ids)
            logits = outputs.logits

            total_preds += output_ids.shape[1]
            cross_entropy += F.cross_entropy(
                logits.view(1, logits.size(2), logits.size(1))[:, :, 1:], output_ids
            )

            log_prob_sum += torch.gather(
                F.softmax(logits, 2), 2, output_ids.unsqueeze(2)
            ).sum()

    ce_string = f"Mean Cross Entropy: {round((cross_entropy/total_preds).item(), 3)}"
    ppl_string = f"Mean PPL: {round(torch.exp(-log_prob_sum/total_preds).item(), 3)}"

    bar.set_description(ce_string + " | " + ppl_string)

Mean Cross Entropy: 0.498 | Mean PPL: 0.784:   0%|               | 1/181467 [00:19<967:27:39, 19.19s/it]                                             


RuntimeError: Expected target size [1, 2047], got [1, 2048]

In [93]:
toknized_text["attention_mask"].shape

torch.Size([1, 2281])

In [96]:
i

0