In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import struct
import torch.optim as optim
import torchvision
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import time
import random




In [2]:
src_vocab_size = 25002
d_model = 512
n_layers = 6
n_head = 8
d_k = 64
d_v = 64
d_ff = 2048
tgt_vocab_size = 25002
tgt_len = 25002
n_context = 8
d_context = 64

In [3]:
'''embeddings=torch.Tensor(25003,512)
embeddings.to("cuda")
with open('MetaAI04.vd','rb') as file:
    file.read(4)
    file.read(4)
    for i in range(0,25001):
        embeddings[i]=torch.Tensor(struct.unpack('<512f',file.read(4*512)))
'''

'embeddings=torch.Tensor(25003,512)\nembeddings.to("cuda")\nwith open(\'MetaAI04.vd\',\'rb\') as file:\n    file.read(4)\n    file.read(4)\n    for i in range(0,25001):\n        embeddings[i]=torch.Tensor(struct.unpack(\'<512f\',file.read(4*512)))\n'

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)

        out = (x - mean) / torch.sqrt(var + self.eps)
        #out = self.gamma * out + self.beta
        return out

In [5]:
def get_attn_pad_mask(seq_q, seq_k): 
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(25001).unsqueeze(1)
    return pad_attn_mask.expand(batch_size, len_q, len_k)

In [6]:
def get_attn_subsequence_mask(seq):
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask

In [7]:
class Embeddings(nn.Module):
    def __init__(self, vocab, d_model):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(src_vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=10000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].clone().detach()
        return self.dropout(x)

In [9]:
def get_sinusoid_encoding_table(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

In [10]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, scale_factor, dropout=0.1):
        super().__init__()
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.scale_factor, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 1, -1e9)
        attn = self.dropout(torch.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        return output, attn

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(scale_factor=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        batch_size, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
        residual = q
        q = self.layer_norm(q)
        k = self.layer_norm(k)
        v = self.layer_norm(v)

        q = self.w_qs(q).view(batch_size, len_q, n_head, d_k)
        k = self.w_ks(k).view(batch_size, len_k, n_head, d_k)
        v = self.w_vs(v).view(batch_size, len_v, n_head, d_v)

        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)
        q, attn = self.attention(q, k, v, mask=mask)

        q = q.transpose(1, 2).contiguous().view(batch_size, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual
        q = self.layer_norm(q)
        return q, attn

In [12]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, dropout=0.1):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False))
        self.dropout = nn.Dropout(dropout)

    def forward(self, inputs):
        residual = inputs
        output = self.fc(inputs)
        output = self.dropout(output)
        return nn.LayerNorm(d_model).cuda()(output + residual)

In [13]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_head, d_model, d_k, d_v)
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, 
                                               enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn

In [14]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.src_emb = Embeddings(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        enc_outputs = self.src_emb(enc_inputs)
        enc_outputs = self.pos_emb(enc_outputs)
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
        enc_self_attns = []
        for layer in self.layers:
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return enc_outputs, enc_self_attns

In [15]:
class Decider(nn.Module):
    def __init__(self):
        super(Decider, self).__init__()
        
        self.dec_context_attn = MultiHeadAttention(n_head, d_model, d_k, d_v)
        self.dec_input_attn = MultiHeadAttention(n_head, d_model, d_k, d_v)
        self.context = torch.Tensor(n_context * d_context, d_model)
        self.contexts = torch.stack([self.context for _ in range(1)]).cuda()
    
    def clean(self, dim = 16):
        self.contexts = torch.stack([self.context for _ in range(dim)]).cuda()

    def forward(self, dec_inputs, enc_inputs):
        #self.contexts.requires_grad_(False)
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
        dec_outputs, attn = self.dec_input_attn(dec_inputs, self.contexts, self.contexts)
        contexts, attn = self.dec_context_attn(self.contexts, dec_inputs, dec_inputs, enc_self_attn_mask)
        return dec_outputs

In [16]:
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(n_head, d_model, d_k, d_v)
        self.dec_enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v)
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, 
                                                 dec_inputs, dec_self_attn_mask)
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, 
                                                enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs)
        return dec_outputs, dec_self_attn, dec_enc_attn

In [17]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = Embeddings(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        dec_outputs = self.tgt_emb(dec_inputs)
        dec_outputs = self.pos_emb(dec_outputs)
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda()
        dec_self_attn_subsequent_mask = get_attn_subsequence_mask(dec_inputs).cuda()
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)

        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns

In [18]:
class MetaAI(nn.Module):
    def __init__(self):
        super(MetaAI, self).__init__()
        self.encoder = Encoder()
        self.decider = Decider()
        self.decoder = Decoder()
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)

    def forward(self, enc_inputs, dec_inputs):
        dec_logits = []
        self.decider.clean(dim=1)
        for i in range(len(enc_inputs)):
            enc_outputs, enc_self_attns = self.encoder(enc_inputs[i].unsqueeze(0))
            #enc_outputs = self.decider(enc_outputs,enc_inputs[i].unsqueeze(0))
            dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs[i].unsqueeze(0), enc_inputs[i].unsqueeze(0), enc_outputs)
            dec_logits.append(self.projection(dec_outputs))
            dec_logits[i] = dec_logits[i].view(-1,dec_logits[i].size(-1))
        return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns

