In [1]:
import torch
import os
import random
import numpy as np

config = {
    'train_file_path': 'data/train.json',
    'dev_file_path': 'data/dev.json',
    'test_file_path': 'data/test.json',
    'output_path': '.',
    'model_path': '../NLP_Project/dataset/BERT_model/',
    'batch_size': 16,
    'num_epochs': 1,
    'max_seq_len': 64,
    'learning_rate': 2e-5,
    'eps': 0.1,
    'alpha': 0.3,
    'adv': 'fgm',
    'warmup_ratio': 0.05,
    'weight_decay': 0.01,
    'use_bucket': True,
    'bucket_multiplier': 200,
    'device': 'cuda',
    'n_gpus': 0,
    'use_amp': True,
    'logging_step': 400,
    'ema_start_step': 500,
    'ema_start': False,
    'seed': 2021
}
if not torch.cuda.is_available():
    config['device'] = 'cpu'
else:
    config['n_gpus'] = torch.cuda.device_count()
    config['batch_size'] *= config['n_gpus']
    
if not os.path.exists(config['output_path']):
    os.makedirs((config['output_path']))
    
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

seed_everything(config['seed'])

2021

In [2]:
from tqdm.notebook import tqdm
import json
import pandas as pd

def parse_data(path, data_type='train'):
    sentence_a, sentence_b, labels = [], [], []
    with open(path, 'r', encoding='utf-8') as f:
        for line in tqdm(f.readlines(), desc=f'Reading {data_type} data'):
            line = json.loads(line)
            sentence_a.append(line['sentence1'])
            sentence_b.append(line['sentence2'])
            if data_type != 'test':
                labels.append(int(line['label']))
            else:
                labels.append(0)
    df = pd.DataFrame(zip(sentence_a, sentence_b, labels), columns=['text_a', 'text_b', 'labels'])
    
    return df

In [3]:
# inputs: defaultdict(list)
def builder_bert_inputs(inputs, label, sentence_a, sentence_b, tokenizer):
    # add_special_tokens [CLS] [SEP]
    # return token_type_ids sentence_a 0, sentence_b 1
    # return_attention_mask 不是pad的部分1 是pad的部分0
    inputs_dict = tokenizer.encode_plus(sentence_a, sentence_b, add_special_tokens=True,
                                        return_token_type_ids=True, return_attention_mask=True)
    inputs['input_ids'].append(inputs_dict['input_ids'])
    inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
    inputs['attention_mask'].append(inputs_dict['attention_mask'])
    inputs['labels'].append(label)

In [4]:
from collections import defaultdict

def read_data(config, tokenizer):
    train_df = parse_data(config['train_file_path'], data_type='train')
    dev_df = parse_data(config['dev_file_path'], data_type='dev')
    test_df = parse_data(config['test_file_path'], data_type='test')
    
    data_df = {'train': train_df, 'dev': dev_df, 'test': test_df}
    # 保存BERT的输入
    processed_data = {}
    for data_type, df in data_df.items():
        inputs = defaultdict(list)
        for i, row in tqdm(df.iterrows(), desc=f'Preprocessing {data_type} data', total=len(df)):
            sentence_a, sentence_b, label = row[0], row[1], row[2]
            builder_bert_inputs(inputs, label, sentence_a, sentence_b, tokenizer)
        processed_data[data_type] = inputs
        
    return processed_data

In [5]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained(config['model_path'])
dt = read_data(config, tokenizer)

Reading train data:   0%|          | 0/34334 [00:00<?, ?it/s]

Reading dev data:   0%|          | 0/4316 [00:00<?, ?it/s]

Reading test data:   0%|          | 0/3861 [00:00<?, ?it/s]

Preprocessing train data:   0%|          | 0/34334 [00:00<?, ?it/s]

Preprocessing dev data:   0%|          | 0/4316 [00:00<?, ?it/s]

Preprocessing test data:   0%|          | 0/3861 [00:00<?, ?it/s]

In [6]:
print(dt['train']['input_ids'][0])

[101, 6010, 6009, 955, 1446, 5023, 7583, 6820, 3621, 1377, 809, 2940, 2768, 1044, 2622, 1400, 3315, 1408, 102, 955, 1446, 3300, 1044, 2622, 1168, 3309, 6820, 3315, 1408, 102]


