In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
with open("pi_million.txt") as f:
    pi = f.read()
    pi = pi[0] + pi[2:]
print(len(pi)) # "3." + 1,000,000 digits

1000001


In [None]:
class Dataset:
    def __init__(self, pi_str, _start, _end, _len=512):
        self.data = self.fetch_data(pi_str, 
                                    start=_start, 
                                    end=_end, 
                                    seq_len=_len)

    def fetch_data(self, pi_str, start, end, seq_len):
        sequences = [pi_str[i:i+seq_len] for i in range(start, end)]
        digit_data = [[int(c) for c in seq] for seq in sequences]

        return digit_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def col_fn(batch):
    return torch.stack([torch.LongTensor(b) for b in batch])

train_set = Dataset(pi, _start=0, _end=100000)
test_set  = Dataset(pi, _start=100000, _end=103000)
train_dl  = DataLoader(train_set, batch_size=256, drop_last=False, collate_fn = col_fn)
test_dl   = DataLoader(test_set,  batch_size=256, drop_last=False, collate_fn = col_fn)

In [None]:
class Model(nn.Module):
    def __init__(self, num_digit, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_digit, in_dim)
        self.in_gru = nn.GRU(in_dim, hidden_dim, batch_first=True)
        self.latent_fc = nn.Linear(hidden_dim, hidden_dim)
        self.out_gru = nn.GRU(hidden_dim, out_dim, batch_first=True)
        self.out_fc = nn.Linear(out_dim, num_digit)
    
    def forward(self, x, return_loss=True):
        emb_out = self.embedding(x)
        hidden, _ = self.in_gru(emb_out)
        latent = self.latent_fc(hidden)
        out, _ = self.out_gru(latent)
        out_digit = self.out_fc(out[:, -2])

        if return_loss:
            loss = self.get_loss(out_digit, x)
            return loss
        else:
            return out_digit


    def get_loss(self, digit_logit, x):
        x_last = x[:, -1]

        loss = F.cross_entropy(digit_logit, x_last)
        return loss

    @torch.no_grad()
    def generate(self, x):
        digit_logit = self.forward(x, return_loss=False)

        prob = torch.softmax(digit_logit, -1)
        sample = torch.argmax(prob, -1)

        return sample

In [None]:
model_config = {
    "num_digit": 10,
    "in_dim": 512,
    "hidden_dim": 1024,
    "out_dim": 512
}

model = Model(**model_config)
model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=[0.9, 0.999])

In [None]:
idx = 0
epochs = 50
for e in range(epochs):
    for i, x in enumerate(train_dl):
        x = x.to(device)

        loss = model(x)
        optim.zero_grad()
        loss.backward()
        optim.step()

        print(f"\ridx: {idx:,} [{i} / {len(train_dl)}],  loss= {loss.item():.3f}", end='')
        if idx % 500 == 0:
            print("")

        idx += 1


idx: 0 [0 / 391],  loss= 2.306
idx: 500 [109 / 391],  loss= 2.308
idx: 1,000 [218 / 391],  loss= 2.310
idx: 1,500 [327 / 391],  loss= 2.318
idx: 2,000 [45 / 391],  loss= 2.309
idx: 2,500 [154 / 391],  loss= 2.302
idx: 3,000 [263 / 391],  loss= 2.305
idx: 3,500 [372 / 391],  loss= 2.310
idx: 4,000 [90 / 391],  loss= 2.296
idx: 4,500 [199 / 391],  loss= 2.312
idx: 5,000 [308 / 391],  loss= 2.313
idx: 5,500 [26 / 391],  loss= 2.298
idx: 6,000 [135 / 391],  loss= 2.303
idx: 6,500 [244 / 391],  loss= 2.315
idx: 7,000 [353 / 391],  loss= 2.303
idx: 7,500 [71 / 391],  loss= 2.302
idx: 8,000 [180 / 391],  loss= 2.304
idx: 8,500 [289 / 391],  loss= 2.309
idx: 9,000 [7 / 391],  loss= 2.304
idx: 9,500 [116 / 391],  loss= 2.300
idx: 10,000 [225 / 391],  loss= 2.308
idx: 10,500 [334 / 391],  loss= 2.306
idx: 11,000 [52 / 391],  loss= 2.309
idx: 11,500 [161 / 391],  loss= 2.304
idx: 12,000 [270 / 391],  loss= 2.317
idx: 12,500 [379 / 391],  loss= 2.308
idx: 13,000 [97 / 391],  loss= 2.318
idx: 13,50

In [None]:
torch.save(model.state_dict(), "512_1024_10k.ckpt")

# **Evaluation**
Check whether pi estimator works well in validation dataset

In [None]:
def check_correct_wrong(sample, true):
    batch_len = len(sample)
    correct = sample == true
    correct_num = correct.sum()
    wrong_num = batch_len - correct_num

    return correct_num, wrong_num

model_config = {
    "num_digit": 10,
    "in_dim": 512,
    "hidden_dim": 1024,
    "out_dim": 512
}

model = Model(**model_config)
state_dict = torch.load("512_1024_10k.ckpt")
model.load_state_dict(state_dict)
model.to(device)

total_correct, total_wrong = 0, 0
for x in train_dl:
    x = x.to(device)
    sample = model.generate(x)
    correct_num, wrong_num = check_correct_wrong(sample, x[:, -1])
    total_correct += correct_num
    total_wrong += wrong_num
print(f"total_correct = {total_correct}, total_wrong = {total_wrong}")

total_correct = 9999, total_wrong = 1
