In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import pickle 
import matplotlib.pyplot as plt 
import random 
import numpy as np

random.seed(0)
torch.manual_seed(220);

In [3]:
with open('words_250000_train.txt', 'r', encoding='utf-8') as f:
    text = f.read().splitlines()

In [4]:
charset = sorted(list(set(''.join(text))))
stoi = {c:idx+2 for idx,c in enumerate(charset)}
itos = {idx+2:c for idx,c in enumerate(charset)}
stoi['.'] = 1 
itos[1] = '.'
stoi[''] = 0
itos[0] = ''
encode = lambda x: [stoi[i] for i in x] 
decode = lambda x: ''.join([itos[i] for i in x])
vocab_size = len(stoi)
print(f'{vocab_size = }') 

vocab_size = 28


In [6]:
common_letters = ('b', 'c', 'd','f', 'g', 'h', 'k' ,'l', 'm' ,'n' ,'p', 'r', 's', 't', 'y')
vowels = ('a','e','i','o','u')
uncommon_letters = ('j', 'q', 'x', 'z','v', 'w')

# indicators  
indicators = {0 :[stoi[c] for c in common_letters],
              1 : [stoi[v] for v in vowels],
              2 : [stoi[u] for u in uncommon_letters]}

print(indicators)

{0: [3, 4, 5, 7, 8, 9, 12, 13, 14, 15, 17, 19, 20, 21, 26], 1: [2, 6, 10, 16, 22], 2: [11, 18, 25, 27, 23, 24]}


In [7]:
# def build_dataset(words):
#     x_temp = []
#     y = []
#     for word in words:
#         encoded = np.array(encode(word))
#         input = np.ones_like(encoded)
#         x_temp.append(input) 
#         sorted_letter_count = Counter(word).most_common()
#         g = [0]*26
#         prev = x_temp[-1].copy()
#         while not np.all(prev != 1):
#             prev = x_temp[-1].copy()
#             for letter, _ in sorted_letter_count:
#                 idx = ord(letter) - ord('a')
#                 if g[idx] != 1:
#                     g[idx] = 1
#                     g_in = stoi[letter]
#                     break
#             y.append(g_in)
#             idxs = np.where(encoded == g_in)[0]
#             prev[idxs] = g_in 
#             if np.any(prev == 1):
#                 x_temp.append(prev)
            
#     # adding paddding to every input 
#     x = []
#     for ix in x_temp:
#         extra = maxlen - len(ix)
#         pad = np.array([0]*extra)
#         x_in = np.concatenate((ix, pad))
#         x.append(x_in)
    
#     x = torch.tensor(np.array(x), dtype=torch.long)
#     y = torch.tensor(y, dtype=torch.long)

#     return x,y 

# def queries_data(words):
#     x = []
#     y = [] 
#     '''
#         y can be info about the word 
#             1. length of the word
#             2. vowel count
#             3. consonant count
            
#     '''
    

# random.shuffle(text)
# n1 = int(0.9*len(text))
# n2 = int(0.95*len(text))

# # Xtr, Ytr = build_dataset(text[:n1])
# # Xval, Yval = build_dataset(text[n1:n2])
# # Xts, Yts  = build_dataset(text[n2:])

In [19]:
def load_split(split):
    try:
        with open(f'data\X_{split}.pkl', 'rb') as f:
            X = pickle.load(f)
    
    except Exception as e:
        raise e
    
    try:
        with open(f'data\Y_{split}.pkl', 'rb') as f:
            Y = pickle.load(f)
    
    except Exception as e:
        raise e
    return X, Y 

Xtr, Ytr = load_split('train')
Xval, Yval = load_split('val')
Xts, Yts = load_split('test')

In [20]:
print(f'Train Shape : {Xtr.shape} {Ytr.shape}')
print(f'Val Shape : {Xval.shape} {Yval.shape}')
print(f'Test Shape : {Xts.shape} {Yts.shape}')

Train Shape : torch.Size([1513485, 29]) torch.Size([1513485])
Val Shape : torch.Size([84078, 29]) torch.Size([84078])
Test Shape : torch.Size([83646, 29]) torch.Size([83646])


In [23]:
def get_batch(split):
    X, Y   = {'train': (Xtr, Ytr),
            'val' : (Xval, Yval)}[split]
    ix = torch.randint(0, X.shape[0], (32,))
    xb, yb = X[ix], Y[ix]
    return xb, yb


@torch.no_grad()
def split_loss():
    # model.eval()
    out = {}
    for split in ['train','val']:
        losses = torch.zeros(200)
        for k in range(200):
            xb, yb = get_batch(split)
            logits, loss = model(xb,yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    # model.train()
    return out

In [24]:
xb, yb = get_batch('train')

In [29]:
batch_X2_rep = np.tile(np.arange(0, 3).reshape([1, 3]), [32, 1]).flatten('F').reshape([32 * 3, 1])
print(batch_X2_rep.shape)

(96, 1)


In [30]:
batch_X_rep = np.tile(xb, [3,1,1])

In [31]:
batch_X_rep.shape

(3, 32, 29)

In [33]:
y_np = yb.numpy()

In [34]:
y_np

array([21,  2,  7,  6, 13, 10, 20,  2, 14,  4,  8, 17, 15, 10, 21,  4,  8,
       16, 24, 17, 17,  6,  6, 19,  3, 12,  5,  6, 19,  7, 10, 19],
      dtype=int64)

In [44]:
true_y = np.tile(
        np.array(list(map(lambda x: [k for k, v in indicators.items() if x in v], y_np))) ,
        [3,1])

In [45]:
true_y

array([[0],
       [1],
       [0],
       [1],
       [0],
       [1],
       [0],
       [1],
       [0],
       [0],
       [0],
       [0],
       [0],
       [1],
       [0],
       [0],
       [0],
       [1],
       [2],
       [0],
       [0],
       [1],
       [1],
       [0],
       [0],
       [0],
       [0],
       [1],
       [0],
       [0],
       [1],
       [0],
       [0],
       [1],
       [0],
       [1],
       [0],
       [1],
       [0],
       [1],
       [0],
       [0],
       [0],
       [0],
       [0],
       [1],
       [0],
       [0],
       [0],
       [1],
       [2],
       [0],
       [0],
       [1],
       [1],
       [0],
       [0],
       [0],
       [0],
       [1],
       [0],
       [0],
       [1],
       [0],
       [0],
       [1],
       [0],
       [1],
       [0],
       [1],
       [0],
       [1],
       [0],
       [0],
       [0],
       [0],
       [0],
       [1],
       [0],
       [0],
       [0],
       [1],
       [2],
    