## BERT&co

ML-часть, распиленная на пайплайны:

    Обучатор берта. Возьмём версию на PyTorch от huggingface. Для английского отсутствует — можно взять предобученный от гугла.
    Первый дообучатор берта. Обучается на вопросах-ответах как болталка. После этого ответная башня выкидывается и сохранаятся только вопросная — болталка нам не нужна.
    Второй дообучатор берта. Обучается ранжировать (если данных совсем много — классифицировать) только вопросы по близости через триплет лосс. Требует много реальной разметки, не обязателен.
    Парсер диалоговых данных. Сначала возьмём какой-нибудь ubuntu dialogue corpus, но в будущем нужно будет напарсить какой-нибудь твиттер или реддит и хорошо дообучиться на них.

В репозитории ml должны быть скрипты для сбора данных (изначально только wget убунту диалог корпуса) и пайплайн для дообучения берта под диалоги. That's it. Результатом основного скрипта для обучения будут два файла — сериализованная моделька и токенизатор — и, возможно, какие-нибудь скрипты, чтобы их можно было использовать бэкэнду на чистом сервере.

За основу имеет смысл взять тот репозиторий от huggingface. ЕМНИП, там токенизатор встроен в модель или куда-то на высоком уровне.

Там можно несложными хаками докрутить поверх эмбеддера ещё голову, которая будет делать ранжирование (нужно два раза инициализировать берт — сиамская сеть же, нужны две разные башни). Само обучение будет выглядеть так: нарезать данные формата вопрос-правильный_ответ и засунуть в большой батч (скажем, 64 примера), внутри которого для каждого вопроса все остальные 63 ответа считаются негативными. Векторизовав весь батч и посчитав «матрицу умножения», то есть все попарные скалярные произведения, можно эффективнее считать какой-нибудь лосс для ранжирования (см. презентацию).


In [1]:
#!pip3 install pytorch_pretrained_bert

In [2]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertForQuestionAnswering, BertModel
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
from torchvision import transforms
from pytorch_pretrained_bert.optimization import BertAdam
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
import PIL
from IPython.display import HTML
import pickle
from torch.utils.data import Dataset, DataLoader
import os
import csv
import random
from sklearn.utils import shuffle
import time
%matplotlib inline

In [3]:
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

In [4]:
logger = logging.getLogger(__name__)

In [5]:
!./download_datasets.sh

Корпус влезет в оперативную память. 

In [6]:
import re
def remove_urls (vTEXT):
    vTEXT = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[link]', vTEXT, flags=re.MULTILINE)
    return(vTEXT)


print( remove_urls("this is a test https://sdfs.sdfsdf.com/sdfsdf/sdfsdf?233/sd/sdfsdfs?bob=%20tree&jef=man lets see this too https://sdfsdf.fdf.com/sdf/f end"))

this is a test [link] lets see this too [link] end


In [7]:
device = torch.device('cuda:0') #('cpu')
bert_type = 'bert-base-uncased'
max_seq_len = 512 # BERT-BASE restriction
cache_dir = './pretrained-' + bert_type
tokenizer = BertTokenizer.from_pretrained(bert_type, cache_dir=cache_dir)

INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./pretrained-bert-base-uncased/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [8]:
class UbuntuCorpus(Dataset):
  def __init__(self, tokenizer, rootdir='./dialogs'):
    super(UbuntuCorpus, self).__init__()
    dialogs = []
    _cnt = 3000 # debug constant
    
    # punctuations signs after which we put [SEP] token
    punctuation_seps = ['?!', '!?', '?', '...', '. '] 
    
    qa_pairs = []
    
    for subdir in os.listdir(rootdir):
      for dialog in os.listdir(rootdir + '/' + subdir):
        path = rootdir + '/' + subdir +'/' + dialog 
        with open(path) as tsvfile:
          reader = csv.reader(tsvfile, delimiter='\t')
          rows = [(row[1], row[-1]) for row in reader]
          replicas = []
          authors = set()
          author = -1
          for row in rows:
            if author == row[0]:
              replicas[-1].append(row[1])
            else:
              author = row[0]
              authors.add(author)
              replicas.append([row[1]])
              
          '''
          Answer replic is a replic without ?
          Question replic is a replic with ? followed by answer replic
          
          Both must be longer than thr (after link replacemenets)
          
          And due to BERT restrictions in tokenized form shorter than max_seq_len
          '''
          
          for i in range(len(replicas)):
            replicas[i] = '[CLS] ' + remove_urls(' '.join(replicas[i]))
            
            codephrase = 'evilcyborgswillkillhumanity'
            
            sep_token = '[SEP]'
            
            for (ind, el) in enumerate(punctuation_seps):
              crouch = codephrase + f'{ind} ' + sep_token + ' '
              replicas[i] = replicas[i].replace(el, crouch)
            
            for (ind, el) in enumerate(punctuation_seps):
              replicas[i] = replicas[i].replace(codephrase + f'{ind}', el)
              
            
            if replicas[i].rstrip()[-len(sep_token):] != sep_token:
              replicas[i] = replicas[i] + ' ' + sep_token
          
          thr = 120
          
          for i in range(len(replicas) - 1):
            if replicas[i].count('?') > 0 and replicas[i + 1].count('?') == 0 \
              and min(len(replicas[i]), len(replicas[i + 1])) > thr \
              and len(tokenizer.tokenize(replicas[i])) <= max_seq_len \
              and len(tokenizer.tokenize(replicas[i + 1])) <= max_seq_len:
              qa_pairs.append([replicas[i], replicas[i + 1]])
              _cnt -= 1
              if _cnt <=0:
                break
          
          
    
          #for replica in replicas:
          #  print('>>>', replica)
          #  print()
          #print(authors)
          #print()
          #print()
        
        if _cnt <= 0:
          break
      if _cnt <=0:
          break
    '''for el in qa_pairs:
      print('>>', el[0])
      print('>>>', el[1])
      print()'''
    
    self.qa_pairs = qa_pairs
  
  def __len__(self):
    return len(self.qa_pairs)
  
  def __getitem__(self, ind):
    '''
    #self.tokenizer = ??
    answ = [self.qa_pairs[ind][1]]
    used = {ind}
    while len(answ) != batch_size:
      ind2 = random.randint(0, len(self) - 1)     
      if ind2 not in used:
        used.add(ind2)
        answ.append(self.qa_pairs[ind2][1])
    #print('#%^^&', answ)
    '''
    return (self.qa_pairs[ind][0], self.qa_pairs[ind][1])#answ)      
        
corpus = UbuntuCorpus(tokenizer) # full corpus, 1,917,802 qa pairs 
print(len(corpus))

3000


In [9]:
!nvidia-smi

Mon Jul 15 19:50:48 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.67       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce RTX 2070    On   | 00000000:06:00.0  On |                  N/A |
| 27%   45C    P8    22W / 175W |    357MiB /  7949MiB |     11%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [10]:
#pickle.dump(corpus, open( "./corpus.p", "wb" ))

In [11]:
item = corpus[random.randint(0, len(corpus) - 1)]
#print(item)
print(item[0])
print(tokenizer.tokenize(item[0]))
print('----------------')

print(item[1])

[CLS] I am attempting to install GnuGo & Quarry (a GUI to handle GnuGo); I was going to install the tarball, and had downloaded the tars for both, but then I found them both on aptitude.  [SEP] So I installed them (I think? [SEP] ) with aptitude's terminal UI.  [SEP] How do I use them now? [SEP]  typing in "quarry" gives me nothing.  [SEP] I am attempting to install GnuGo & Quarry (a GUI to handle GnuGo); I was going to install the tarball, and had downloaded the tars for both, but then I found them both on aptitude.  [SEP] So I installed them (I think? [SEP] ) with aptitude's terminal UI.  [SEP] How do I use them now? [SEP]  Typing in "quarry" gives me nothing.  [SEP] Can anyone help me understand Aptitude? [SEP]  kubuntu? [SEP]  I'm not using KDE, I use fluxbox.  [SEP] Mark, I'm using terminal.  [SEP] I've never used the gui, don't particularly want to.  [SEP] mark ryan: I just used sudo aptitude, found two programs I wanted (gnugo & quarry), downloaded them (or so it seems? [SEP] ) 

