In [1]:
import argparse
from contextlib import nullcontext
import torch
from tqdm import tqdm

from data import get_dataset
from utils.training_utils import get_lr, get_run_name, AverageMeter
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from evaluate import evaluate, evaluate_forced
from models import get_model
from tokenizing import get_tokenizer, NumeralTokenizer, Tokenizer
from data.reverse import prefix_target_list
import os
import wandb
from argparse import Namespace


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

get_lr(3*500, 1e-5, 125, 1000*3, 1e-6)

5.807090860141019e-06

In [2]:

args = {
    "model": 'gpt',
    "dataset": 'reverse',
    "n_train": 300,
    "n_test": 5000,
    "n_nodes": 100,
    "num_nodes": 20000,
    "deg": 2,
    "path_len": 5,
    "mate_in": 2,
    "unrolled": True,
    "batch_size": 128,
    "lr": 5e-4,
    "weight_decay": 1e-2,
    "epochs": 300,
    "save_every": 30,
    "teacherless": False,
    "reverse": False,
    "eval_train": False,
    "eval_every": 3750,
    "use_wandb": False,
    "wandb_entity": '5000',
    "n_layer": 36,
    "n_head": 20,
    "n_embd": 1280,
    "block_size": 11,
    "teacherless": True
}

args = Namespace(**args)

In [31]:
print(train_data.tokenized[1])
print(train_loader.dataset[1])

tensor([40006, 40018, 40007, 40012, 40015, 40001,     0, 40000,     1, 40002])
(tensor([40006, 40018, 40007, 40012, 40015, 40001, 40003, 40003, 40003],
       device='cuda:6'), tensor([   -1,    -1,    -1,    -1,    -1,     0, 40000,     1, 40002],
       device='cuda:6'))


In [4]:

# System stuff
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6"  # Only allow GPU 0, or "0,1" for multiple GPUs

# Clear the cached device count to ensure PyTorch re-evaluates available devices
torch.cuda.device_count.cache_clear()
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('device:', device)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# Model stuff
top_k = 1

# Evaluation stuff
eval_iters = 1000
eval_interval = 5
log_interval = 10

# Optimiser
dtype = 'bfloat16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
beta1 = 0.9
beta2 = 0.999
decay_lr = True
args.compile = False if device == 'cuda' else False
args.use_flash = True if device == 'cuda' else False
warmup_iters = 100
min_lr = 1e-6

run_name = get_run_name(args)
path = './checkpoints/' + run_name + '.pt'

# Get tokenizer and de-tokenizer
tokenizer = get_tokenizer(args)

checkpoint_dir = "../../../../data/user_data/clding/checkpoints_onlyteacherless_12000"
checkpoint_path = os.path.join(checkpoint_dir, "model_checkpoint_epoch_1440.pt")

train_data, test_data = get_dataset(args, tokenizer, device)

train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, collate_fn = train_data.collate_fn)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=True, collate_fn = test_data.collate_fn)
target_len = train_data.num_tokens - train_data.num_prefix_tokens
max_iters = len(train_data) * args.epochs

lr_decay_iters = max_iters

block_size = train_data.num_tokens
args.block_size = 17
args.vocab_size = tokenizer.vocab_size
args.teacherless_token = tokenizer.encode('$')[0] if args.teacherless else None

# Load model and move to device
model = get_model(args)
print("checkpoint_path", checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location=device)
new_state_dict = {}
for k, v in checkpoint.items():
    new_key = k.replace("module.", "")  # Remove 'module.' prefix
    new_state_dict[new_key] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()


# Load a single data point from train_data
# Assuming train_data is directly accessible and each item is formatted correctly
data_path = './data/datasets/reverse/'
train_path, test_path = data_path + f'train_{args.n_nodes}.txt', data_path + f'test_{args.n_nodes}.txt'


