In [6]:
import json

from tokenizers.implementations import BertWordPieceTokenizer
from Bert import Bert
from layers import Train
import torch
from torch.nn import Linear
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [7]:
_tokenizer = BertWordPieceTokenizer("../custom/vocab.txt")
_embedding_dim = 384
_hidden_size = 3072
_num_head = 12
_out_dim = 512
max_epoch = 17
batch = 140
_num_layers = 12
vocab_size = _tokenizer.get_vocab_size()
bert = Bert(_embedding_dim, _hidden_size, _num_head, 128, _num_layers, _tokenizer)
bert.load_state_dict(torch.load("../bert_impl_weights/bert.pth"))

<All keys matched successfully>

In [8]:
from torch.nn.utils.rnn import pad_sequence

with open("../bert_impl_data/emo_train.txt", "r", encoding="utf-8") as fp:
    label = set([])
    data = fp.readlines()
    sentence = []
    labels = []
    max_len = 0
    for idx, i in enumerate(data):
        datas = data[idx].split(';')
        l = datas[1].replace('\n', '')
        sentence.append(datas[0])
        labels.append(l)
        label.add(l)
    label = {v: k for k, v in enumerate(label)}
    for idx, i in enumerate(labels):
        labels[idx] = label[labels[idx]]
    for idx, sent in enumerate(sentence):
        sentence[idx] = torch.tensor(_tokenizer.encode(sent).ids, dtype=torch.long, device=device)
        if len(sentence[idx]) >= max_len:
            max_len = len(sentence[idx])
    sentence.append(torch.tensor([0] * max_len, dtype=torch.long, device=device))
    sentence = pad_sequence(sentence, batch_first=True)
    sentence = sentence[:-1]
    labels = torch.tensor(labels)
    with open("../bert_impl_data/emo_class.json", "w", encoding='utf-8') as f:
        f.write(json.dumps(label, indent=4))
sentence

tensor([[    2,    50,  7569,  ...,     0,     0,     0],
        [    2,    50,  1935,  ...,     0,     0,     0],
        [    2,  2399, 12844,  ...,     0,     0,     0],
        ...,
        [    2,    50,  5271,  ...,     0,     0,     0],
        [    2,    50,  5271,  ...,     0,     0,     0],
        [    2,    50,  3821,  ...,     0,     0,     0]], device='cuda:0')

In [9]:
with open("../bert_impl_data/emo_val.txt", "r", encoding="utf-8") as fp:
    data = fp.readlines()
    sentence_val = []
    labels_val = []
    max_len = 0
    for idx, i in enumerate(data):
        datas = data[idx].split(';')
        l = datas[1].replace('\n', '')
        sentence_val.append(datas[0])
        labels_val.append(l)
    for idx, i in enumerate(labels_val):
        labels_val[idx] = label[labels_val[idx]]
    for idx, sent in enumerate(sentence_val):
        sentence_val[idx] = torch.tensor(_tokenizer.encode(sent).ids, dtype=torch.long, device=device)
        if len(sentence_val[idx]) >= max_len:
            max_len = len(sentence_val[idx])
    sentence_val.append(torch.tensor([0] * (max_len + 1), dtype=torch.long, device=device))
    sentence_val = pad_sequence(sentence_val, batch_first=True)
    sentence_val = sentence_val[:-1]
    labels_val = torch.tensor(labels_val)
sentence_val

tensor([[    2,  2399, 10133,  ...,     0,     0,     0],
        [    2,    50,  5271,  ...,     0,     0,     0],
        [    2,    50,  5271,  ...,     0,     0,     0],
        ...,
        [    2,    50,  5271,  ...,     0,     0,     0],
        [    2,    50, 13385,  ...,     0,     0,     0],
        [    2,    50,  5271,  ...,     0,     0,     0]], device='cuda:0')

In [10]:
import time