In [12]:
batch_size = 2
trainloader = DataLoader(corpus, batch_size=batch_size, shuffle=True)
batch = None
for el in trainloader:
  batch = el
  break

print(batch)

[("[CLS] what version of ubuntu is this ? [SEP]  right click on the pannel and try to re-add gnome-network manager or launch it manually from the command line, see if it shows you why it's crashing [SEP]", '[CLS] howso? [SEP]  just before you said it does matter, i figured id put the cd i used to install 10.04 to begin with in, and tr it [SEP]'), ('[CLS] it start without problem but in wired connection and wireless network not appear nothing about adapters.  [SEP] they are disabled (but working at least eth0) [SEP]', "[CLS] If it's 10.10 Desktop that makes it easier because 10.10 has a loopback.cfg: [link] .  [SEP] If it's an Alternate install iso then you'll have problems with any version number of Ubuntu. [SEP]")]


In [13]:
def prepare_batch(batch):
  (quests, answs) = batch
  quests = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(el)) for el in quests]
  answs = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(el)) for el in answs]
  
  quest_segments = [torch.tensor([[0 for i in range(len(quests[j]))]]) for j in range(len(quests))]
  answ_segments = [torch.tensor([[0 for i in range(len(answs[j] ))]]) for j in range(len(answs))]
  
  quests = [torch.tensor([el]) for el in quests]
  answs = [torch.tensor([el]) for el in answs]
  
  return ((quests, quest_segments), (answs, answ_segments))
  
prepare_batch(batch)
0

0

In [14]:
def get_embedding(embeddings):
  '''
  using default bert-as-service strategy to get fixed-size vector
  1. considering only -2 layer
  2. "REDUCE_MEAN 	take the average of the hidden state of encoding layer on the time axis" @bert-as-service  
  '''
  embeddings = embeddings[-2]
  result = torch.sum(embeddings, dim=1)
  
  return result.to(device)

def embed_batch(batch, qembedder, aembedder):
  ((quests, quest_segments), (answs, answ_segments)) = batch
  
  #print(quests[0])
  #print(quest_segments[0])
  
  tmp_quest = [get_embedding(qembedder(quests[i].to(device), quest_segments[i].to(device))[0]) for i in range(len(quests))]
  tmp_answ = [get_embedding(aembedder(answs[i].to(device), answ_segments[i].to(device))[0]) for i in range(len(answs))]
  
  qembeddings = torch.cat(tmp_quest)
  aembeddings = torch.cat(tmp_answ)
    
  return (qembeddings, aembeddings)

#embed_batch(prepare_batch(batch), qembedder, aembedder)
0

0

In [15]:
def hinge_loss(X, Y, margin=0.1):
  batch_size = X.shape[0]
  similarities = cosine_similarity_table(X, Y)
  #^ см. ниже
  
  identity = torch.eye(batch_size, device=X.device)
  non_diagonal = torch.ones_like(similarities) - identity
  
  targets = identity - non_diagonal
  weights = identity + non_diagonal / (batch_size - 1)
  
  #всё то же самое, но лосс другой: учитываем только то, что не превосходит margin
  losses = torch.pow(F.relu(margin - targets * similarities), 2)
  return torch.mean(losses * weights)

def cosine_similarity_table(X, Y):
  X = F.normalize(X)
  Y = F.normalize(Y)
  return torch.mm(X, Y.transpose(0, 1))

In [16]:
def bce_loss(X, Y, conf_true=0.9, conf_false=0.1): 
  '''на вход пришел батч размера n,
  мы векторизовали контексты (X)
  и ответы (Y) и хотим сделать n*n
  независимых классификаций
  '''
  n = X.shape[0]

  logits = torch.mm(X, Y.transpose(0, 1)) # считаем таблицу умножения
  identity = torch.eye(n, device=X.device)
  
  non_diagonal = torch.ones_like(logits) - identity
  targets = identity * conf_true + non_diagonal * conf_false
  #получаем матрицу с conf_true на диагонали и conf_false где-либо ещё
  
  weights = identity + non_diagonal / (n - 1)
  # ^ чтобы не было перекоса в сторону негативов
  return F.binary_cross_entropy_with_logits(logits, targets, weights) * n

