In [None]:
%cd /kaggle/input/tran-qswin-distill
!pip uninstall timm --yes
import math
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader,RandomSampler,SequentialSampler
from torch.utils.data import Subset

from main import * 
import argparse
import quant_swin_transformer
# from losses import DistillationLoss
from engine import evaluate
import swin_transformer

import timm
from timm.models import create_model
from timm.optim import create_optimizer
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy
from timm.data import Mixup
import utils
from typing import Iterable, Optional


parser = argparse.ArgumentParser(parents=[get_args_parser()])
args=parser.parse_args(args=[])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### datasets

In [None]:
transform = transforms.Compose([
#     transforms.RandomCrop(32),
#     transforms.RandomHorizontalFlip(),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])
])

dataset_train = datasets.CIFAR100(root='/kaggle/working/data', train=True, download=True, transform=transform)
dataset_val = datasets.CIFAR100(root='/kaggle/working/data', train=False, download=True, transform=transform)

args.batch_size = 32

sampler_train = RandomSampler(dataset_train)
sampler_val = SequentialSampler(dataset_val)
args.pin_mem = False
data_loader_train = DataLoader(
    dataset_train, sampler=sampler_train,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=True,
)

data_loader_val = DataLoader(
    dataset_val, sampler=sampler_val,
    batch_size=int(1.5 * args.batch_size),
    num_workers=args.num_workers,
    pin_memory=args.pin_mem,
    drop_last=False
)

### load teacher_model

In [None]:
# all_pretrained_models_available = timm.list_models(pretrained=True)
# print(all_pretrained_models_available)
teacher_model_name = "swin_tiny_patch4_window7_224"
# teacher_model = None
teacher_model = create_model(
    teacher_model_name,
    pretrained=True,
    num_classes=100,
)
# teacher_model = swin_transformer.SwinTransformer(num_classes=100,patch_size=4, window_size=7, embed_dim=24, depths=(2, 2, 6, 2), num_heads=(3, 3, 3, 3))
# teacher_model = swin_transformer.SwinTransformer(num_classes=100)

teacher_model.to(device)
teacher_model.eval()

teacher_model_params = sum(p.numel() for p in teacher_model.parameters())
print(f"Total number of parameters in teacher_model: {teacher_model_params}")

In [None]:
import torch
from torch.nn import functional as F


class DistillationLoss(torch.nn.Module):
    """
    This module wraps a standard criterion and adds an extra knowledge distillation loss by
    taking a teacher model prediction and using it as additional supervision.
    """
    def __init__(self, base_criterion: torch.nn.Module, s_t_criterion: torch.nn.Module,teacher_model: torch.nn.Module,
                 distillation_type: str, alpha: float, tau: float):
        super().__init__()
        self.base_criterion = base_criterion
        self.s_t_criterion = s_t_criterion
        self.teacher_model = teacher_model
        assert distillation_type in ['none', 'soft', 'hard']
        self.distillation_type = distillation_type
        self.alpha = alpha
        self.tau = tau

    def forward(self, inputs, outputs, labels, model_feature):
        """
        Args:
            inputs: The original inputs that are feed to the teacher model
            outputs: the outputs of the model to be trained. It is expected to be
                either a Tensor, or a Tuple[Tensor, Tensor], with the original output
                in the first position and the distillation predictions as the second output
            labels: the labels for the base criterion
        """
        teacher_feature = None
        def hook_fn_2(module, input, output):
            nonlocal teacher_feature
            teacher_feature = output[1]
        self.teacher_model.layers[-1].blocks[1].attn.register_forward_hook(hook_fn_2)

        base_loss = self.base_criterion(outputs, labels)

        with torch.no_grad():
            teacher_outputs = self.teacher_model(inputs)
        s_t_loss = self.s_t_criterion(model_feature,teacher_feature)
        
        loss = base_loss * (1 - self.alpha) + s_t_loss * self.alpha
        return loss

In [None]:
# model = quant_swin_transformer.SwinTransformer(num_classes=100,patch_size=4, window_size=7, embed_dim=24, depths=(2, 2, 6, 2), num_heads=(3, 3, 3, 3))
model = quant_swin_transformer.SwinTransformer(num_classes=100)

# model = create_model(
#     teacher_model_name,
#     pretrained=True,
#     num_classes=100,
# )

model.to(device)
n_parameters = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in model: {n_parameters}")


optimizer = create_optimizer(args, model)
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
s_t_criterion = F.mse_loss
criterion = DistillationLoss(
    criterion, s_t_criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
)
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
args.lr = linear_scaled_lr
loss_scaler = NativeScaler()
lr_scheduler, _ = create_scheduler(args, optimizer)
output_dir = Path(args.output_dir)

In [None]:
def train_one_epoch(model: torch.nn.Module, teacher_model: torch.nn.Module, criterion: DistillationLoss,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
                    set_training_mode=True,print_freq = 10):
    model.train(set_training_mode)
    teacher_model.eval()
    student_feature = None

    def hook_fn_1(module, input, output):
        # 使用 nonlocal 关键字引用外部变量
        nonlocal student_feature
        # 获取最后一层的特征
        student_feature = output[1]
    model.layers[-1].blocks[1].attn.register_forward_hook(hook_fn_1)
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        outputs = model(samples)
        loss = criterion(samples, outputs, targets,student_feature)
        loss_value = loss.item()
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        torch.cuda.synchronize()
        # if model_ema is not None:
        #     model_ema.update(model)

        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

In [None]:
max_accuracy = 0.0
for epoch in range(args.start_epoch, args.epochs):

    train_stats = train_one_epoch(
        model, teacher_model, criterion, data_loader_train,
        optimizer, device, epoch, loss_scaler,
        args.clip_grad, None, None,
        set_training_mode=args.finetune == '',  # keep in eval mode during finetuning
        print_freq = 100
    )

    lr_scheduler.step(epoch)
    if args.output_dir:
        checkpoint_paths = [output_dir / 'checkpoint.pth']
        for checkpoint_path in checkpoint_paths:
            utils.save_on_master({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                # 'model_ema': get_state_dict(model_ema),
                'scaler': loss_scaler.state_dict(),
                'args': args,
            }, checkpoint_path)


    test_stats = evaluate(data_loader_val, model, device)
    print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")

    if max_accuracy < test_stats["acc1"]:
        max_accuracy = test_stats["acc1"]
        if args.output_dir:
            checkpoint_paths = [output_dir / 'best_checkpoint.pth']
            for checkpoint_path in checkpoint_paths:
                utils.save_on_master({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    # 'model_ema': get_state_dict(model_ema),
                    'scaler': loss_scaler.state_dict(),
                    'args': args,
                }, checkpoint_path)

    print(f'Max accuracy: {max_accuracy:.2f}%')

    log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                 **{f'test_{k}': v for k, v in test_stats.items()},
                 'epoch': epoch,
                 'n_parameters': n_parameters}




    if args.output_dir and utils.is_main_process():
        with (output_dir / "log.txt").open("a") as f:
            f.write(json.dumps(log_stats) + "\n")

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))