In [7]:
from torch.utils.data import Dataset

class AFQMCDataset(Dataset):
    def __init__(self, data_dict):
        super(AFQMCDataset, self).__init__()
        self.data_dict = data_dict
    
    def __getitem__(self, idx):
        example = (self.data_dict['input_ids'][idx],
                   self.data_dict['token_type_ids'][idx],
                   self.data_dict['attention_mask'][idx],
                   self.data_dict['labels'][idx])
        return example
    
    def __len__(self):
        return len(self.data_dict['input_ids'])

In [8]:
class Collator:
    def __init__(self, max_seq_len, tokenizer):
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenizer
        
    def pad_and_truncate(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):
        input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
        token_type_ids = torch.zeros_like(input_ids)
        attention_mask = torch.zeros_like(input_ids)
        
        for i in range(len(input_ids_list)):
            seq_len = len(input_ids_list[i])
            
            if seq_len <= max_seq_len:
                input_ids[i, :seq_len] = torch.tensor(input_ids_list[i], dtype=torch.long)
                token_type_ids[i, :seq_len] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
                attention_mask[i, :seq_len] = torch.tensor(attention_mask_list[i], dtype=torch.long)
            else:
                # input_ids 最后一位上放一个特殊的token
                input_ids[i] = torch.tensor(input_ids_list[i][:max_seq_len-1] + [self.tokenizer.sep_token_id], dtype=torch.long)
                token_type_ids[i] = torch.tensor(token_type_ids_list[i][:max_seq_len], dtype=torch.long)
                attention_mask[i] = torch.tensor(attention_mask_list[i][:max_seq_len], dtype=torch.long)
        labels = torch.tensor(labels_list, dtype=torch.long)
        
        return input_ids, token_type_ids, attention_mask, labels
        
    def __call__(self, examples):
        input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))
        cur_max_seq_len = max(len(input_id) for input_id in input_ids_list)
        max_seq_len = min(cur_max_seq_len, self.max_seq_len)
        input_ids, token_type_ids, attention_mask, labels = self.pad_and_truncate(input_ids_list,
                                                                                  token_type_ids_list,
                                                                                  attention_mask_list,
                                                                                  labels_list, max_seq_len)
        data_dict = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }
        
        return data_dict

In [9]:
collate_fn = Collator(config['max_seq_len'], tokenizer)

### Sampler Summary
- 所有采样器都继承Sampler这个类
- 需要实现`__iter__`, `__len__`方法

##### SequentialSampler
- 在初始化时拿到数据集, 按顺序对元素采样, 每次只返回一个索引

In [10]:
# 模拟数据 (2, 3, 4) -> batch_size 2, seq_len 3, embedding_dim 4 每个batch2条数据, 每个句子3个词, 每个词维度4
a = torch.randperm(60).reshape(5, 3, 4)
# print(a)
b = torch.utils.data.SequentialSampler(a)
for i in b:
    print(i)

0
1
2
3
4


##### RandomSampler
- replacement: True表示可以重复采样(类似有放回)
- num_samples: 指定采样的数量
- 当replacement=False时不应指定num_samples

In [11]:
a = torch.randperm(60).reshape((5, 3, 4))
b = torch.utils.data.RandomSampler(a, replacement=True, num_samples=3)
for i in b:
    print(i)

0
4
0


##### SubsetRandomSampler
- Samples elements randomly from a given list of indices, without replacement

In [12]:
a = torch.arange(5)
print(a)
b = torch.utils.data.SubsetRandomSampler(indices=a[2:])
for i in b:
    print(i)

tensor([0, 1, 2, 3, 4])
tensor(4)
tensor(3)
tensor(2)


##### BatchSampler
- sampler: 基采样器
- batch_size
- drop_last: True如果一个batch的长度小于batch_size则丢弃

In [13]:
a = torch.randperm(60).reshape((5, 3, 4))
base_b = torch.utils.data.SequentialSampler(a)
for i in base_b:
    print(i)
b = torch.utils.data.BatchSampler(base_b, 2, drop_last=True)
for i in b:
    print(i)

0
1
2
3
4
[0, 1]
[2, 3]