class Emo_trainer(Train):
    def __init__(self, model, optimizer):
        super().__init__(model, optimizer)

    def down_stream(self, batch_size, max_epoch, layer: torch.nn.Module, log=True, log_dir=None, Tensorboard_reloadInterval=30,
                    log_file_name='', monitor=True, pick_params=False):
        from torch.cuda.amp import GradScaler, autocast
        begin = time.time()
        max_iter = 0
        loss = 0
        scaler = GradScaler()
        if log:
            self.open_tensorboard(log_dir, Tensorboard_reloadInterval, f"({log_file_name})")
        COUNT = 0
        loss_func = torch.nn.CrossEntropyLoss()
        path = None
        min_loss = 20
        best_loss = 20

        def run(train):
            try:
                score = self._model.forward(train, mask.to(device=device))
                score = layer(score)[:, 0, :]
                score = loss_func.forward(score, answer)
                return score
            except Exception as e:
                raise e

        train_question = sentence
        test_question = sentence_val
        train_answer = labels
        test_answer = labels_val
        for epoch in range(max_epoch):
            # torch.cuda.empty_cache()
            iters = 0
            max_iter = len(train_question) // batch_size
            average_loss = 0
            for i in range(max_iter):
                start = i * batch_size + 1
                if pick_params:
                    article = train_question[start:(i + 1) * batch_size + 1].to(device=device, dtype=torch.long)
                    mask0 = train_answer[start:(i + 1) * batch_size + 1].to(device=device)
                self._optimizer.zero_grad()
                with autocast():
                    batch_question = train_question[start:(i + 1) * batch_size + 1].to(device=device)
                    answer = train_answer[start:(i + 1) * batch_size + 1].to(device=device)
                    mask = batch_question != 0
                    loss = run(batch_question)
                scaler.scale(loss).backward()
                scaler.step(self._optimizer)
                scaler.update()
                loss = loss.detach_().item()
                average_loss += loss

                self.print_result((epoch, max_epoch), (iters, max_iter), loss, begin=begin, timing=True)
                iters += 1
            if self.writer:
                COUNT += 1
                try:
                    if average_loss / max_iter < min_loss:
                        min_loss = average_loss
                        try:
                            torch.save(layer.state_dict(), '../bert_impl_weights/down_stream_bert_emo_layer.pth')
                            torch.save(self._model.state_dict(), '../bert_impl_weights/down_stream_bert_emo.pth')
                        except Exception as e:
                            print(e)
                except ZeroDivisionError:
                    print(average_loss, max_iter, path, len(train_question), batch_size)
                correctness = self._model.down_stream(test_question, test_answer, batch_size, layer)
                self.writer.add_scalar("loss", average_loss / max_iter, COUNT)
                self.writer.add_scalar("correctness", correctness, COUNT)
            self.print_result((epoch, max_epoch), (max_iter, max_iter), loss, begin=begin, timing=True)
            try:
                if best_loss >= average_loss:
                    torch.save(layer.state_dict(), '../bert_impl_weights/down_stream_bert_emo_layer.pth')
                    torch.save(self._model.state_dict(), '../bert_impl_weights/down_stream_bert_emo.pth')
                    best_loss = average_loss
                torch.save(layer.state_dict(), '../bert_impl_weights/down_stream_bert_emo_layer.pth')
                torch.save(self._model.state_dict(), '../bert_impl_weights/down_stream_bert_emo.pth')
            except Exception as e:
                print(e)
        self.print_result((max_epoch, max_epoch), (max_iter, max_iter), loss, begin=begin, timing=True)
        if self.writer is not None and self.tensorboard_process is not None:
            self.writer.close()
            self.tensorboard_process.terminate()
bert.train()
layers = Linear(384, len(label), device=bert.device)
optimizer = torch.optim.Adam(list(layers.parameters()) + list(bert.parameters()), lr=1e-4)
trainer = Emo_trainer(bert, optimizer)
trainer.add_bar('Epoch', 'Iter')
trainer.add_metrics('loss', float)
trainer.down_stream(batch, max_epoch, layers, log_dir="Bert_down_stream_emo", log=True,
                    log_file_name="Bert_down_stream_emo", monitor=False)

copy to run: tensorboard --logdir=C:\Users\123\PycharmProjects\torch-models\bert_impl\Bert_down_stream_emo --port=6006 --reload_interval=30
 ▏Epoch: │████████████████████│ 100.00% ▏Iter: │████████████████████│ 100.00% ▏Time: 2min23s ▏loss: 0.05397 