In [17]:
def calc_acc(X, Y):
    '''на вход пришел батч размера n,
    мы векторизовали контексты (X)
    и ответы (Y)'''
    
    csim = cosine_similarity_table(X, Y)
    confidence, predictions = csim.max(-1)
    avg = confidence.mean().item()
    predictions = list(predictions.cpu())
    right = 0
    for i in range(len(predictions)):
        right += predictions[i] == i
    return right
    
X = torch.tensor([[0, .1], [.1, 0]], device=device)
Y = torch.tensor([[0, .1], [.1, 0]], device=device)

print(calc_acc(X, Y))

X = torch.tensor([[0, .1], [.1, 0]], device=device)
Y = torch.tensor([[.1, 0], [0, .1]], device=device)

print(calc_acc(X, Y))

tensor(2, dtype=torch.uint8)
tensor(0, dtype=torch.uint8)


In [18]:
def get_optimizer_params(model):
  param_optimizer = list(model.named_parameters())
  no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  
  optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
  ]
  
  return optimizer_grouped_parameters

#get_optimizer_params(qembedder)
0

0

In [19]:
test_size = int(len(corpus) * .33)
train_size = len(corpus) - test_size
train_corpus, test_corpus = torch.utils.data.random_split(corpus, [train_size, test_size])

In [20]:
def train(epochs):
  batch_size = 15
  trainloader = DataLoader(corpus, batch_size=batch_size, shuffle=True)
  num_train_optimization_steps = len(corpus) * epochs
  
  '''
  optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=5e-5,
                                 warmup=0.1,
                                 t_total=num_train_optimization_steps)
  
  '''
  
  lr = 5e-5
  warmup = 0.1
  
  qoptim = BertAdam(get_optimizer_params(qembedder),
                                  lr=lr,
                                  warmup=warmup,
                                  t_total=num_train_optimization_steps)
  aoptim = BertAdam(get_optimizer_params(aembedder),
                                  lr=lr,
                                  warmup=warmup,
                                  t_toal=num_train_optimization_steps)
  criterion = hinge_loss
  
  total = right = 0
  with torch.no_grad():
    for batch in trainloader:
        total += len(batch[0])
        embeddings = embed_batch(prepare_batch(batch), qembedder, aembedder)
        right += calc_acc(*embeddings) 

  qembedder.train()
  aembedder.train()
  
  logger.info("***** Running training *****")
  logger.info("  Num steps = %d", num_train_optimization_steps)  
  logger.info(f" right: {right} of {total}")
  
  start_training = time.time()
  for epoch in range(epochs):
    total_loss = 0
    start_epoch = time.time()
    qembedder.train()
    aembedder.train()
    for bidx, batch in enumerate(trainloader):
      qoptim.zero_grad()
      aoptim.zero_grad()
      print('batch_index', bidx)
      embeddings = embed_batch(prepare_batch(batch), qembedder, aembedder)
      loss = bce_loss(*embeddings)
      total_loss += loss.item()
      loss.backward()

      qoptim.step()
      aoptim.step()
    
    end_epoch = time.time()
    
    total = right = 0
    qembedder.eval()
    aembedder.eval()
    with torch.no_grad():
        for batch in trainloader:
            total += len(batch[0])
            embeddings = embed_batch(prepare_batch(batch), qembedder, aembedder)
            right += calc_acc(*embeddings) 

    logger.info(f'epoch {epoch} loss: {total_loss} time: {int(end_epoch - start_epoch)}')
    logger.info(f" right: {right} of {total}")
    
  end_training = time.time()
  logger.info(f'Training is compleated time: {int(end_training - start_training)}')
  torch.cuda.empty_cache()

In [21]:
qembedder = BertModel.from_pretrained(bert_type, cache_dir=cache_dir).to(device)
aembedder = BertModel.from_pretrained(bert_type, cache_dir=cache_dir).to(device)
train(10)

INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./pretrained-bert-base-uncased/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
INFO:pytorch_pretrained_bert.modeling:extracting archive file ./pretrained-bert-base-uncased/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpfl9w6j_n
INFO:pytorch_pretrained_bert.modeling:Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.am