##### BucketBatchSampler
- Dataset ->RandomSampler-> 关于Dataset的随机序列 ->BatchSampler-> min(n*batch_size, len(Dataset)) <br/> 得到bucket ->SortedSampler-> 得到bucket的排序索引 ->BatchSampler-> 得到若干batch_size大小的小bucket <br/> ->随机抽取小bucket->返回小bucket在大bucket中大位置

In [14]:
# source code
import math
from torch.utils.data import Sampler ,BatchSampler, RandomSampler, SubsetRandomSampler

class SortedSampler(Sampler):
    def __init__(self, data, sort_key):
        super().__init__(data)
        self.sort_key = sort_key
        zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
        zip_ = sorted(zip_, key=lambda r: r[1])
        self.sorted_indice = [item[0] for item in zip_]
        
    def __iter__(self):
        return iter(self.sorted_indice)
    
    def __len__(self):
        return len(self.data)
    
        
class BucketBatchSampler(BatchSampler):
    def __init__(self, sampler, batch_size, drop_last, sort_key, bucket_size_multiplier=100):
        super().__init__(sampler, batch_size, drop_last)
        self.sort_key = sort_key
        self.bucket_sampler = BatchSampler(sampler, min(batch_size * bucket_size_multiplier, len(sampler)), False)
        
    def __iter__(self):
        for bucket in self.bucket_sampler:
            sorted_sampler = SortedSampler(bucket, self.sort_key)
            for batch in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
                yield [bucket[i] for i in batch]
    
    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return math.ceil(len(self.sampler) / batch_size)

In [15]:
mini_dataset = {k: v[:6] for k, v in dt['train'].items()}
mini_data = AFQMCDataset(mini_dataset)
# mini_data前6条数据
for i, d in enumerate(mini_data):
    print(len(d[0]))

30
24
25
24
59
28


In [16]:
from extra_file.bucket_sampler import SortedSampler

random_sampler = RandomSampler(mini_data, replacement=False)
print(list(random_sampler))

[5, 3, 1, 0, 4, 2]


In [17]:
batch_sampler = BatchSampler(random_sampler, 4, drop_last=True)
for sample in batch_sampler:
    print(sample)
    sorted_sampler = SortedSampler(sample, sort_key=lambda x: len(mini_data[x][0]))
    print(list(sorted_sampler))

[4, 5, 0, 1]
[3, 1, 2, 0]


```
30 24 25 24 59 28
[1, 5, 2, 3]
[0, 3, 2, 1]:
    0->1->24
    3->3->24
    2->2->25
    1->5->28
```

In [18]:
c = list(BatchSampler(sorted_sampler, 2, drop_last=True))
print(c)

[[3, 1], [2, 0]]


In [19]:
for batch in SubsetRandomSampler(c):
    print('从给定的索引列表中随机采样元素')
    print(batch)
    print('所对应的原序列是')
    print([sample[i] for i in batch])

从给定的索引列表中随机采样元素
[3, 1]
所对应的原序列是
[1, 5]
从给定的索引列表中随机采样元素
[2, 0]
所对应的原序列是
[0, 4]


In [20]:
from torch.utils.data import DataLoader

def build_dataloader(config, data, collate_fn):
    train_dataset = AFQMCDataset(data['train'])
    dev_dataset = AFQMCDataset(data['dev'])
    test_dataset = AFQMCDataset(data['test'])
    
    if config['use_bucket']:
        train_sampler = RandomSampler(train_dataset)
        bucket_sampler = BucketBatchSampler(train_sampler, batch_size=config['batch_size'],
                                            drop_last=True, sort_key=lambda x: len(train_dataset[x][0]),
                                            bucket_size_multiplier=config['bucket_multiplier'])
        train_dataloader = DataLoader(train_dataset, batch_sampler=bucket_sampler,
                                      num_workers=4, collate_fn=collate_fn)
    else:
        train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,
                                      num_workers=4, collate_fn=collate_fn)
    
    dev_dataloader = DataLoader(dev_dataset, batch_size=config['batch_size'], 
                                     shuffle=False, num_workers=4, collate_fn=collate_fn)
    
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], 
                                     shuffle=False, num_workers=4, collate_fn=collate_fn)
    
    return train_dataloader, dev_dataloader, test_dataloader