device: cuda:0
[('AAAAAAAAAAf=', 'a0-b0>'), ('UFFBOPAXYIf=', 'a0-b0>'), ('AAAAAAAAAA@=', 'b0-a0>'), ('UFFBOPAXYI@=', 'b0-a0>'), ('AAAAAAAAAAf=', 'a1-b1>'), ('JRFXURCNFDf=', 'a1-b1>'), ('AAAAAAAAAAf=', 'a2-b2>'), ('FYNBZXJMKDf=', 'a2-b2>'), ('AAAAAAAAAA@=', 'b2-a2>'), ('FYNBZXJMKD@=', 'b2-a2>'), ('AAAAAAAAAAf=', 'a3-b3>'), ('CUEFEJDERAf=', 'a3-b3>'), ('AAAAAAAAAAf=', 'a4-b4>'), ('TGHNGEQXDRf=', 'a4-b4>'), ('AAAAAAAAAA@=', 'b4-a4>'), ('TGHNGEQXDR@=', 'b4-a4>'), ('AAAAAAAAAAf=', 'a5-b5>'), ('DQEEHYTXFAf=', 'a5-b5>'), ('AAAAAAAAAAf=', 'a6-b6>'), ('IWOAHKSGEDf=', 'a6-b6>'), ('AAAAAAAAAA@=', 'b6-a6>'), ('IWOAHKSGED@=', 'b6-a6>'), ('AAAAAAAAAAf=', 'a7-b7>'), ('PWSMEYVFSZf=', 'a7-b7>'), ('AAAAAAAAAAf=', 'a8-b8>'), ('QITHVKDKGAf=', 'a8-b8>'), ('AAAAAAAAAA@=', 'b8-a8>'), ('QITHVKDKGA@=', 'b8-a8>'), ('AAAAAAAAAAf=', 'a9-b9>'), ('NDKDJWNFBAf=', 'a9-b9>'), ('AAAAAAAAAAf=', 'a10-b10>'), ('KEDZZSFSMUf=', 'a10-b10>'), ('AAAAAAAAAA@=', 'b10-a10>'), ('KEDZZSFSMU@=', 'b10-a10>'), ('AAAAAAAAAAf=', 'a11-b1

In [30]:
tokenizer = get_tokenizer(args)
print(tokenizer.encode('=a2'))

[]


In [40]:
import re

data_path = './data/datasets/reverse/'
train_path = data_path + 'train_10hashfb20000.txt'
test_path = data_path + 'test_10hashfb20000.txt'

# Compile regex to match strings with a{i} where i is an odd number
pattern = re.compile(r"^(.+?)=a(\d+)-b\2>")

# Function to check if a number is odd
def is_odd(number):
    return int(number) % 2 != 0 and int(number)%4 !=2

last_num = 0
# Open the output file for writing
with open(test_path, "w") as output_file:
    # Open the input file and process each line
    with open(train_path, "r") as input_file:
        for line in input_file:
            line = line.strip()  # Remove any leading/trailing whitespace
            match = pattern.match(line)
            if match and is_odd(match.group(2)):  # Check if `i` is odd
                prefix = match.group(1)  # Extract the prefix
                a_val = match.group(2)   # Extract the 'a{i}' value
                if last_num == a_val:
                    continue
                last_num = a_val
                # Flip from `a{i}-b{i}` to `b{i}-a{i}`
                modified_line = f"{prefix[:-1]}@=b{a_val}-a{a_val}>"
                output_file.write(modified_line + "\n")


0


In [48]:
# data_list = prefix_target_list(train_path, reverse=args.reverse)
# data_point = data_list[300][0] + data_list[300][1] # Adjust based on train_data's structure
# data_point = data_point.split("-")[0]
count = 0
testfile = './data/datasets/reverse/' + 'test_10hashfb20000' + '.txt'

pattern = re.compile(r"^(.+?=b\d+)")

# List to store prefixes
prefixes = []

line = '@=b1'
line = line.split('-')
i = int(line[0].split('b')[1])
input_tokens = tokenizer.encode(line[0])
# if isinstance(data_point, str):
#     # Tokenize directly if raw text
#     input_tokens = tokenizer.encode(data_point)
# else:
#     # If data_point is structured, assume "input_text" is the correct key
#     input_tokens = tokenizer.encode(data_point["input_text"])

# print("input_tokens", input_tokens)
input_tensor = torch.tensor(input_tokens).unsqueeze(0).to(device)  # Add batch dimension

# Initialize the generated tokens with the input prompt
generated_tokens = input_tokens.copy()

# Set maximum generation length to prevent infinite loops
max_length = 7  # You can adjust this as needed
# print("input_tokens", input_tokens)
with torch.no_grad():
    with torch.cuda.amp.autocast(dtype=torch.float16):
        for _ in range(max_length):
            # Convert the current generated tokens to tensor for each step
            input_tensor = torch.tensor(generated_tokens).unsqueeze(0).to(device)
            attn_mask = (input_tensor != -2).long()  # Adjust padding if needed
            # print("input_tensor", input_tensor)
            
            # Forward pass through the model
            logits, _, _ = model(input_tensor)
            
            # Get the prediction for the next token (only the last position)
            
            next_token_id = torch.argmax(logits[:, -1, :], dim=-1).item()
            # Append the predicted token to the generated sequence
            generated_tokens.append(next_token_id)
            if tokenizer.decode([next_token_id])[0] == ">":
                break

# Decode the complete generated token sequence into text
prediction_text = tokenizer.decode(generated_tokens)
prediction_text = "".join(prediction_text)
print("Generated Text:", prediction_text)
if (prediction_text == line[0]+f"-a{i}>"):
    count+=1
print(count)



SyntaxError: 'break' outside loop (2857526405.py, line 16)

In [47]:
# data_list = prefix_target_list(train_path, reverse=args.reverse)
# data_point = data_list[300][0] + data_list[300][1] # Adjust based on train_data's structure
# data_point = data_point.split("-")[0]
count = 0
testfile = './data/datasets/reverse/' + 'test_10hashfb20000' + '.txt'

pattern = re.compile(r"^(.+?=b\d+)")

# List to store prefixes
prefixes = []

# Open the file and process each line
with open(testfile, "r") as file:
    for line in file:
        line = line.split('-')
        i = int(line[0].split('b')[1])
        if i >= 4000:
            break
        input_tokens = tokenizer.encode(line[0])
        # if isinstance(data_point, str):
        #     # Tokenize directly if raw text
        #     input_tokens = tokenizer.encode(data_point)
        # else:
        #     # If data_point is structured, assume "input_text" is the correct key
        #     input_tokens = tokenizer.encode(data_point["input_text"])

        # print("input_tokens", input_tokens)
        input_tensor = torch.tensor(input_tokens).unsqueeze(0).to(device)  # Add batch dimension

        # Initialize the generated tokens with the input prompt
        generated_tokens = input_tokens.copy()

        # Set maximum generation length to prevent infinite loops
        max_length = 7  # You can adjust this as needed
        # print("input_tokens", input_tokens)
        with torch.no_grad():
            with torch.cuda.amp.autocast(dtype=torch.float16):
                for _ in range(max_length):
                    # Convert the current generated tokens to tensor for each step
                    input_tensor = torch.tensor(generated_tokens).unsqueeze(0).to(device)
                    attn_mask = (input_tensor != -2).long()  # Adjust padding if needed
                    # print("input_tensor", input_tensor)
                    
                    # Forward pass through the model
                    logits, _, _ = model(input_tensor)
                    
                    # Get the prediction for the next token (only the last position)
                    
                    next_token_id = torch.argmax(logits[:, -1, :], dim=-1).item()
                    # Append the predicted token to the generated sequence
                    generated_tokens.append(next_token_id)
                    if tokenizer.decode([next_token_id])[0] == ">":
                        break

        # Decode the complete generated token sequence into text
        prediction_text = tokenizer.decode(generated_tokens)
        prediction_text = "".join(prediction_text)
        print("Generated Text:", prediction_text)
        if (prediction_text == line[0]+f"-a{i}>"):
            count+=1
        print(count)
print(count)



Generated Text: ZHEMBCWSQP@=b1-a1>
1
Generated Text: TJVXOVNHKN@=b3-a1618>
1
Generated Text: XJMBFIXXRB@=b5-a3004>
1
Generated Text: YGINLJMWGJ@=b7-a7>
2
Generated Text: QMLCTMJQBS@=b9-a3136>
2
Generated Text: HLSURFPCQI@=b11-a11>
3
Generated Text: HZYDIADJGH@=b13-a2670>
3
Generated Text: NGNKBUUAWG@=b15-a2422>
3
Generated Text: GLVBAUPOZL@=b17-a17>
4
Generated Text: WHUGKPSQKD@=b19-a2802>
4
Generated Text: JMFPVZXKHR@=b21-a21>
5
Generated Text: HYGGCVATOL@=b23-a23>
6
Generated Text: WDSDRCSRPZ@=b25-a25>
7
Generated Text: OLONMUYFKN@=b27-a1350>
7
Generated Text: YPPWBDIXAP@=b29-a1840>
7
Generated Text: CDGRLJTNMH@=b31-a214>
7
Generated Text: AOSLGFCOFN@=b33-a33>
8
Generated Text: TBQZUNXQMJ@=b35-a1960>
8
Generated Text: LHWUTQCTFP@=b37-a37>
9
Generated Text: PNCXBPQWQD@=b39-a39>
10
Generated Text: VNUGMZKOAU@=b41-a41>
11
Generated Text: DZLBWJXUCL@=b43-a2648>
11
Generated Text: PRSNHFCTXB@=b45-a45>
12
Generated Text: CCOAIPQIBB@=b47-a1598>
12
Generated Text: KSKADSFRAN@=b49-a49>
13
Gen

KeyboardInterrupt: 