batch_index 0
batch_index 1
batch_index 2
batch_index 3
batch_index 4
batch_index 5
batch_index 6
batch_index 7
batch_index 8
batch_index 9
batch_index 10
batch_index 11
batch_index 12
batch_index 13
batch_index 14
batch_index 15
batch_index 16
batch_index 17
batch_index 18
batch_index 19
batch_index 20
batch_index 21
batch_index 22
batch_index 23
batch_index 24
batch_index 25
batch_index 26
batch_index 27
batch_index 28
batch_index 29
batch_index 30
batch_index 31
batch_index 32
batch_index 33
batch_index 34
batch_index 35
batch_index 36
batch_index 37
batch_index 38
batch_index 39
batch_index 40
batch_index 41
batch_index 42
batch_index 43
batch_index 44
batch_index 45
batch_index 46
batch_index 47
batch_index 48
batch_index 49
batch_index 50
batch_index 51
batch_index 52
batch_index 53
batch_index 54
batch_index 55
batch_index 56
batch_index 57
batch_index 58
batch_index 59
batch_index 60
batch_index 61
batch_index 62
batch_index 63
batch_index 64
batch_index 65
batch_index 66
batch

INFO:__main__:epoch 0 loss: 1347817.1661987305 time: 241
INFO:__main__: right: 204 of 3000


batch_index 0
batch_index 1
batch_index 2
batch_index 3
batch_index 4
batch_index 5
batch_index 6
batch_index 7
batch_index 8
batch_index 9
batch_index 10
batch_index 11
batch_index 12
batch_index 13
batch_index 14
batch_index 15
batch_index 16
batch_index 17
batch_index 18
batch_index 19
batch_index 20
batch_index 21
batch_index 22
batch_index 23
batch_index 24
batch_index 25
batch_index 26
batch_index 27
batch_index 28
batch_index 29
batch_index 30
batch_index 31
batch_index 32
batch_index 33
batch_index 34
batch_index 35
batch_index 36
batch_index 37
batch_index 38
batch_index 39
batch_index 40
batch_index 41
batch_index 42
batch_index 43
batch_index 44
batch_index 45
batch_index 46
batch_index 47
batch_index 48
batch_index 49
batch_index 50
batch_index 51
batch_index 52
batch_index 53
batch_index 54
batch_index 55
batch_index 56
batch_index 57
batch_index 58
batch_index 59
batch_index 60
batch_index 61
batch_index 62
batch_index 63
batch_index 64
batch_index 65
batch_index 66
batch

INFO:__main__:epoch 1 loss: 130649.66033935547 time: 238
INFO:__main__: right: 205 of 3000


batch_index 0
batch_index 1
batch_index 2
batch_index 3
batch_index 4
batch_index 5
batch_index 6
batch_index 7
batch_index 8
batch_index 9
batch_index 10
batch_index 11
batch_index 12
batch_index 13
batch_index 14
batch_index 15
batch_index 16
batch_index 17
batch_index 18
batch_index 19
batch_index 20
batch_index 21
batch_index 22
batch_index 23
batch_index 24
batch_index 25
batch_index 26
batch_index 27
batch_index 28
batch_index 29
batch_index 30
batch_index 31
batch_index 32
batch_index 33
batch_index 34
batch_index 35
batch_index 36
batch_index 37
batch_index 38
batch_index 39
batch_index 40
batch_index 41
batch_index 42
batch_index 43
batch_index 44
batch_index 45
batch_index 46
batch_index 47
batch_index 48
batch_index 49
batch_index 50
batch_index 51
batch_index 52
batch_index 53
batch_index 54
batch_index 55
batch_index 56
batch_index 57
batch_index 58
batch_index 59
batch_index 60
batch_index 61
batch_index 62
batch_index 63
batch_index 64
batch_index 65
batch_index 66
batch

INFO:__main__:epoch 2 loss: 111859.43228149414 time: 237
INFO:__main__: right: 201 of 3000