In [21]:
train_dataloader, dev_dataloader, test_dataloader = build_dataloader(config, dt, collate_fn)

In [22]:
import warnings
warnings.filterwarnings('ignore')

for i in train_dataloader:
    print(i)
    break

{'input_ids': tensor([[ 101, 5709, 1446,  ..., 2512, 1510,  102],
        [ 101, 2769, 4638,  ...,  809, 1408,  102],
        [ 101,  711,  784,  ..., 4638, 1660,  102],
        ...,
        [ 101, 3221, 1728,  ..., 3326,  749,  102],
        [ 101, 2769,  955,  ...,  955, 1446,  102],
        [ 101, 3221,  679,  ..., 6820, 1408,  102]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
        1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0,
        1, 1, 0

##### 混合精度训练

作用: 训练时尽量不降低性能, 并提升速度 <br/>

Float16优点:
- 减少内存的使用
- 加快训练和推理计算

Float16缺点:
- 溢出错误
- 舍入误差

当进入autocast()时, 系统自动切换为Float16, autocast上下文只包含前向传播

scaler.scale(loss)将损失乘以缩放器当前比例因子, 进行反向传播

scaler.step(optimizer)取消缩放梯度并调用optimizer.step()

scaler.update()更新缩放器的比例因子


In [23]:
from sklearn.metrics import f1_score, accuracy_score

def evaluation(config, model, val_dataloader):
    model.eval()
    preds, labels, val_loss = [], [], 0.
    val_iterator = tqdm(val_dataloader, desc='Evaluation', total=len(val_dataloader))
    
    with torch.no_grad():
        for batch in val_iterator:
            labels.append(batch['labels'])
            batch_cuda = {item: value.to(config['device']) for item, value in list(batch.items())}
            loss, logits = model(**batch_cuda)[:2]
            
            if config['n_gpus'] > 1:
                loss = loss.mean()
                
            val_loss += loss.item()
            preds.append(logits.argmax(dim=-1).detach().cpu())
    
    avg_val_loss = val_loss / len(val_dataloader)
    labels = torch.cat(labels, dim=0).numpy()
    preds = torch.cat(preds, dim=0).numpy()
    f1 = f1_score(labels, preds)
    acc = accuracy_score(labels, preds)
    
    return avg_val_loss, f1, acc

In [24]:
class EMA:
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self.register()

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

In [25]:
from transformers import BertForSequenceClassification
from torch.cuda import amp
from torch.optim import AdamW
from extra_file.extra_pgd import *
from extra_file.extra_fgm import *
from extra_file.extra_loss import *
from extra_file.extra_optim import *
from tqdm.notebook import trange

def train(config, train_dataloader,dev_dataloader):
    # 封装好 BertForSequenceClassification
    model = BertForSequenceClassification.from_pretrained(config['model_path'])
    param_optimizer = list(model.named_parameters())
    # 实例化scaler对象使用梯度缩放
    scaler = amp.GradScaler(enabled=config['use_amp'])
    # 权重缩减
    no_decay = ['bias', 'LayerNorm.weight']
    # ['bias', 'LayerNorm.weight']权重衰减因子为0
    # 其它权重衰减因子为0.01
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weigth_decay': config['weight_decay']},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay': 0.}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate'], eps=1e-8)
    # Lookahead 预先查看AdamW生成的快权重来选择搜索方向
    optimizer = Lookahead(optimizer, 5, 1)
    total_steps = config['num_epochs'] * len(train_dataloader)
    # 使用warmup调整学习率
    lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(config['warmup_ratio'] * total_steps),
                                        t_total=total_steps)
    model.to(config['device'])
    
    if config['adv'] == 'fgm':
        fgm = FGM(model)
    else:
        pgd = PGD(model)
        K = 3
        
    epoch_iterator = trange(config['num_epochs'])
    global_steps, train_loss, logging_loss, best_acc, best_model_path = 0, 0., 0., 0., ''
    
    if config['n_gpus'] > 1:
        model = nn.DataParallel(model)
        
    for _ in epoch_iterator:
        train_iterator = tqdm(train_dataloader, desc='Training', total=len(train_dataloader))
        model.train()
        for batch in train_iterator:
            batch_cuda = {item: value.to(config['device']) for item, value in list(batch.items())}
            # 前向传播
            with amp.autocast(enabled=config['use_amp']):
                loss = model(**batch_cuda)[0]
                # 多卡取平均
                if config['n_gpus'] > 1:
                    loss = loss.mean()
                    
            scaler.scale(loss).backward()
            
            if config['adv'] == 'fgm':
                # 在embedding上加扰动
                fgm.attack(epsilon=config['eps'])
                # autocast
                with amp.autocast(enabled=config['use_amp']):
                    loss_adv = model(**batch_cuda)[0]
                    if config['n_gpus'] > 1:
                        loss_adv = loss_adv.mean()
                
                scaler.scale(loss_adv).backward()
                # 恢复embedding参数
                fgm.restore()
            else:
                pgd.backup_grad()
                for t in range(K):
                    pgd.attack(epsilon=config['eps'], alpha=config['alpha'], is_first_attack=(t == 0))
                    if t != K - 1:
                        model.zero_grad()
                    else:
                        pgd.restore_grad()
                    with amp.autocast(enabled=config['use_amp']):
                        loss_adv = model(**batch_cuda)[0]
                        if config['n_gpus'] > 1:
                            loss_adv = loss_adv.mean()

                    scaler.scale(loss_adv).backward()
                pgd.restore()
            
            scaler.step(optimizer)
            scaler.update()
            
            lr_scheduler.step()
            optimizer.zero_grad()
            
            if config['ema_start']:
                ema.update()

            train_loss += loss.item()
            global_steps += 1

            train_iterator.set_postfix_str(f'running training loss: {loss.item():.4f}')

            if global_steps % config['logging_step'] == 0:
                if global_steps >= config['ema_start_step'] and not config['ema_start']:
                    print('\n>>> EMA starting ...')
                    config['ema_start'] = True
                    
                    ema = EMA(model.module if hasattr(model, 'module') else model, decay=0.999)

                print_train_loss = (train_loss - logging_loss) / config['logging_step']
                logging_loss = train_loss


                if config['ema_start']:
                    ema.apply_shadow()
                val_loss, f1, acc = evaluation(config, model, dev_dataloader)

                print_log = f'\n>>> training loss: {print_train_loss:.6f}, valid loss: {val_loss:.6f}, '

                if acc > best_acc:
                    model_save_path = os.path.join(config['output_path'],
                                                   f'checkpoint-{global_steps}-{acc:.6f}')
                    model_to_save = model.module if hasattr(model, 'module') else model
                    model_to_save.save_pretrained(model_save_path)
                    best_acc = acc
                    best_model_path = model_save_path

                print_log += f'valid f1: {f1:.6f}, valid acc: {acc:.6f}'

                print(print_log)
                model.train()

                if config['ema_start']:
                    ema.restore()

    return model, best_model_path

In [26]:
train(config, train_dataloader, dev_dataloader)

Some weights of the model checkpoint at ../NLP_Project/dataset/BERT_model/ were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the mo

  0%|          | 0/1 [00:00<?, ?it/s]

Training:   0%|          | 0/536 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/68 [00:00<?, ?it/s]


>>> training loss: 0.582753, valid loss: 0.531598, valid f1: 0.483151, valid acc: 0.705051


(DataParallel(
   (module): BertForSequenceClassification(
     (bert): BertModel(
       (embeddings): BertEmbeddings(
         (word_embeddings): Embedding(21128, 768, padding_idx=1)
         (position_embeddings): Embedding(512, 768)
         (token_type_embeddings): Embedding(2, 768)
         (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
         (dropout): Dropout(p=0.1, inplace=False)
       )
       (encoder): BertEncoder(
         (layer): ModuleList(
           (0): BertLayer(
             (attention): BertAttention(
               (self): BertSelfAttention(
                 (query): Linear(in_features=768, out_features=768, bias=True)
                 (key): Linear(in_features=768, out_features=768, bias=True)
                 (value): Linear(in_features=768, out_features=768, bias=True)
                 (dropout): Dropout(p=0.1, inplace=False)
               )
               (output): BertSelfOutput(
                 (dense): Linear(in_features=768, out_