In [1]:
"""
@author: Prakhar
"""
import os
import argparse
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import warnings
warnings.filterwarnings('ignore')

import torch
torch.cuda.empty_cache()

def choose_from_top_k_top_n(probs, k=50, p=0.8):
    ind = np.argpartition(probs, -k)[-k:]
    top_prob = probs[ind]
    top_prob = {i: top_prob[idx] for idx, i in enumerate(ind)}
    sorted_top_prob = {k: v for k, v in sorted(
        top_prob.items(), key=lambda item: item[1], reverse=True)}

    t = 0
    f = []
    pr = []
    for k, v in sorted_top_prob.items():
        t += v
        f.append(k)
        pr.append(v)
        if t >= p:
            break
    top_prob = pr / np.sum(pr)
    token_id = np.random.choice(f, 1, p=top_prob)

    return int(token_id)


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device_str = "cuda" if torch.cuda.is_available() else "cpu"

device = torch.device("cpu")
device_str = "cpu"


def generate(tokenizer, model, sentences, label):
    with torch.no_grad():
        for idx in range(sentences):
            finished = False
            cur_ids = torch.tensor(tokenizer.encode(
                label)).unsqueeze(0).to(device_str)
                
            for i in range(100):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]

                softmax_logits = torch.softmax(logits[0, -1], dim=0)

                if i < 5:
                    n = 10
                else:
                    n = 5

                next_token_id = choose_from_top_k_top_n(softmax_logits.to(device_str).numpy())  # top-k-top-n sampling
                cur_ids = torch.cat([cur_ids, torch.ones((1, 1)).long().to(device) * next_token_id], dim=1)
                

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    finished = True
                    break

                

            if finished:
                output_list = list(cur_ids.squeeze().to(device_str).numpy())
                output_text = tokenizer.decode(output_list)
                print(output_text)
            else:
                output_list = list(cur_ids.squeeze().to(device_str).numpy())
                output_text = tokenizer.decode(output_list)
                print(output_text)




def load_models(model_name):

    os.environ['CUDA_VISIBLE_DEVICES'] ='0'
    
    """
    Summary:
            Loading the trained model
    """
    print('Loading Trained GPT-2 Model')
    tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')
    model = GPT2LMHeadModel.from_pretrained('distilgpt2')
    model_path = model_name
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device_str)))
    return tokenizer, model








In [2]:
SENTENCES = 5
MODEL_NAME = 'mymodel.pt'
# MODEL_NAME = 'gpt2_medium_spamham_4.pt'
LABEL = 'SPAM'

TOKENIZER, MODEL = load_models(MODEL_NAME)

generate(TOKENIZER, MODEL, SENTENCES, LABEL)


Loading Trained GPT-2 Model
SPAM: Thanks for the chance to interview me <|endoftext|>
SPAM: Japkala is here. Do you have any questions? KABATYN: Thanks for visiting. Send me a call on +51 4727383887 to get an email from japkala <|endoftext|>
SPAM: Sorry, if you are going to a pub you want to get 2. Call ahead on your phone for FREE. For details please call our Phone Call 072-795566 or call 07544666940. We are also open 5th street so call if there is a busy lunchtime, you can call <|endoftext|>
SPAM: We are trying to make sure that you have a quick and enjoyable start to what we want to do with your future. We have been looking at a few things which will help you make your life easier. <|endoftext|>
SPAM: You have a new phone app for your mobile phone <|endoftext|>


In [3]:
SENTENCES = 5
# MODEL_NAME = 'mymodel.pt'
MODEL_NAME = 'gpt2_medium_spamham_4.pt'
LABEL = 'HAM'




TOKENIZER, MODEL = load_models(MODEL_NAME)

generate(TOKENIZER, MODEL, SENTENCES, LABEL)


Loading Trained GPT-2 Model


RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel:
	Missing key(s) in state_dict: "transformer.h.0.attn.masked_bias", "transformer.h.1.attn.masked_bias", "transformer.h.2.attn.masked_bias", "transformer.h.3.attn.masked_bias", "transformer.h.4.attn.masked_bias", "transformer.h.5.attn.masked_bias". 
	Unexpected key(s) in state_dict: "transformer.h.6.ln_1.weight", "transformer.h.6.ln_1.bias", "transformer.h.6.attn.bias", "transformer.h.6.attn.c_attn.weight", "transformer.h.6.attn.c_attn.bias", "transformer.h.6.attn.c_proj.weight", "transformer.h.6.attn.c_proj.bias", "transformer.h.6.ln_2.weight", "transformer.h.6.ln_2.bias", "transformer.h.6.mlp.c_fc.weight", "transformer.h.6.mlp.c_fc.bias", "transformer.h.6.mlp.c_proj.weight", "transformer.h.6.mlp.c_proj.bias", "transformer.h.7.ln_1.weight", "transformer.h.7.ln_1.bias", "transformer.h.7.attn.bias", "transformer.h.7.attn.c_attn.weight", "transformer.h.7.attn.c_attn.bias", "transformer.h.7.attn.c_proj.weight", "transformer.h.7.attn.c_proj.bias", "transformer.h.7.ln_2.weight", "transformer.h.7.ln_2.bias", "transformer.h.7.mlp.c_fc.weight", "transformer.h.7.mlp.c_fc.bias", "transformer.h.7.mlp.c_proj.weight", "transformer.h.7.mlp.c_proj.bias", "transformer.h.8.ln_1.weight", "transformer.h.8.ln_1.bias", "transformer.h.8.attn.bias", "transformer.h.8.attn.c_attn.weight", "transformer.h.8.attn.c_attn.bias", "transformer.h.8.attn.c_proj.weight", "transformer.h.8.attn.c_proj.bias", "transformer.h.8.ln_2.weight", "transformer.h.8.ln_2.bias", "transformer.h.8.mlp.c_fc.weight", "transformer.h.8.mlp.c_fc.bias", "transformer.h.8.mlp.c_proj.weight", "transformer.h.8.mlp.c_proj.bias", "transformer.h.9.ln_1.weight", "transformer.h.9.ln_1.bias", "transformer.h.9.attn.bias", "transformer.h.9.attn.c_attn.weight", "transformer.h.9.attn.c_attn.bias", "transformer.h.9.attn.c_proj.weight", "transformer.h.9.attn.c_proj.bias", "transformer.h.9.ln_2.weight", "transformer.h.9.ln_2.bias", "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_fc.bias", "transformer.h.9.mlp.c_proj.weight", "transformer.h.9.mlp.c_proj.bias", "transformer.h.10.ln_1.weight", "transformer.h.10.ln_1.bias", "transformer.h.10.attn.bias", "transformer.h.10.attn.c_attn.weight", "transformer.h.10.attn.c_attn.bias", "transformer.h.10.attn.c_proj.weight", "transformer.h.10.attn.c_proj.bias", "transformer.h.10.ln_2.weight", "transformer.h.10.ln_2.bias", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_fc.bias", "transformer.h.10.mlp.c_proj.weight", "transformer.h.10.mlp.c_proj.bias", "transformer.h.11.ln_1.weight", "transformer.h.11.ln_1.bias", "transformer.h.11.attn.bias", "transformer.h.11.attn.c_attn.weight", "transformer.h.11.attn.c_attn.bias", "transformer.h.11.attn.c_proj.weight", "transformer.h.11.attn.c_proj.bias", "transformer.h.11.ln_2.weight", "transformer.h.11.ln_2.bias", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_fc.bias", "transformer.h.11.mlp.c_proj.weight", "transformer.h.11.mlp.c_proj.bias", "transformer.h.12.ln_1.weight", "transformer.h.12.ln_1.bias", "transformer.h.12.attn.bias", "transformer.h.12.attn.c_attn.weight", "transformer.h.12.attn.c_attn.bias", "transformer.h.12.attn.c_proj.weight", "transformer.h.12.attn.c_proj.bias", "transformer.h.12.ln_2.weight", "transformer.h.12.ln_2.bias", "transformer.h.12.mlp.c_fc.weight", "transformer.h.12.mlp.c_fc.bias", "transformer.h.12.mlp.c_proj.weight", "transformer.h.12.mlp.c_proj.bias", "transformer.h.13.ln_1.weight", "transformer.h.13.ln_1.bias", "transformer.h.13.attn.bias", "transformer.h.13.attn.c_attn.weight", "transformer.h.13.attn.c_attn.bias", "transformer.h.13.attn.c_proj.weight", "transformer.h.13.attn.c_proj.bias", "transformer.h.13.ln_2.weight", "transformer.h.13.ln_2.bias", "transformer.h.13.mlp.c_fc.weight", "transformer.h.13.mlp.c_fc.bias", "transformer.h.13.mlp.c_proj.weight", "transformer.h.13.mlp.c_proj.bias", "transformer.h.14.ln_1.weight", "transformer.h.14.ln_1.bias", "transformer.h.14.attn.bias", "transformer.h.14.attn.c_attn.weight", "transformer.h.14.attn.c_attn.bias", "transformer.h.14.attn.c_proj.weight", "transformer.h.14.attn.c_proj.bias", "transformer.h.14.ln_2.weight", "transformer.h.14.ln_2.bias", "transformer.h.14.mlp.c_fc.weight", "transformer.h.14.mlp.c_fc.bias", "transformer.h.14.mlp.c_proj.weight", "transformer.h.14.mlp.c_proj.bias", "transformer.h.15.ln_1.weight", "transformer.h.15.ln_1.bias", "transformer.h.15.attn.bias", "transformer.h.15.attn.c_attn.weight", "transformer.h.15.attn.c_attn.bias", "transformer.h.15.attn.c_proj.weight", "transformer.h.15.attn.c_proj.bias", "transformer.h.15.ln_2.weight", "transformer.h.15.ln_2.bias", "transformer.h.15.mlp.c_fc.weight", "transformer.h.15.mlp.c_fc.bias", "transformer.h.15.mlp.c_proj.weight", "transformer.h.15.mlp.c_proj.bias", "transformer.h.16.ln_1.weight", "transformer.h.16.ln_1.bias", "transformer.h.16.attn.bias", "transformer.h.16.attn.c_attn.weight", "transformer.h.16.attn.c_attn.bias", "transformer.h.16.attn.c_proj.weight", "transformer.h.16.attn.c_proj.bias", "transformer.h.16.ln_2.weight", "transformer.h.16.ln_2.bias", "transformer.h.16.mlp.c_fc.weight", "transformer.h.16.mlp.c_fc.bias", "transformer.h.16.mlp.c_proj.weight", "transformer.h.16.mlp.c_proj.bias", "transformer.h.17.ln_1.weight", "transformer.h.17.ln_1.bias", "transformer.h.17.attn.bias", "transformer.h.17.attn.c_attn.weight", "transformer.h.17.attn.c_attn.bias", "transformer.h.17.attn.c_proj.weight", "transformer.h.17.attn.c_proj.bias", "transformer.h.17.ln_2.weight", "transformer.h.17.ln_2.bias", "transformer.h.17.mlp.c_fc.weight", "transformer.h.17.mlp.c_fc.bias", "transformer.h.17.mlp.c_proj.weight", "transformer.h.17.mlp.c_proj.bias", "transformer.h.18.ln_1.weight", "transformer.h.18.ln_1.bias", "transformer.h.18.attn.bias", "transformer.h.18.attn.c_attn.weight", "transformer.h.18.attn.c_attn.bias", "transformer.h.18.attn.c_proj.weight", "transformer.h.18.attn.c_proj.bias", "transformer.h.18.ln_2.weight", "transformer.h.18.ln_2.bias", "transformer.h.18.mlp.c_fc.weight", "transformer.h.18.mlp.c_fc.bias", "transformer.h.18.mlp.c_proj.weight", "transformer.h.18.mlp.c_proj.bias", "transformer.h.19.ln_1.weight", "transformer.h.19.ln_1.bias", "transformer.h.19.attn.bias", "transformer.h.19.attn.c_attn.weight", "transformer.h.19.attn.c_attn.bias", "transformer.h.19.attn.c_proj.weight", "transformer.h.19.attn.c_proj.bias", "transformer.h.19.ln_2.weight", "transformer.h.19.ln_2.bias", "transformer.h.19.mlp.c_fc.weight", "transformer.h.19.mlp.c_fc.bias", "transformer.h.19.mlp.c_proj.weight", "transformer.h.19.mlp.c_proj.bias", "transformer.h.20.ln_1.weight", "transformer.h.20.ln_1.bias", "transformer.h.20.attn.bias", "transformer.h.20.attn.c_attn.weight", "transformer.h.20.attn.c_attn.bias", "transformer.h.20.attn.c_proj.weight", "transformer.h.20.attn.c_proj.bias", "transformer.h.20.ln_2.weight", "transformer.h.20.ln_2.bias", "transformer.h.20.mlp.c_fc.weight", "transformer.h.20.mlp.c_fc.bias", "transformer.h.20.mlp.c_proj.weight", "transformer.h.20.mlp.c_proj.bias", "transformer.h.21.ln_1.weight", "transformer.h.21.ln_1.bias", "transformer.h.21.attn.bias", "transformer.h.21.attn.c_attn.weight", "transformer.h.21.attn.c_attn.bias", "transformer.h.21.attn.c_proj.weight", "transformer.h.21.attn.c_proj.bias", "transformer.h.21.ln_2.weight", "transformer.h.21.ln_2.bias", "transformer.h.21.mlp.c_fc.weight", "transformer.h.21.mlp.c_fc.bias", "transformer.h.21.mlp.c_proj.weight", "transformer.h.21.mlp.c_proj.bias", "transformer.h.22.ln_1.weight", "transformer.h.22.ln_1.bias", "transformer.h.22.attn.bias", "transformer.h.22.attn.c_attn.weight", "transformer.h.22.attn.c_attn.bias", "transformer.h.22.attn.c_proj.weight", "transformer.h.22.attn.c_proj.bias", "transformer.h.22.ln_2.weight", "transformer.h.22.ln_2.bias", "transformer.h.22.mlp.c_fc.weight", "transformer.h.22.mlp.c_fc.bias", "transformer.h.22.mlp.c_proj.weight", "transformer.h.22.mlp.c_proj.bias", "transformer.h.23.ln_1.weight", "transformer.h.23.ln_1.bias", "transformer.h.23.attn.bias", "transformer.h.23.attn.c_attn.weight", "transformer.h.23.attn.c_attn.bias", "transformer.h.23.attn.c_proj.weight", "transformer.h.23.attn.c_proj.bias", "transformer.h.23.ln_2.weight", "transformer.h.23.ln_2.bias", "transformer.h.23.mlp.c_fc.weight", "transformer.h.23.mlp.c_fc.bias", "transformer.h.23.mlp.c_proj.weight", "transformer.h.23.mlp.c_proj.bias". 
	size mismatch for transformer.wte.weight: copying a param with shape torch.Size([50257, 1024]) from checkpoint, the shape in current model is torch.Size([50257, 768]).
	size mismatch for transformer.wpe.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1024, 768]).
	size mismatch for transformer.h.0.ln_1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.0.ln_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([1024, 3072]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.0.attn.c_attn.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.0.attn.c_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.0.attn.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.0.ln_2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.0.ln_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.0.mlp.c_fc.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for transformer.h.0.mlp.c_fc.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for transformer.h.0.mlp.c_proj.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for transformer.h.0.mlp.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.1.ln_1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.1.ln_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.1.attn.c_attn.weight: copying a param with shape torch.Size([1024, 3072]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.1.attn.c_attn.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.1.attn.c_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.1.attn.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.1.ln_2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.1.ln_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.1.mlp.c_fc.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for transformer.h.1.mlp.c_fc.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for transformer.h.1.mlp.c_proj.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for transformer.h.1.mlp.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.2.ln_1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.2.ln_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.2.attn.c_attn.weight: copying a param with shape torch.Size([1024, 3072]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.2.attn.c_attn.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.2.attn.c_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.2.attn.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.2.ln_2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.2.ln_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.2.mlp.c_fc.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for transformer.h.2.mlp.c_fc.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for transformer.h.2.mlp.c_proj.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for transformer.h.2.mlp.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.3.ln_1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.3.ln_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.3.attn.c_attn.weight: copying a param with shape torch.Size([1024, 3072]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.3.attn.c_attn.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.3.attn.c_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.3.attn.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.3.ln_2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.3.ln_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.3.mlp.c_fc.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for transformer.h.3.mlp.c_fc.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for transformer.h.3.mlp.c_proj.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for transformer.h.3.mlp.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.4.ln_1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.4.ln_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.4.attn.c_attn.weight: copying a param with shape torch.Size([1024, 3072]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.4.attn.c_attn.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.4.attn.c_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.4.attn.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.4.ln_2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.4.ln_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.4.mlp.c_fc.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for transformer.h.4.mlp.c_fc.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for transformer.h.4.mlp.c_proj.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for transformer.h.4.mlp.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.5.ln_1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.5.ln_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.5.attn.c_attn.weight: copying a param with shape torch.Size([1024, 3072]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for transformer.h.5.attn.c_attn.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for transformer.h.5.attn.c_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for transformer.h.5.attn.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.5.ln_2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.5.ln_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.h.5.mlp.c_fc.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
	size mismatch for transformer.h.5.mlp.c_fc.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([3072]).
	size mismatch for transformer.h.5.mlp.c_proj.weight: copying a param with shape torch.Size([4096, 1024]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
	size mismatch for transformer.h.5.mlp.c_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.ln_f.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for transformer.ln_f.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for lm_head.weight: copying a param with shape torch.Size([50257, 1024]) from checkpoint, the shape in current model is torch.Size([50257, 768]).