batch_index 0
batch_index 1
batch_index 2
batch_index 3
batch_index 4
batch_index 5
batch_index 6
batch_index 7
batch_index 8
batch_index 9
batch_index 10
batch_index 11
batch_index 12
batch_index 13
batch_index 14
batch_index 15
batch_index 16
batch_index 17
batch_index 18
batch_index 19
batch_index 20
batch_index 21
batch_index 22
batch_index 23
batch_index 24
batch_index 25
batch_index 26
batch_index 27
batch_index 28
batch_index 29
batch_index 30
batch_index 31
batch_index 32
batch_index 33
batch_index 34
batch_index 35
batch_index 36
batch_index 37
batch_index 38
batch_index 39
batch_index 40
batch_index 41
batch_index 42
batch_index 43
batch_index 44
batch_index 45
batch_index 46
batch_index 47
batch_index 48
batch_index 49
batch_index 50
batch_index 51
batch_index 52
batch_index 53
batch_index 54
batch_index 55
batch_index 56
batch_index 57
batch_index 58
batch_index 59
batch_index 60
batch_index 61
batch_index 62
batch_index 63
batch_index 64
batch_index 65
batch_index 66
batch

INFO:__main__:epoch 3 loss: 93859.02690124512 time: 235
INFO:__main__: right: 197 of 3000


batch_index 0
batch_index 1
batch_index 2
batch_index 3
batch_index 4
batch_index 5
batch_index 6
batch_index 7
batch_index 8
batch_index 9
batch_index 10
batch_index 11
batch_index 12
batch_index 13
batch_index 14
batch_index 15
batch_index 16
batch_index 17
batch_index 18
batch_index 19
batch_index 20
batch_index 21
batch_index 22
batch_index 23
batch_index 24
batch_index 25
batch_index 26
batch_index 27
batch_index 28
batch_index 29
batch_index 30
batch_index 31
batch_index 32
batch_index 33
batch_index 34
batch_index 35
batch_index 36
batch_index 37
batch_index 38
batch_index 39
batch_index 40
batch_index 41
batch_index 42
batch_index 43
batch_index 44
batch_index 45
batch_index 46
batch_index 47
batch_index 48
batch_index 49
batch_index 50
batch_index 51
batch_index 52
batch_index 53
batch_index 54
batch_index 55
batch_index 56
batch_index 57
batch_index 58
batch_index 59
batch_index 60
batch_index 61
batch_index 62
batch_index 63
batch_index 64
batch_index 65
batch_index 66
batch

INFO:__main__:epoch 4 loss: 94527.99389648438 time: 233
INFO:__main__: right: 204 of 3000


batch_index 0
batch_index 1
batch_index 2
batch_index 3
batch_index 4
batch_index 5
batch_index 6
batch_index 7
batch_index 8
batch_index 9
batch_index 10
batch_index 11
batch_index 12
batch_index 13
batch_index 14
batch_index 15
batch_index 16
batch_index 17
batch_index 18
batch_index 19
batch_index 20
batch_index 21
batch_index 22
batch_index 23
batch_index 24
batch_index 25
batch_index 26
batch_index 27
batch_index 28
batch_index 29
batch_index 30
batch_index 31
batch_index 32
batch_index 33
batch_index 34
batch_index 35
batch_index 36
batch_index 37
batch_index 38
batch_index 39
batch_index 40
batch_index 41
batch_index 42
batch_index 43
batch_index 44
batch_index 45
batch_index 46
batch_index 47
batch_index 48
batch_index 49
batch_index 50
batch_index 51
batch_index 52
batch_index 53
batch_index 54
batch_index 55
batch_index 56
batch_index 57
batch_index 58
batch_index 59
batch_index 60
batch_index 61
batch_index 62
batch_index 63
batch_index 64
batch_index 65
batch_index 66
batch

INFO:__main__:epoch 5 loss: 76324.4669342041 time: 240
INFO:__main__: right: 209 of 3000


batch_index 0
batch_index 1
batch_index 2
batch_index 3
batch_index 4
batch_index 5
batch_index 6
batch_index 7
batch_index 8
batch_index 9
batch_index 10
batch_index 11
batch_index 12
batch_index 13
batch_index 14
batch_index 15
batch_index 16
batch_index 17
batch_index 18
batch_index 19
batch_index 20
batch_index 21
batch_index 22
batch_index 23
batch_index 24
batch_index 25
batch_index 26
batch_index 27


RuntimeError: CUDA out of memory. Tried to allocate 90.00 MiB (GPU 0; 7.76 GiB total capacity; 6.57 GiB already allocated; 49.06 MiB free; 59.19 MiB cached)