import textbrewer from textbrewer import GeneralDistiller from textbrewer import TrainingConfig, DistillationConfig from transformers import BertForSequenceClassification, BertConfig, AdamW from transformers import get_linear_schedule_with_warmup import torch from torch.utils.data import Dataset, DataLoader import numpy as np from vit_model import vit_base_patch16_224_in21k as stdudet_modle from vit_model import vit_base_patch16_224_in21k_teacher as teacher_module from utils import read_split_data, train_one_epoch, evaluate # device device = torch.device('cpu') from torchvision import transforms from my_dataset import MyDataSet # Define models teacher_model = teacher_module(num_classes=5, has_logits=False).to(device) # , num_labels = 2 # Teacher should be initialized with pre-trained weights and fine-tuned on the downstream task. # For the demonstration purpose, we omit these steps here student_model = stdudet_modle(num_classes=5, has_logits=False).to(device) # , num_labels = 2 # Define Dict Dataset train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(r"D:\swin\data\flower_photos") data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])} # Prepare random data dataset = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"]) eval_dataset = MyDataSet(images_path=val_images_path, images_class=val_images_label, transform=data_transform["val"]) dataloader = DataLoader(dataset, batch_size=8, shuffle=True, pin_memory=True, num_workers=0, collate_fn=dataset.collate_fn) train_loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, pin_memory=True, num_workers=0, collate_fn=dataset.collate_fn) val_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=8, shuffle=False, pin_memory=True, num_workers=0, collate_fn=dataset.collate_fn) num_epochs = 10 num_training_steps = len(dataloader) * num_epochs # Optimizer and learning rate scheduler optimizer = AdamW(student_model.parameters(), lr=1e-4) scheduler_class = get_linear_schedule_with_warmup # arguments dict except 'optimizer' scheduler_args = {'num_warmup_steps': int(0.1 * num_training_steps), 'num_training_steps': num_training_steps} # display model parameters statistics print("\nteacher_model's parametrers:") result, _ = textbrewer.utils.display_parameters(teacher_model, max_level=3) print(result) print("student_model's parametrers:") result, _ = textbrewer.utils.display_parameters(student_model, max_level=3) print(result) def simple_adaptor(batch, model_outputs): # The second element of model_outputs is the logits before softmax # The third element of model_outputs is hidden states return {'logits': model_outputs[1], 'hidden': model_outputs[2], 'inputs_mask': batch['attention_mask']} # Define callback function def predict(model, eval_dataset, step, device): ''' eval_dataset: 验证数据集 ''' model.eval() pred_logits = [] label_ids = [] dataloader = DataLoader(eval_dataset, batch_size=8) for batch in dataloader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'] with torch.no_grad(): logits, _ = model(input_ids=input_ids, attention_mask=attention_mask) cpu_logits = logits.detach().cpu() for i in range(len(cpu_logits)): pred_logits.append(cpu_logits[i].numpy()) label_ids.append(labels[i]) train_loss, train_acc = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=10) from functools import partial callback_fun = partial(predict, eval_dataset=eval_dataset, device=device) # fill other arguments # Initialize configurations and distiller train_config = TrainingConfig(device=device) distill_config = DistillationConfig( temperature=8, hard_label_weight=0, kd_loss_type='ce', probability_shift=False, intermediate_matches=[ {'layer_T': 0, 'layer_S': 0, 'feature': 'hidden', 'loss': '', 'weight': 1}, {'layer_T': 8, 'layer_S': 2, 'feature': 'hidden', 'loss': '', 'weight': 1}, {'layer_T': [0, 0], 'layer_S': [0, 0], 'feature': 'hidden', 'loss': '', 'weight': 1}, {'layer_T': [8, 8], 'layer_S': [2, 2], 'feature': 'hidden', 'loss': '', 'weight': 1}] ) print("train_config:") print(train_config) print("distill_config:") print(distill_config) distiller = GeneralDistiller( train_config=train_config, distill_config=distill_config, model_T=teacher_model, model_S=student_model, adaptor_T=simple_adaptor, adaptor_S=simple_adaptor) # Start distilling with distiller: distiller.train(optimizer, num_epochs=num_epochs, dataloader=dataloader, scheduler_class=scheduler_class, scheduler_args=scheduler_args, callback=callback_fun)