In [19]:
#torch.serialization.add_safe_globals([MetaAI, set, Encoder, Embeddings, PositionalEncoding, nn.Embedding, nn.Dropout, nn.ModuleList, EncoderLayer])
AI=MetaAI()
AI.load_state_dict(torch.load('MetaAI4.pth',weights_only = True),strict=False)
# for p in AI.parameters():
#     if p.dim() > 1:
#         nn.init.xavier_uniform_(p)
if torch.cuda.is_available():
    AI.to("cuda")

In [20]:
from datasets import load_dataset
dataset=load_dataset("chujiezheng/wizard_of_wikipedia",download_mode="reuse_dataset_if_exists")

In [21]:
def list_of_numbers_to_strings(lst):
    return [str(item) if isinstance(item, int) else list_of_numbers_to_strings(item) for item in lst]
def flatten_array(arr):
    return [item for sublist in arr for item in sublist]
def flatten_array_a(arr):
    return ['<|beginoftext|>'+item for sublist in arr for item in sublist]
def flatten_array_b(arr):
    return [item+'<|endoftext|>' for sublist in arr for item in sublist]
def _a(arr):
    for i in range(len(arr)):
        arr[i]='<|beginoftext|>'+arr[i]
    return arr
def _b(arr):
    for i in range(len(arr)):
        arr[i]=arr[i]+'<|endoftext|>'
    return arr

In [22]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("my-new-tokenizer")
tokenizer.add_tokens(['<|beginoftext|>','<pad>'])
tokenizer.add_special_tokens({'pad_token': '<pad>'})
tokenizer.add_special_tokens({'bos_token': '<|beginoftext|>'})
tokenizer.add_special_tokens({'eos_token': '<|endoftext|>'})

batch = 1000

data_input_t=tokenizer(flatten_array(dataset["train"]['post'][:batch]),padding='longest')['input_ids']
data_output_t=tokenizer(flatten_array_a(dataset["train"]['response'][:batch]),padding='longest')['input_ids']
data_pred_t=tokenizer(flatten_array_b(dataset["train"]['response'][:batch]),padding='longest')['input_ids']

data_input=data_input_t
data_output=data_output_t
data_pred=data_pred_t
batchstart=[0]
fastreader = dataset["train"]['post'][:batch]
for i in range(1,batch):
    batchstart.append(len(fastreader[i]))
    batchstart[i]+=batchstart[i-1]

# for i in range(1000):
#     data_input.append([data_input_t[lst+_] for _ in range(len(dataset["train"]['post'][:1000][i]))])
#     lst+=len(dataset["train"]['post'][:1000][i])

# for i in range(len(dataset["train"]['post'][:1000])):
#     data_input.append(tokenizer(dataset["train"]['post'][:1000][i],padding='longest')['input_ids'])
#     data_output.append(tokenizer(_a(dataset["train"]['response'][:1000][i]),padding='longest')['input_ids'])
#     data_pred.append(tokenizer(_b(dataset["train"]['response'][:1000][i]),padding='longest')['input_ids'])

In [23]:
def get_batch(batch):
    return torch.tensor(data_input[batchstart[batch]:batchstart[batch]+len(fastreader[batch])]),torch.tensor(data_output[batchstart[batch]:batchstart[batch]+len(fastreader[batch])]),torch.tensor(data_pred[batchstart[batch]:batchstart[batch]+len(fastreader[batch])])

In [24]:
AI.eval()
a,b,c=get_batch(11)
output, un1, un2, un3 = AI([a[2].cuda()],[b[2].cuda()])
torch.argmax(output[0],dim=1)

tensor([   46,    79,   329,  9930,   812,   329,  5749,     7,    84,    27,
           84,   553,   290,   693,  4225,    12,   329,    27,    77,   473,
         1408,   403,   261, 18357,  2210,   281,  3560,  1462,   819,   401,
         2287,   490,   364,  2181,    14,     0,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667],
       device='cuda:0')

In [25]:
print("output: {} \n answer: {} \n correction: {}".format(torch.argmax(output[0],dim=1), c[2], (c[2].cuda()==25001)+(c[2].cuda()==torch.argmax(output[0],dim=1))))

