In [None]:
#block_size = 256
class SiLU(nn.Module):
   def forward(self, x):
        return x*F.sigmoid(x)

class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class SA(nn.Module):


    def __init__(self, config):
        super().__init__()
        self.attention = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.key = nn.Linear(config.n_embd, head_size, bias=False)
        self.query = nn.Linear(config.n_embd, head_size, bias=False)
        self.value = nn.Linear(config.n_embd, head_size, bias=False)
        self.C= nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.attn_pdrop)
        self.dropout2 = nn.Dropout(config.resid_pdrop)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        k = self.key(x)
        q = self.query(x)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) / (k.size(-1))**0.5
        for i in range(block_size):
          for j in range(i + 1):
            att[i, j] = 1.0
        self.register_buffer("bias", att.view(1, 1, block_size, block_size))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.dropout2(self.C(y))
        return y

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.LN1 = nn.LayerNorm(config.n_embd)
        self.LN2 = nn.LayerNorm(config.n_embd)
        self.attention = SA(config)
        self.FeedForward = nn.ModuleDict(dict(
            l1    = nn.Linear(config.n_embd, 4 * config.n_embd),
            l2  = nn.Linear(4 * config.n_embd, config.n_embd),
            l3    = GELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))

    def forward(self, x):
        x = x+self.attention(self.LN1(x))
        x = x+self.FeedForward(self.LN2(x))
        return x

class Model(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.Re = Reinforcer()
        self.block_size = config.block_size
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
        n_params = sum(p.numel() for p in self.transformer.parameters())
        print("number of parameters: %.2fM" % (n_params/1e6,))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    @classmethod
    def from_pretrained(cls, model_type):

        from transformers import GPT2LMHeadModel
        device = "cpu" if torch.cuda.is_available() else "cpu"

        config = cls.get_default_config()
        config.model_type = model_type
        config.vocab_size = 20
        config.block_size = 256
        model = Model(config)
        sd = model.state_dict()


        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')]
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']


        return model

    def configure_optimizers(self, train_config):
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn
                if pn.endswith('bias'):

                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):

                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):

                    no_decay.add(fpn)


        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0,% (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, \
                                                    % (str(param_dict.keys() - union_params), )


        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)


        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
        tokenizer =
        for _ in range(max_new_tokens):
            logits, _ = self(idx)
            logits = logits[:, -1, :] / temperature
            if top_k != None:
                v, _ = torch.topk(logits, top_k)
                for i in range(logits.shape[0]):
                  for j in range(logits.shape[1]):
                    if logits[i, j] < v[i, -1]:
                      logits[i, j] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            if do_sample:
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                _, idx_next = torch.topk(probs, k=1, dim=-1)
                Score = []
                P = []
                for i in range(self.vocab_size):
                  for j in range(probs.size()[0]):
                    if j == i:
                      P.append(probs[0,j])
                    else:
                      P.append(0)
                  Score.append(self.Re.search(tokenizer.decode(torch.cat((idx,torch.tensor(P).view(len(self.vocab_size,1))),dim = 1))).tolist()[0])
                idx_next = torch.cat( (idx,torch.multinomial(P[np.array(Score).argmax()]),dim = 1) )




            if idx[-1][-1].item() == tokenizer.encode([["#"]]):
              idx_next = tokenizer.encode([["M"]])
            idx = torch.cat((idx,idx_next),dim = 1)

        return idx