In [19]:
from pydantic import BaseModel
from typing import Optional
import os
import sys
import torch
import random
import argparse
import numpy as np

from GPT2.model import (GPT2LMHeadModel)
from GPT2.utils import load_weight
from GPT2.config import GPT2Config
from GPT2.sample import sample_sequence
from GPT2.encoder import get_encoder



## INITIALISATION

class Args(BaseModel):
    text: str
    quiet: Optional[bool] = False
    nsamples: Optional[int] = 1
    unconditional: Optional[bool] = False
    batch_size: Optional[int] = 1
    length: Optional[int] = -1
    temperature: Optional[float] = 0.7
    top_k: Optional[int] = 1

args = Args(text="The capital of spain is ")
        
state_dict = torch.load(
    'checkpoints/gpt2-pytorch_model.bin',
    map_location='cpu' if not torch.cuda.is_available() else None
)

assert args.nsamples % args.batch_size == 0

seed = random.randint(0, 2147483647)
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = get_encoder()
config = GPT2Config()
model = GPT2LMHeadModel(config)



model = load_weight(model, state_dict)
model.to(device)
model.eval()

if args.length == -1:
    args.length = config.n_ctx // 2
elif args.length > config.n_ctx:
    raise ValueError("Can't get samples longer than window size: %s" % config.n_ctx)

print(args.text)
context_tokens = enc.encode(args.text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
    out = sample_sequence(
        model=model, 
        length=args.length,
        context=context_tokens  if not  args.unconditional else None,
        start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
        batch_size=args.batch_size,
        temperature=args.temperature, top_k=args.top_k, device=device
    )
    out = out[:, len(context_tokens):].tolist()
    for i in range(args.batch_size):
        generated += 1
        text = enc.decode(out[i])
        if args.quiet is False:
            print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)







The capital of spain is 


100%|██████████████████████████████████████████████████████████████| 512/512 [00:46<00:00, 10.96it/s]

 the capital of the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city of  the city 




In [None]:
from pydantic import BaseModel
from typing import Optional
import os
import sys
import torch
import random
import argparse
import numpy as np

from GPT2.model import (GPT2LMHeadModel)
from GPT2.utils import load_weight
from GPT2.config import GPT2Config
from GPT2.sample import sample_sequence
from GPT2.encoder import get_encoder



## INITIALISATION

class Args(BaseModel):
    text: str
    quiet: Optional[bool] = False
    nsamples: Optional[int] = 1
    unconditional: Optional[bool] = False
    batch_size: Optional[int] = 1
    length: Optional[int] = -1
    temperature: Optional[float] = 0.7
    top_k: Optional[int] = 1

args = Args(text="The capital of spain is ")
        
state_dict = torch.load(
    'checkpoints/gpt2-pytorch_model.bin',
    map_location='cpu' if not torch.cuda.is_available() else None
)

assert args.nsamples % args.batch_size == 0

seed = random.randint(0, 2147483647)
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc = get_encoder()
config = GPT2Config()
model = GPT2LMHeadModel(config)



model = load_weight(model, state_dict)
model.to(device)
model.eval()

if args.length == -1:
    args.length = config.n_ctx // 2
elif args.length > config.n_ctx:
    raise ValueError("Can't get samples longer than window size: %s" % config.n_ctx)

print(args.text)
context_tokens = enc.encode(args.text)
generated = 0
for _ in range(args.nsamples // args.batch_size):
    out = sample_sequence(
        model=model, 
        length=args.length,
        context=context_tokens  if not  args.unconditional else None,
        start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
        batch_size=args.batch_size,
        temperature=args.temperature, top_k=args.top_k, device=device
    )
    out = out[:, len(context_tokens):].tolist()
    for i in range(args.batch_size):
        generated += 1
        text = enc.decode(out[i])
        if args.quiet is False:
            print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)








In [14]:
!pip install tqdm

Collecting tqdm
  Downloading tqdm-4.65.0-py3-none-any.whl (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.1/77.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tqdm
Successfully installed tqdm-4.65.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.1.2[0m[39;49m -> [0m[32;49m23.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [5]:
dir(a)

['__class__',
 '__class_getitem__',
 '__contains__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__ior__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__ne__',
 '__new__',
 '__or__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__ror__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_metadata',
 'clear',
 'copy',
 'fromkeys',
 'get',
 'items',
 'keys',
 'move_to_end',
 'pop',
 'popitem',
 'setdefault',
 'update',
 'values']