output: tensor([   46,    79,   329,  9930,   812,   329,  5749,     7,    84,    27,
           84,   553,   290,   693,  4225,    12,   329,    27,    77,   473,
         1408,   403,   261, 18357,  2210,   281,  3560,  1462,   819,   401,
         2287,   490,   364,  2181,    14,     0,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,
         8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667,  8667],
       device='cuda:0') 
 answer: tensor([   46,    79,   329,

In [36]:
criterion = nn.CrossEntropyLoss(ignore_index=25001)
# optimizer = optim.Adam(AI.parameters(), lr=0.005, betas=(0.9, 0.98), eps=1e-9)
optimizer = optim.SGD(AI.parameters(), lr=0.016)

batchsize = 16
warmup_steps = 200

#inputd = inputd.view(-1,batchsize,inputd.size(1))
# inputd = torch.split(inputd.long(), batchsize, dim=0)
# outputd = torch.split(outputd.long(), batchsize, dim=0)
# pred = torch.split(pred.long(), batchsize, dim=0)
#for i in range(len(pred)):
#    pred[i] = F.one_hot(pred[i], num_classes=tgt_vocab_size).float()

if torch.cuda.is_available():
    criterion=criterion.cuda()

work_dir = 'C:/'
writer = SummaryWriter("{}/logs/{}".format(work_dir,str(random.randint(10000,99999))), comment=str(random.randint(10000,99999)))

In [37]:
last_epoch = 0
last_batch = 0

In [40]:
AI.train()

t0=time.time()

for epoch in range(last_epoch, 100):
    for i in range(last_batch,batch):
        optimizer.zero_grad()
        inputd,outputd,pred = get_batch(i)
        outputs, un1, un2, un3 = AI(inputd.cuda(), outputd.cuda())
        loss_total = 0
        for j in range(len(inputd)):
#             print(outputs[j].size())
#             print(pred[j].size())
            loss = criterion(outputs[j].contiguous().view(-1, tgt_vocab_size), pred[j].contiguous().view(-1).cuda())
            loss.backward()
            loss_total += loss.item()
#             print(loss.item())
        optimizer.step()
        train_step = batch*epoch+i+1
        #rate = d_model ** (-0.5) * min(train_step ** (-0.5), train_step * warmup_steps ** (-1.5))
        #for p in optimizer.param_groups:
        #    p['lr'] = rate
        if train_step % 20 == 0:
            print("train time：{}, Loss: {}".format(train_step, loss_total / len(inputd)))
        if train_step % 500 == 0:
            print("output: {} \n answer: {} \n correction: {}".format(torch.argmax(outputs[0],dim=1), pred[0].cuda(), (pred[0].cuda()==25001)+(pred[0].cuda()==torch.argmax(outputs[0],dim=1)) ))
        writer.add_scalar("train_loss", loss_total / len(inputd), train_step)
        last_batch=i
    print("time per epoch: {} minutes".format((time.time()-t0)/(epoch+1)/60))
    last_batch=0
    last_epoch=epoch

train time：320, Loss: 3.425010532140732
train time：340, Loss: 2.995445156097412
train time：360, Loss: 1.2482065558433533
train time：380, Loss: 1.7308690786361693
train time：400, Loss: 2.0347853899002075
train time：420, Loss: 2.513591468334198
train time：440, Loss: 2.675833895802498
train time：460, Loss: 2.6052913665771484
train time：480, Loss: 2.0045367181301117
train time：500, Loss: 1.7898331334193547
output: tensor([  55,  479,   12, 1643, 2716,    7,   84, 2255,  358,  358,   12, 1643,
        1643,  281, 1421,   12, 1643,  747,    1, 3322, 3914,  677,  434, 3914,
         286,  286,    1,    0,   70,  286,  286,  364,  261,  434,  271,  677,
         290,   83,  261,  358,  261, 3883, 8224,    1,    0,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1, 

train time：1520, Loss: 1.6414456764856975
train time：1540, Loss: 1.9752378344535828
train time：1560, Loss: 3.046063482761383
train time：1580, Loss: 1.2331477254629135
train time：1600, Loss: 1.6425396278500557
train time：1620, Loss: 2.6439135372638702
train time：1640, Loss: 2.0651017785072328
train time：1660, Loss: 2.4619945883750916
train time：1680, Loss: 1.6339953641096752
train time：1700, Loss: 3.420452296733856
train time：1720, Loss: 2.3279868960380554
train time：1740, Loss: 0.9197007775306701
train time：1760, Loss: 1.4801776558160782
train time：1780, Loss: 2.858782251675924
train time：1800, Loss: 2.7661605775356293
train time：1820, Loss: 2.679353803396225
train time：1840, Loss: 2.3312071204185485
train time：1860, Loss: 3.8055390119552612
train time：1880, Loss: 3.345185915629069
train time：1900, Loss: 2.8990938663482666
train time：1920, Loss: 2.6920234858989716
train time：1940, Loss: 3.5279040336608887
train time：1960, Loss: 2.5074410140514374
train time：1980, Loss: 1.41499919891357

train time：2800, Loss: 2.610824406147003
train time：2820, Loss: 2.5534353733062742
train time：2840, Loss: 2.1811409473419188
train time：2860, Loss: 3.612252414226532
train time：2880, Loss: 3.000482678413391
train time：2900, Loss: 2.5829732020696006
train time：2920, Loss: 2.431773394346237
train time：2940, Loss: 3.2104247212409973
train time：2960, Loss: 1.9994394481182098
train time：2980, Loss: 1.1707877576351167
train time：3000, Loss: 3.0588440895080566
output: tensor([  41,  769,   69,  364,  559,  281,  261,    7,   83,  693,  364, 4828,
         284,  329,  553,  358,  364, 3883,  358,  364, 6696,  364,  261, 1339,
          14,  221,   12,  290,  288,  290,   12,  290,  290,   12,  453,   12,
         290,  290,  290,  358,   12,  453,   14,   12,  290,   14,    1,    1,
         290,  290,  290,  290,  290,   12,  221,   12,  302,    1,  812,   12,
         358,  290,  463,   14,  290,  288,  290,  290,  358,   14,  290,  290,
          14,   14,   12,  358,   14,  290,   12,  302

train time：4020, Loss: 1.5191735494881868
train time：4040, Loss: 1.7588470876216888
train time：4060, Loss: 2.5214683413505554
train time：4080, Loss: 1.8416639417409897
train time：4100, Loss: 1.2198469787836075
train time：4120, Loss: 1.284528911113739
train time：4140, Loss: 2.0512676437695823
train time：4160, Loss: 2.3831084966659546
train time：4180, Loss: 1.5940485447645187
train time：4200, Loss: 0.8243550608555476
train time：4220, Loss: 1.3594954162836075
train time：4240, Loss: 1.4380584160486858
train time：4260, Loss: 1.4434857666492462
train time：4280, Loss: 1.2866633057594299
train time：4300, Loss: 1.4133170396089554
train time：4320, Loss: 2.618557393550873
train time：4340, Loss: 2.27355101108551
train time：4360, Loss: 0.7962979475657145
train time：4380, Loss: 0.974596482515335
train time：4400, Loss: 1.191584825515747
train time：4420, Loss: 1.57528904825449
train time：4440, Loss: 2.130998358130455
train time：4460, Loss: 1.9447204321622849
train time：4480, Loss: 1.1285261027514935
t

train time：5300, Loss: 1.2572564333677292
train time：5320, Loss: 2.4391315579414368
train time：5340, Loss: 2.1029442071914675
train time：5360, Loss: 0.8001591339707375
train time：5380, Loss: 0.8551692187786102
train time：5400, Loss: 1.0225048899650573
train time：5420, Loss: 1.5083471462130547
train time：5440, Loss: 1.8892849832773209
train time：5460, Loss: 1.7186316028237343
train time：5480, Loss: 1.0932536274194717
train time：5500, Loss: 1.0888535653551419
output: tensor([  55,  479,   12, 1643, 2857,    7,   84, 2255,  756,  358,  508, 1421,
         673,  855, 1421,   12, 1643,  747, 3428, 3322,  290,  677,  434, 9180,
        2478,  286,  364,    0,   70,  286,  286,  364, 9970,  677,  271,  677,
        1146,   83, 2491,  358, 5678, 3883, 2013,    1,    0,  473, 2255,  473,
         281,  281,  677,  281,  281,  281,  677,  473,  281,  281,  473, 2255,
        2255,  677, 3883,  812,  281, 2255, 2255,    1,  473,  473, 3883, 1408,
         473,  281,  281,  281,  281, 2255,  473, 

train time：6500, Loss: 0.9879982235531012
output: tensor([  55,  479,   12,  329, 2857,    7,   84, 2255,  756,  358,  508, 1421,
         673,  261, 3428,   12, 1643,  747,    1, 3322, 3914,  677,  434, 9180,
         326,  286,    1,    0,   70,  286,  286,    1, 9970,  677,  271,  677,
        1146,   83, 2491,  358,  261, 3883, 2013,    1,    0,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1],
       device='cuda:0') 
 answer: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787, 13379, 

train time：7520, Loss: 0.6503598491350809
train time：7540, Loss: 1.1113403171300889
train time：7560, Loss: 1.6469959914684296
train time：7580, Loss: 0.6122087389230728
train time：7600, Loss: 0.6448778044432402
train time：7620, Loss: 1.372016303241253
train time：7640, Loss: 1.0606144458055495
train time：7660, Loss: 1.4390905102094014
train time：7680, Loss: 1.079775631427765
train time：7700, Loss: 2.19858655333519
train time：7720, Loss: 1.143805593252182
train time：7740, Loss: 0.2719944268465042
train time：7760, Loss: 0.6925760060548782
train time：7780, Loss: 1.5714968641599019
train time：7800, Loss: 1.6058399230241776
train time：7820, Loss: 1.907886154949665
train time：7840, Loss: 1.2550556004047393
train time：7860, Loss: 2.5740757882595062
train time：7880, Loss: 1.941337029139201
train time：7900, Loss: 1.6351089477539062
train time：7920, Loss: 1.536693498492241
train time：7940, Loss: 1.9981980323791504
train time：7960, Loss: 0.8614312708377838
train time：7980, Loss: 0.44531192928552626

train time：8660, Loss: 1.361119379599889
train time：8680, Loss: 0.8178021311759949
train time：8700, Loss: 1.930748075246811
train time：8720, Loss: 1.0164958760142326
train time：8740, Loss: 0.2297824054956436
train time：8760, Loss: 0.6030746698379517
train time：8780, Loss: 1.4975757797559102
train time：8800, Loss: 1.4332445561885834
train time：8820, Loss: 1.7559938699007034
train time：8840, Loss: 1.157717490196228
train time：8860, Loss: 2.319817155599594
train time：8880, Loss: 1.7379385630289714
train time：8900, Loss: 1.4662968317667644
train time：8920, Loss: 1.3808779567480087
train time：8940, Loss: 1.8589484691619873
train time：8960, Loss: 0.6961897537112236
train time：8980, Loss: 0.45557025223970415
train time：9000, Loss: 1.446939930319786
output: tensor([   41,   769,    69,   364,   559,   281,   261,     7,    83,  5747,
          362,  4828,    14,   329,  2716,   358,   356,  1911,   358,   356,
         6696,  3690,   261,  1284,    14,     0,   281,   812,    14,   290,
      

train time：9800, Loss: 1.3067445382475853
train time：9820, Loss: 1.5470577970147132
train time：9840, Loss: 1.0471519857645035
train time：9860, Loss: 2.1130621135234833
train time：9880, Loss: 1.7535741925239563
train time：9900, Loss: 1.4005050261815388
train time：9920, Loss: 1.2824032306671143
train time：9940, Loss: 1.6894978880882263
train time：9960, Loss: 0.531804770231247
train time：9980, Loss: 0.38908875435590745
train time：10000, Loss: 1.2504704594612122
output: tensor([   41,   769,    69,   364,   559,   281,   261,     7,    83,   988,
          362,  4828,    14,   329, 21581,   358,   364,  1911,   358,   364,
         6696,  3690,   261,  1284,    14,     0,   302,   463,   302,    12,
          463,   463,  2491,   281,   463,   463,   463,     1,   463,     1,
           14,   302,   463,    14,   221,    14,   302,    14,    14,   302,
          302,   221,   281,   302,   281,   297,   302,   281,   281,    14,
         1156,   302,   463,    12,    14,    12,    14,   28

train time：10920, Loss: 1.1787843108177185
train time：10940, Loss: 1.5017244517803192
train time：10960, Loss: 0.477465208619833
train time：10980, Loss: 0.2846684232354164
train time：11000, Loss: 1.1732934415340424
output: tensor([   41,   769,    69,   356,   559,   281,  1822,     7,    83,   988,
         8905,  4828,    14,   329, 21581,   358,   385,  1911,   358,   356,
         6696,  3690,   261,  1284,    14,     0, 10498,  2255,   302, 13344,
          302,  8905,  6591, 13344,  2095, 10498,   302,  1156, 10498,   281,
        13344,   302,   302,   302,   281,   302,   302,   344,  2049,   302,
          302, 13344, 13344, 10498,   463,  8905, 10498,    68,  8905,  2095,
         8905,   281, 10498,   302,   302, 10498,   302,   302,   302,   302,
          281, 10498,  8667, 10498,   302,   463, 13344,  1419,  8905,  1156,
         2095,   302,  2255, 10498,   861,  2255,   302,  8905, 10498,   302,
        10498,  6481,  2255,   677, 10498,   302,   463,   281,   463,  8905

train time：12020, Loss: 0.6022599381394684
train time：12040, Loss: 0.7213764265179634
train time：12060, Loss: 1.1629134342074394
train time：12080, Loss: 0.594383429735899
train time：12100, Loss: 0.3527379222214222
train time：12120, Loss: 0.38614675775170326
train time：12140, Loss: 0.9126023799180984
train time：12160, Loss: 0.9513915002346038
train time：12180, Loss: 0.5401250198483467
train time：12200, Loss: 0.4598227192958196
train time：12220, Loss: 0.524180586139361
train time：12240, Loss: 0.5425337553024292
train time：12260, Loss: 0.6739733353257179
train time：12280, Loss: 0.4714317828416824
train time：12300, Loss: 0.5180784463882446
train time：12320, Loss: 1.4151592962443829
train time：12340, Loss: 1.206474393606186
train time：12360, Loss: 0.38259972631931305
train time：12380, Loss: 0.24869754910469055
train time：12400, Loss: 0.2985100507736206
train time：12420, Loss: 0.7500749565660954
train time：12440, Loss: 0.8512558601796627
train time：12460, Loss: 0.9000923931598663
train time：

train time：13120, Loss: 0.3522893004119396
train time：13140, Loss: 0.8293793002764384
train time：13160, Loss: 0.8296401798725128
train time：13180, Loss: 0.4872649386525154
train time：13200, Loss: 0.2517426485816638
train time：13220, Loss: 0.528001227726539
train time：13240, Loss: 0.4784950812657674
train time：13260, Loss: 0.6053914800286293
train time：13280, Loss: 0.3351176343858242
train time：13300, Loss: 0.33632177487015724
train time：13320, Loss: 1.3051776550710201
train time：13340, Loss: 1.0387099862098694
train time：13360, Loss: 0.3381088810662429
train time：13380, Loss: 0.25885607302188873
train time：13400, Loss: 0.26576599925756456
train time：13420, Loss: 0.5910102427005768
train time：13440, Loss: 0.8871560320258141
train time：13460, Loss: 1.0130474269390106
train time：13480, Loss: 0.36529452726244926
train time：13500, Loss: 0.46566570550203323
output: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787, 13379,    12, 

train time：14220, Loss: 0.39838681059579056
train time：14240, Loss: 0.36279748380184174
train time：14260, Loss: 0.5313649788498879
train time：14280, Loss: 0.394034144282341
train time：14300, Loss: 0.26997949555516243
train time：14320, Loss: 1.2275180462747812
train time：14340, Loss: 0.9704928815364837
train time：14360, Loss: 0.3922846739490827
train time：14380, Loss: 0.22117257118225098
train time：14400, Loss: 0.2385665088891983
train time：14420, Loss: 0.650890402495861
train time：14440, Loss: 0.7126604542136192
train time：14460, Loss: 0.8125459514558315
train time：14480, Loss: 0.27814004197716713
train time：14500, Loss: 0.40436673040191334
output: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787,  1421,    12,  1643,   747,  3428,  3322,
        11974, 21793,   434,  9180, 22294,   286,     1,  3370,    70, 22294,
          286,     1,  9970, 13328,   271,   677,  1146,    83,  2491,   358,
         8827,  3883,  2013,   

train time：15320, Loss: 1.121601689606905
train time：15340, Loss: 0.8839160561561584
train time：15360, Loss: 0.23955782875418663
train time：15380, Loss: 0.23083326369524002
train time：15400, Loss: 0.18107970654964448
train time：15420, Loss: 0.48360348492860794
train time：15440, Loss: 0.6588537804782391
train time：15460, Loss: 0.6920915059745312
train time：15480, Loss: 0.2810746170580387
train time：15500, Loss: 0.36189491860568523
output: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787, 13379,    12,  1643,   747,  3428,  3322,
         3914, 21793,   434,  9180, 22294,   286,     1,  3370,    70, 22294,
          286,     1,  9970, 13328,   271,   677,  1146,    83,  2491,   358,
         8827,  3883,  2013,     1,     0,     1,     1,     1,     1,     1,
            1,   812,     1,     1,     1,     1,   812,     1,     1,     1,
            1,  2255,     1,  9970,   812,     1,     1,     1,     1,     1,
            

train time：16420, Loss: 0.49170559272170067
train time：16440, Loss: 0.5787157528102398
train time：16460, Loss: 0.672347404062748
train time：16480, Loss: 0.2442345693707466
train time：16500, Loss: 0.4338065932194392
output: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787, 13379,    12,  1643,   747,  3428,  3322,
        11974, 21793,   434,  9180, 22294,   286,     1,  3370,    70, 22294,
          286,     1,  9970, 13328,   271,   677,  1146,    83,  2491,   358,
         8827,  3883,  2013,     1,   221,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     

train time：17520, Loss: 0.12417921920617421
train time：17540, Loss: 0.32116073966026304
train time：17560, Loss: 0.47707731276750565
train time：17580, Loss: 0.16304443466166654
train time：17600, Loss: 0.2316549881361425
train time：17620, Loss: 0.4322680216282606
train time：17640, Loss: 0.32097321897745135
train time：17660, Loss: 0.5281858593225479
train time：17680, Loss: 0.3750653713941574
train time：17700, Loss: 0.8752381429076195
train time：17720, Loss: 0.359677379950881
train time：17740, Loss: 0.15720960032194853
train time：17760, Loss: 0.21258042380213737
train time：17780, Loss: 0.6501675645510355
train time：17800, Loss: 0.6011014319956303
train time：17820, Loss: 0.886514800041914
train time：17840, Loss: 0.413235804438591
train time：17860, Loss: 1.1754458248615265
train time：17880, Loss: 0.8643080294132233
train time：17900, Loss: 0.6534556796153387
train time：17920, Loss: 0.5478517226874828
train time：17940, Loss: 0.7928371578454971
train time：17960, Loss: 0.2597496509552002
train t

train time：18680, Loss: 0.30574719111124676
train time：18700, Loss: 0.7657223790884018
train time：18720, Loss: 0.3232599589973688
train time：18740, Loss: 0.12546982653439046
train time：18760, Loss: 0.18393797054886818
train time：18780, Loss: 0.6106311877568563
train time：18800, Loss: 0.6323317475616932
train time：18820, Loss: 0.8894570484757424
train time：18840, Loss: 0.40210162103176117
train time：18860, Loss: 1.0185579434037209
train time：18880, Loss: 0.7707404096921285
train time：18900, Loss: 0.6219789187113444
train time：18920, Loss: 0.6469021737575531
train time：18940, Loss: 0.6948318853974342
train time：18960, Loss: 0.16092479787766933
train time：18980, Loss: 0.11894767805933952
train time：19000, Loss: 0.4471251778304577
output: tensor([  447,   769,    69,   364,   559,   281,  1822,     7,    83,   988,
         4753,  4828,    14,   329, 21581,   358,   364,  1911,   358,   356,
         6696,  3690,   261,  1284,    14,     0,    14,    14,    14,    14,
           14,    14,

train time：19840, Loss: 0.33260240256786344
train time：19860, Loss: 0.9091772064566612
train time：19880, Loss: 0.774943600098292
train time：19900, Loss: 0.5869205892086029
train time：19920, Loss: 0.5174745842814445
train time：19940, Loss: 0.6030955202877522
train time：19960, Loss: 0.20143266022205353
train time：19980, Loss: 0.14249256774783134
train time：20000, Loss: 0.4641747325658798
output: tensor([  447,   769,    69,   364,   559,   281,  1822,     7,    83,   988,
         4753,  4828,    14,   329, 21581,   358,   356,  1911,   358,   356,
         6696,  3690,   259,  1284,    14,     0,    14,   463,   382,  1911,
          281,     1,  1911,     1,    14,     1,  1156,   281,   302,   382,
          302,  1911,   302,   302,  1911, 13344,  1156, 13344,   281,  1911,
          463,  1911,   302,   463,  7789,   302,   382,  1015,   281,    14,
          302,     1,  1911,     1,   337,  1911,   463,  1911,  2255,   302,
          281,   281,   281,  6591,  1911,   302,    12, 

train time：21000, Loss: 0.34933337569236755
output: tensor([   41,   769,    69,   364,   559,   281,  1822,     7,    83,   988,
          362,  4828,    14,   329, 21581,   358,   364,  1911,   358,   364,
         6696,  3690,   261,  1284,    14,     0,   281,   358,     1,     1,
          281, 13344,  1911,   281,    14,  2095,    14,  7789,   281,     1,
          358,   281,     1,   281,    14,   812,   281,    20,     0,    12,
            1,   358,   281,  1911,     1,     1,     0,   812,    14,     1,
           14,     1,     1,    14,   812,     0,    14,     1,     1,     1,
            1,   281,  8905,    14,     1, 13344,    12,     1,    12,     1,
           12,   812,     1,   281,  1057,     1,   281,     1,   281,   281,
            1,   358,   812,  8905,     1,     1,   281,     0,     1,   618,
         8905,     1,     0,     0,    14,     1,     1,   281,    14,   281,
           12,   281,    14,     1,     0,  2491,     1,   463, 13344],
       device='cud

train time：22020, Loss: 0.19206967297941446
train time：22040, Loss: 0.28950487822294235
train time：22060, Loss: 0.5138293914496899
train time：22080, Loss: 0.16532990988343954
train time：22100, Loss: 0.10827352758497
train time：22120, Loss: 0.1165336612612009
train time：22140, Loss: 0.3543196991086006
train time：22160, Loss: 0.3523018181324005
train time：22180, Loss: 0.2534201256930828
train time：22200, Loss: 0.06617357954382896
train time：22220, Loss: 0.2154853499184052
train time：22240, Loss: 0.16473279893398285
train time：22260, Loss: 0.21580106830224394
train time：22280, Loss: 0.13347210697829723
train time：22300, Loss: 0.11197397019714117
train time：22320, Loss: 0.6890613473951817
train time：22340, Loss: 0.4985862076282501
train time：22360, Loss: 0.14430364469687143
train time：22380, Loss: 0.09245973452925682
train time：22400, Loss: 0.18538100719451905
train time：22420, Loss: 0.2968836072832346
train time：22440, Loss: 0.3141029980033636
train time：22460, Loss: 0.42767530400305986
t

train time：23120, Loss: 0.1309307785704732
train time：23140, Loss: 0.4136055260896683
train time：23160, Loss: 0.3556163996458054
train time：23180, Loss: 0.22621121630072594
train time：23200, Loss: 0.24006553987661997
train time：23220, Loss: 0.18990069814026356
train time：23240, Loss: 0.13727054248253504
train time：23260, Loss: 0.21464740987867117
train time：23280, Loss: 0.13814825154840946
train time：23300, Loss: 0.1531132636591792
train time：23320, Loss: 0.6170124318450689
train time：23340, Loss: 0.44745790064334867
train time：23360, Loss: 0.16330541794498762
train time：23380, Loss: 0.12042324841022492
train time：23400, Loss: 0.12873955816030502
train time：23420, Loss: 0.29790376499295235
train time：23440, Loss: 0.3397391140460968
train time：23460, Loss: 0.3741942085325718
train time：23480, Loss: 0.1205559391528368
train time：23500, Loss: 0.17128281046946844
output: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787, 13379,

train time：24280, Loss: 0.12101424336433411
train time：24300, Loss: 0.10285876225680113
train time：24320, Loss: 0.6031548995524645
train time：24340, Loss: 0.47326569855213163
train time：24360, Loss: 0.1532973935827613
train time：24380, Loss: 0.15416546538472176
train time：24400, Loss: 0.07203001882880926
train time：24420, Loss: 0.23752171359956264
train time：24440, Loss: 0.3365989848971367
train time：24460, Loss: 0.335148680023849
train time：24480, Loss: 0.10377231193706393
train time：24500, Loss: 0.14136449247598648
output: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787, 13379,    12,  1643,   747,  3428,  3322,
        11974, 21793,   434,  9180, 22294,   286,     1,  3370,    70, 22294,
          286,   364,  9970, 13328,   271,   677,  1146,    83,  2491,   358,
         8827,  3883,  2013,     1,     0,     1,     1,     1,   281,     1,
          281,   281,     1,     1,   677,     1,     1,     1,     1,   709,
 

train time：25440, Loss: 0.26081401016563177
train time：25460, Loss: 0.3325518788769841
train time：25480, Loss: 0.08756757713854313
train time：25500, Loss: 0.19752181321382523
output: tensor([   55,   479,    12,   329,  2857,     7,    84,  2255,   756,   358,
          508,  1421,   673,  7787, 13379,    12,  1643,   747,  3428,  3322,
        11974, 21793,   434,  9180, 22294,   286,   364,  3370,    70, 22294,
          286,   364,  9970, 13328,   271,   677,  1146,    83,  2491,   358,
         8827,  3883,  2013,     1,     0,     1,   812,     1,   364,   709,
            1,   709,     1,   709,  1408,     1,  1408,     1,   394,   394,
         1408,     1,   434,     1,   434,     1,     1,   434,     1,  1408,
            1,     1,     1,     1,     1,     1,     1,  1421,   434,  1408,
          364,  1408,     1,     1,   364,     1,     1,     1,     1,   281,
          364,     1,  1408,  1408,   606,  1574,     1,   434,   434,     1,
          434,     1,     1,   364,  

train time：26520, Loss: 0.05798732861876488
train time：26540, Loss: 0.1407766687683761
train time：26560, Loss: 0.20435765199363232
train time：26580, Loss: 0.07890138542279601
train time：26600, Loss: 0.07464280258864164
train time：26620, Loss: 0.1601993814110756
train time：26640, Loss: 0.14882045965641738
train time：26660, Loss: 0.3132480134566625
train time：26680, Loss: 0.17736243704954782
train time：26700, Loss: 0.36645600013434887
train time：26720, Loss: 0.1444625398144126
train time：26740, Loss: 0.12787606874480845
train time：26760, Loss: 0.13152771443128586
train time：26780, Loss: 0.3959488794207573
train time：26800, Loss: 0.2705882843583822
train time：26820, Loss: 0.516626306436956
train time：26840, Loss: 0.19460764080286025
train time：26860, Loss: 0.5548914410173893
train time：26880, Loss: 0.39443185925483704
train time：26900, Loss: 0.3573960984746615
train time：26920, Loss: 0.34417247027158737
train time：26940, Loss: 0.33158006332814693
train time：26960, Loss: 0.0943971034139394

train time：27620, Loss: 0.17585923429578543
train time：27640, Loss: 0.1678120207041502
train time：27660, Loss: 0.25635263820489246
train time：27680, Loss: 0.1443654124935468
train time：27700, Loss: 0.3413214888423681
train time：27720, Loss: 0.1302690990269184
train time：27740, Loss: 0.09789410643279553
train time：27760, Loss: 0.20855445973575115
train time：27780, Loss: 0.3075302590926488
train time：27800, Loss: 0.2289806269109249
train time：27820, Loss: 0.4547509679570794
train time：27840, Loss: 0.14927096664905548
train time：27860, Loss: 0.5772981569170952
train time：27880, Loss: 0.3888806601365407
train time：27900, Loss: 0.314778375128905
train time：27920, Loss: 0.2775045596063137
train time：27940, Loss: 0.2869945038110018
train time：27960, Loss: 0.08899808302521706
train time：27980, Loss: 0.033781392872333525
train time：28000, Loss: 0.2557854223996401
output: tensor([  447,   769,    69,   364,   559,   281,  1822,     7,    83,   988,
         4753,  4828,    14,   329, 21581,   35

train time：28780, Loss: 0.2845877756675084
train time：28800, Loss: 0.19835292268544436
train time：28820, Loss: 0.45505195800215004
train time：28840, Loss: 0.16601889878511428
train time：28860, Loss: 0.4925236999988556
train time：28880, Loss: 0.3441740075747172
train time：28900, Loss: 0.3170030315717061
train time：28920, Loss: 0.25975281838327646
train time：28940, Loss: 0.2138468325138092
train time：28960, Loss: 0.05496000777930021
train time：28980, Loss: 0.030901755578815936
train time：29000, Loss: 0.17685313615947962
output: tensor([  447,   769,    69,   364,   559,   281,  1822,     7,    83,   988,
         4753,  4828,    14,   329, 21581,   358,   364,  1911,   358,   364,
         6696,  3690,   261,  1284,    14,     0,    12,   358,   281,   358,
           14,     0,    14,    14,   358,   281,     0,   358,   463,   463,
           14,  1156,   358,     0,   358,  1156,    14,   358,     0,   281,
          358,   281,  8905,   358,     0,   281,   281,   281,   302,   281,


KeyboardInterrupt: 

In [41]:
torch.save(AI.state_dict(),'MetaAI4.pth')

In [42]:
writer.close()