In [3]:
import torch
import argparse
import os
from pytorch_lightning import Trainer, seed_everything
from transformers import AutoTokenizer
from customized import CustomizedGPTJForCausalLM
from data import SoftPromptDataModule
from module import GPTJModule
from pytorch_lightning.loggers import TensorBoardLogger
import json
from utils import TrainArgs

In [4]:
args = TrainArgs("Emb", 27, [1])

In [5]:
torch_dtypes = {
    "fp32": (torch.float32, 32),
    "fp16": (torch.float16, "16-mixed"),
    "bf16": (torch.bfloat16, "bf16-mixed")
}

if torch.cuda.get_device_name in ["NVIDIA H100 PCIe", "NVIDIA A100 80GB PCIe"]:
    torch.set_float32_matmul_precision('medium')

torch_dtype, training_precision = torch_dtypes[args.precision]

if "gpt-j-6b" not in args.model_name_or_path:
    raise ValueError("This script is only for GPT-J-6B.")

logger = TensorBoardLogger(save_dir=args.output_path, name="logs")
model = CustomizedGPTJForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype).cuda()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, padding_side="left")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '<|PAD|>'})
    model.resize_token_embeddings(len(tokenizer))

Some weights of the model checkpoint at EleutherAI/gpt-j-6b were not used when initializing CustomizedGPTJForCausalLM: ['transformer.h.10.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.9.attn.masked_bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.23.attn.bias', 'transformer.h.26.attn.bias', 'transformer.h.22.attn.bias', 'transformer.h.2.attn.bias', 'transformer.h.11.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.7.attn.bias', 'transformer.h.22.attn.masked_bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.6.attn.masked_bias', 'transformer.h.24.attn.masked_bias', 'transformer.h.6.attn.bias', 'transformer.h.20.attn.bias', 'transformer.h.16.attn.bias', 'transformer.h.12.attn.bias', 'transformer.h.9.attn.bias', 'transformer.h.26.attn.masked_bias', 'transformer.h.5.attn.masked_bias', 'transformer.h.27.