In [1]:
import os
import json
import torch
import lightning as L
import datasets

from typing import Union, Any
from transformers import get_linear_schedule_with_warmup
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from torch.utils.data import default_collate, Dataset, DataLoader
from torch.optim import AdamW
from datasets import load_dataset
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.utilities.deepspeed import (
    convert_zero_checkpoint_to_fp32_state_dict
)
from pathlib import Path

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=1)

In [3]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

In [4]:
tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True)

In [5]:
tokenizer("time flies like an arrow", return_tensors="pt")["input_ids"][0].dtype

torch.int64

In [6]:
data = torch.randint_like(torch.eye(8, 1024), 50257, dtype=torch.int64, device=device)
data

tensor([[38012, 44670, 11951,  ..., 48251, 26915, 43453],
        [36739, 31031, 25731,  ..., 29147,  1624, 25526],
        [20617, 35459, 44564,  ..., 13836,  4164,   934],
        ...,
        [16699, 27205, 13850,  ..., 25998, 38129, 40700],
        [32565, 30652,  4647,  ...,  8812, 15416, 42065],
        [12800, 46815, 18155,  ..., 42894, 34912, 17417]], device='cuda:1')

In [7]:
model.train()
model(data)


CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[-104.5942, -104.3012, -112.0728,  ..., -117.4850, -114.3545,
          -108.7051],
         [ -70.5815,  -71.6304,  -76.4634,  ...,  -80.6910,  -79.7442,
           -72.1127],
         [ -63.2612,  -64.2107,  -65.7738,  ...,  -74.2361,  -71.4001,
           -64.3803],
         ...,
         [ -41.8209,  -41.5882,  -43.1094,  ...,  -43.7227,  -45.6557,
           -41.3914],
         [ -68.6626,  -67.7986,  -69.6115,  ...,  -70.9322,  -73.2276,
           -67.2836],
         [ -80.7582,  -80.3847,  -81.6098,  ...,  -82.9875,  -83.4734,
           -79.7137]],

        [[ -82.0731,  -78.5272,  -86.0155,  ...,  -91.4481,  -88.2522,
           -82.8735],
         [ -68.9144,  -69.6701,  -73.5118,  ...,  -75.1127,  -77.5842,
           -69.7680],
         [ -85.4981,  -84.7686,  -88.4837,  ...,  -94.1243,  -92.7000,
           -85.4709],
         ...,
         [ -82.5117,  -82.8144,  -82.6691,  ...,  -88.4224,  -86.8283,
          