In [2]:
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
import numpy as np
import os
import math 
import glob
import pandas as pd
from PIL import Image 
from PIL import ImageFile
from PIL import ImageEnhance
ImageFile.LOAD_TRUNCATED_IMAGES = True
import sklearn
from torchvision import transforms
from torch.utils.data import Dataset,Subset,DataLoader
import matplotlib.pyplot as plt
import pickle
from sklearn.model_selection import train_test_split
from torchvision.transforms import Compose, Resize, ToTensor
import random
import timm
import torchvision.transforms as transforms
from einops.layers.torch import Rearrange
from timm.models.efficientnet_blocks import SqueezeExcite, DepthwiseSeparableConv
from typing import Type


In [3]:
torch.manual_seed(2023)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Current Device : {device}')

Current Device : cuda


In [None]:
enhance_1 = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.6),
  transforms.RandomRotation(degrees=(30))])

In [None]:
class BirdDataset(Dataset):
    def __init__(self,dataset_path,transform_fn,enhance_path = None):
        self.dataset_path = dataset_path
        self.transform = transform_fn
        self.label_idx2name = {}
        self.img_path = []
        if dataset_path:
            file_list = os.listdir(dataset_path)
            file_list = file_list[0:73]
            self.label_idx2name = np.array(file_list)
            self.label_name2idx = {}
            self.img2label = {}
            for i in range(len(file_list)):                 
                self.label_name2idx[self.label_idx2name[i]] = i
                lst = glob.glob(f"{dataset_path}/{file_list[i]}/*.jpg")
                if len(lst)>=200:
                    lst = lst[0:200]
                self.img_path.extend(lst)
                for j in range(len(lst)):
                    self.img2label[lst[j]] = i
                
    def __len__(self):
        return len(self.img_path)
    
    def __getitem__(self,index):
        img =  self.img_path[index]
        label = self.img2label[img]
        img = Image.open(img).convert("RGB"
        img = self.transform(img)
        return (img,label)

In [None]:
channel_mean = torch.Tensor([0.485,0.456,0.406])
channel_std = torch.Tensor([0.229,0.224,0.225])

vit_train_transform_fn = transforms.Compose([
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.6),
    transforms.RandomRotation(degrees=(30)),
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
]) 

train_transfrom_noaug_fn = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
])

vit_valid_transform_fn = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=channel_mean, std=channel_std)
]) 

In [None]:
train_dataset = BirdDataset(dataset_path='../input/100-bird-species/train',transform_fn=vit_train_transform_fn)
valid_dataset = BirdDataset(dataset_path='../input/100-bird-species/valid',transform_fn=vit_valid_transform_fn)
valid_dataset.transform = vit_valid_transform_fn
print(f"训练集图片的个数为：{len(train_dataset)}")
print(f"测试集图片的个数为：{len(valid_dataset)}")

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=16,
    shuffle=True
)

In [None]:
def show_samples(batch_img, batch_label=None, num_samples=16):

    sample_idx = 0
    total_col = 4
    total_row = math.ceil(num_samples / 4)
    col_idx = 0
    row_idx = 0

    fig, axs = plt.subplots(total_row, total_col, figsize=(12, 12))

    while sample_idx < num_samples:
        img = batch_img[sample_idx]
        img = img.view(3, -1) * channel_std.view(3, -1) + channel_mean.view(3, -1)
        img = img.view(3, 224, 224)
        img = img.permute(1, 2, 0)
        axs[row_idx, col_idx].imshow(img)

        if batch_label != None:
            axs[row_idx, col_idx].set_title(train_dataset.label_idx2name[(batch_label[sample_idx])])
        sample_idx += 1
        col_idx += 1
        if col_idx == 4:
            col_idx = 0
            row_idx += 1
batch_img, batch_label = next(iter(train_dataloader))
show_samples(batch_img, batch_label, 8)

# vit

In [None]:
class PretrainViT(nn.Module):

    def __init__(self):
        super(PretrainViT, self).__init__()
        model = models.vit_l_16(pretrained=True)
        num_classifier_feature = model.heads.head.in_features
        model.heads.head = nn.Sequential(
            nn.Linear(num_classifier_feature, 80)
        )
        self.model = model

        for param in self.model.named_parameters():
            if "heads" not in param[0]:
                param[1].requires_grad = False

    def forward(self, x):
        return self.model(x)

In [None]:
net = PretrainViT()
net.to(device)
print(f"number of paramaters: {sum([param.numel() for param in net.parameters() if param.requires_grad])}")

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.009)

In [None]:
def get_accuracy(output, label):
    output = output.to("cpu")
    label = label.to("cpu")

    sm = F.softmax(output, dim=1)
    _, index = torch.max(sm, dim=1)
    return torch.sum((label == index)) / label.size()[0]

In [None]:
def train(model, dataloader):
    model.train()
    running_loss = 0.0
    total_loss = 0.0
    running_acc = 0.0
    total_acc = 0.0

    for batch_idx, (batch_img, batch_label) in enumerate(dataloader):

        batch_img = batch_img.to(device)
        batch_label = batch_label.to(device)

        optimizer.zero_grad()
        output = model(batch_img)
        loss = criterion(output, batch_label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total_loss += loss.item()

        acc = get_accuracy(output, batch_label)
        running_acc += acc
        total_acc += acc

        if batch_idx % 100 == 0 and batch_idx != 0:
            print(f"[step: {batch_idx:4d}/{len(dataloader)}] loss: {running_loss / 100:.3f}")
            running_loss = 0.0
            running_acc = 0.0
    
    return total_loss / len(dataloader), total_acc / len(dataloader)

In [None]:
def validate(model, dataloader):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0

    for batch_idx, (batch_img, batch_label) in enumerate(dataloader):

        batch_img = batch_img.to(device)
        batch_label = batch_label.to(device)

        # optimizer.zero_grad()
        output = model(batch_img)
        loss = criterion(output, batch_label)
        # loss.backward()
        # optimizer.step()

        total_loss += loss.item()
        acc = get_accuracy(output, batch_label)
        total_acc += acc
    
    return total_loss / len(dataloader), total_acc / len(dataloader)

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
early_stopper = EarlyStopper(patience=3, min_delta=0.01)


In [None]:
train_loss_history = []
valid_loss_history = []
train_acc_history = []
valid_acc_history = []
optimizer = optim.Adam(net.parameters(), lr=0.001)
EPOCHS = 7
for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "net.pt")
    if early_stopper.early_stop(valid_loss):             
        break 


In [None]:

epochs = range(1,  len(train_loss_history)+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(epochs, train_acc_history)
axes[0].plot(epochs, valid_acc_history)
axes[0].set_title('Vit Training and validation accuracy',
                  fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xlabel('epoch') 
axes[0].legend(['Train', 'Validation'], loc='upper left')

axes[1].plot(epochs, train_loss_history)
axes[1].plot(epochs, valid_loss_history)
axes[1].set_title('Vit Training and validation loss',
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss')
axes[1].set_xlabel('epoch') 
axes[1].legend(['Train', 'Validation'], loc='upper right')

plt.show()

# 预训练MaxVit

In [None]:
del net,early_stopper

In [None]:
net = timm.create_model('maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k',
                          pretrained=True, drop_rate=0.2,
                         )
torch.set_grad_enabled(True)
num_classes = 80

for param in net.parameters():
    param.requires_grad = False

net.head.fc = nn.Linear(net.head.fc.in_features, num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device)

In [None]:
early_stopper = EarlyStopper(patience=3, min_delta=0.01)
train_loss_history = []
valid_loss_history = []
train_acc_history = []
valid_acc_history = []
optimizer = optim.AdamW(net.parameters(), lr=0.0005)
def adjust_learning_rate(lr, epoch):
    """Sets the learning rate to the initial LR decayed by 0.5 every 2 epochs"""
    if epoch % 2 == 0 and epoch != 0:
        lr *= 0.5
    return lr
lr = 0.0005
EPOCHS = 7
for epoch in range(EPOCHS):
    lr = adjust_learning_rate(lr, epoch)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "maxvit1-net.pt")
    if early_stopper.early_stop(valid_loss):      
        break

In [None]:
epochs = range(1,  len(train_loss_history)+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(epochs, train_acc_history)
axes[0].plot(epochs, valid_acc_history)
axes[0].set_title('maxvit Training and validation accuracy',
                  fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xlabel('epoch') 
axes[0].legend(['Train', 'Validation'], loc='upper left')

axes[1].plot(epochs, train_loss_history)
axes[1].plot(epochs, valid_loss_history)
axes[1].set_title('maxvit Training and validation loss',
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss')
axes[1].set_xlabel('epoch') 
axes[1].legend(['Train', 'Validation'], loc='upper right')

plt.show()

# 预训练maxvit训练最后三层

In [None]:
del net,early_stopper

In [None]:
net = timm.create_model('maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k',
                          pretrained=True, drop_rate=0.2,
                         )
torch.set_grad_enabled(True)
num_classes = 80
for param in net.parameters():
    param.requires_grad = False

# model.stages[3] covers the last two blocks of the model
for param in net.stages[3].parameters():
    param.requires_grad = True
    
for param in net.head.parameters():
    param.requires_grad = True

net.head.fc = nn.Linear(net.head.fc.in_features, num_classes)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device)

In [None]:
early_stopper = EarlyStopper(patience=3, min_delta=0.01)
train_loss_history = []
valid_loss_history = []
train_acc_history = []
valid_acc_history = []
optimizer = optim.Adam(net.parameters(), lr=0.001)
EPOCHS = 7
for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "maxvit3-net.pt")
    if early_stopper.early_stop(valid_loss):             
        break

In [None]:
epochs = range(1,  len(train_loss_history)+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(epochs, train_acc_history)
axes[0].plot(epochs, valid_acc_history)
axes[0].set_title('maxvit Training and validation accuracy',
                  fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xlabel('epoch') 
axes[0].legend(['Train', 'Validation'], loc='upper left')

axes[1].plot(epochs, train_loss_history)
axes[1].plot(epochs, valid_loss_history)
axes[1].set_title('maxvit Training and validation loss',
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss')
axes[1].set_xlabel('epoch') 
axes[1].legend(['Train', 'Validation'], loc='upper right')

plt.show()

# ResNet训练最后四层

In [None]:
del net,early_stopper

In [None]:
import torchvision
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = torchvision.models.resnet50(pretrained=True)
#num_classes = 70
in_features = net.fc.in_features
net.fc = torch.nn.Linear(in_features, num_classes)
for param in net.parameters():
    param.requires_grad = False
for param in net.layer4.parameters():
    param.requires_grad = True
net.to(device)

In [None]:
early_stopper = EarlyStopper(patience=3, min_delta=0.01)
train_loss_history = []
valid_loss_history = []
train_acc_history = []
valid_acc_history = []
optimizer = optim.Adam(net.parameters(), lr=0.001)

EPOCHS = 7
for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "reset4-net.pt")
    if early_stopper.early_stop(valid_loss):             
        break

In [None]:
epochs = range(1,  len(train_loss_history)+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(epochs, train_acc_history)
axes[0].plot(epochs, valid_acc_history)
axes[0].set_title('resnet50(4) Training and validation accuracy',
                  fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xlabel('epoch') 
axes[0].legend(['Train', 'Validation'], loc='upper left')

axes[1].plot(epochs, train_loss_history)
axes[1].plot(epochs, valid_loss_history)
axes[1].set_title('resnet50(训练最后四层) Training and validation loss',
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss')
axes[1].set_xlabel('epoch') 
axes[1].legend(['Train', 'Validation'], loc='upper right')

plt.show()

In [None]:
del net,early_stopper

In [None]:
print(1)

# MBConv+vit

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from einops.layers.torch import Rearrange
from timm.models.efficientnet_blocks import SqueezeExcite, DepthwiseSeparableConv

from typing import Type
from torchvision import models
def _gelu_ignore_parameters(
        *args,
        **kwargs
) -> nn.Module:
    activation = nn.GELU()
    return activation

# 定义 MBConv 模块
class MBConv(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            downscale: bool = False,
            act_layer: Type[nn.Module] = nn.GELU,
            norm_layer: Type[nn.Module] = nn.BatchNorm2d,
            drop_path: float = 0.,
    ) -> None:
        super(MBConv, self).__init__()
        self.drop_path_rate: float = drop_path
        if not downscale:
            assert in_channels == out_channels, "If downscaling is utilized input and output channels must be equal."
        if act_layer == nn.GELU:
            act_layer = _gelu_ignore_parameters
        self.main_path = nn.Sequential(
            norm_layer(in_channels),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1)),
            DepthwiseSeparableConv(in_chs=in_channels, out_chs=out_channels, stride=2 if downscale else 1,
                                   act_layer=act_layer, norm_layer=norm_layer, drop_path_rate=drop_path),
            SqueezeExcite(in_chs=out_channels, rd_ratio=0.25),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1))
        )
        self.skip_path = nn.Sequential(
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, 1))
        ) if downscale else nn.Identity()

    def forward(
            self,
            input: torch.Tensor
    ) -> torch.Tensor:
        output = self.main_path(input)
        if self.drop_path_rate > 0.:
            output = drop_path(output, self.drop_path_rate, self.training)
        output = output + self.skip_path(input)
        return output
    
class con_vit(nn.Module):
    def __init__(self,num_classes):
        super(con_vit,self).__init__()
        self.mbconv = MBConv(in_channels= 3 ,out_channels= 3,downscale=False)
        vit16 = models.vit_l_16(pretrained=True)
        num_classifier_feature = vit16.heads.head.in_features
        vit16.heads.head = nn.Sequential(
            nn.Linear(num_classifier_feature, num_classes)
        )
        for param in vit16.named_parameters(): #对于所有的参数
            if "heads" not in param[0]:             #如果不是分类器头部的参数
                param[1].requires_grad = False
        self.vit = vit16
    def forward(self,x):
        x = self.mbconv(x)
        x = self.vit(x)
        return x
    
#num_classes = 70
net = con_vit(num_classes)
net.to(device)


In [None]:
early_stopper = EarlyStopper(patience=3, min_delta=0.01)
train_loss_history = []
valid_loss_history = []
train_acc_history = []
valid_acc_history = []
optimizer = optim.Adam(net.parameters(), lr=0.001)

EPOCHS = 7
for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "net.pt")
    if early_stopper.early_stop(valid_loss):      
        break

In [None]:
epochs = range(1,  len(train_loss_history)+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(epochs, train_acc_history)
axes[0].plot(epochs, valid_acc_history)
axes[0].set_title('Convit Training and validation accuracy',
                  fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xlabel('epoch') 
axes[0].legend(['Train', 'Validation'], loc='upper left')

axes[1].plot(epochs, train_loss_history)
axes[1].plot(epochs, valid_loss_history)
axes[1].set_title('Convit Training and validation loss',
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss')
axes[1].set_xlabel('epoch') 
axes[1].legend(['Train', 'Validation'], loc='upper right')
plt.show()

# Resnet50（只训练最后一层）

In [None]:
del net,early_stopper

In [None]:
net = timm.create_model('resnet50',
                          pretrained=True,
                          drop_rate=0.2,
                         )
print('Classifier layer:', net.get_classifier())
for param in net.parameters():
    param.requires_grad = False
#num_classes = 70
net.fc = nn.Linear(net.fc.in_features, num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
net.to(device)

In [None]:
early_stopper = EarlyStopper(patience=3, min_delta=0.01)
train_loss_history = []
valid_loss_history = []
train_acc_history = []
valid_acc_history = []
optimizer = optim.Adam(net.parameters(), lr=0.001)

EPOCHS = 7
for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}")

    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "resnet1-net.pt")
    if early_stopper.early_stop(valid_loss):
        break

In [None]:
epochs = range(1,  len(train_loss_history)+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(epochs, train_acc_history)
axes[0].plot(epochs, valid_acc_history)
axes[0].set_title('resnet50(1) Training and validation accuracy',
                  fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xlabel('epoch') 
axes[0].legend(['Train', 'Validation'], loc='upper left')

axes[1].plot(epochs, train_loss_history)
axes[1].plot(epochs, valid_loss_history)
axes[1].set_title('resnet50(训练最后一层) Training and validation loss',
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss')
axes[1].set_xlabel('epoch') 
axes[1].legend(['Train', 'Validation'], loc='upper right')
plt.show()

# maxvit（自己写）

In [None]:
del net,early_stopper

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from torchvision import transforms
from torch.utils.data import Dataset, Subset, DataLoader
import torch
import torch.nn.functional as F
from torch import nn
from PIL import Image
import os

from re import T
from typing import Type, Callable, Tuple, Optional, Set, List, Union

from matplotlib.dates import relativedelta
from matplotlib.pyplot import grid
from numpy import intp
import torch
import torch.nn as nn

from timm.models.efficientnet_blocks import SqueezeExcite, DepthwiseSeparableConv
from timm.models.layers import drop_path, trunc_normal_, Mlp, DropPath


def _gelu_ignore_parameters(*args, **kwargs) -> nn.Module:
    activation = nn.GELU()
    return activation


class MBConv(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        downscale: bool = False,
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Type[nn.Module] = nn.BatchNorm2d,
        drop_path: float = 0.0,
    ):
        super(MBConv, self).__init__()
        self.drop_path_rate: float = drop_path
        if not downscale:
            assert in_channels == out_channels
        if act_layer == nn.GELU:
            act_layer = _gelu_ignore_parameters  # type: ignore
        self.main_path = nn.Sequential(
            norm_layer(in_channels),
            nn.Conv2d(
                in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 1)
            ),
            DepthwiseSeparableConv(
                in_chs=in_channels,
                out_chs=out_channels,
                stride=2 if downscale else 1,
                act_layer=act_layer,  # type: ignore
                norm_layer=norm_layer,  # type:ignore
                drop_path_rate=drop_path,
            ),
            SqueezeExcite(in_chs=out_channels, rd_ratio=0.25),
            nn.Conv2d(
                in_channels=out_channels, out_channels=out_channels, kernel_size=(1, 1)
            ),
        )
        self.skip_path = (
            nn.Sequential(
                nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=(1, 1),
                ),
            )
            if downscale
            else nn.Identity()
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = self.main_path(input)
        if self.drop_path_rate > 0.0:
            output = drop_path(output, self.drop_path_rate, self.training)
        output = output + self.skip_path(input)
        return output


def window_partition(
    input: torch.Tensor, window_size: Tuple[int, int] = (7, 7)
) -> torch.Tensor:
    B, C, H, W = input.shape
    windows = input.view(
        B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]
    )
    windows = (
        windows.permute(0, 2, 4, 3, 5, 1)
        .contiguous()
        .view(-1, window_size[0], window_size[1], C)
    )
    return windows


def window_reverse(
    windows: torch.Tensor,
    original_size: Tuple[int, int],
    window_size: Tuple[int, int] = (7, 7),
) -> torch.Tensor:
    H, W = original_size
    B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
    output = windows.view(
        B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
    )
    output = output.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, -1, H, W)
    return output


def grid_partition(
    input: torch.Tensor, grid_size: Tuple[int, int] = (7, 7)
) -> torch.Tensor:
    B, C, H, W = input.shape
    grid = input.view(
        B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]
    )
    grid = (
        grid.permute(0, 3, 5, 2, 4, 1)
        .contiguous()
        .view(-1, grid_size[0], grid_size[1], C)
    )
    return grid


def grid_reverse(
    grid: torch.Tensor,
    original_size: Tuple[int, int],
    grid_size: Tuple[int, int] = (7, 7),
) -> torch.Tensor:
    (H, W), C = original_size, grid.shape[-1]
    B = int(grid.shape[0] / (H * W / grid_size[0] / grid_size[1]))
    output = grid.view(
        B, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C
    )
    output = output.permute(0, 5, 3, 1, 4, 2).contiguous().view(B, C, H, W)
    return output


def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor:
    coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))
    coords_flatten = torch.flatten(coords, 1)
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
    relative_coords = relative_coords.permute(1, 2, 0).contiguous()
    relative_coords[:, :, 0] += win_h - 1
    relative_coords[:, :, 1] += win_w - 1
    relative_coords[:, :, 0] *= 2 * win_w - 1
    return relative_coords.sum(-1)


class RelativeSelfAttention(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_heads: int = 32,
        grid_window_size: Tuple[int, int] = (7, 7),
        attn_drop: float = 0.0,
        drop: float = 0.0,
    ):
        super(RelativeSelfAttention, self).__init__()
        self.in_channels: int = in_channels
        self.num_heads: int = num_heads
        self.grid_window_size: Tuple[int, int] = grid_window_size
        self.scale: float = num_heads**-0.5
        self.attn_area: int = grid_window_size[0] * grid_window_size[1]
        self.qkv_mapping = nn.Linear(
            in_features=in_channels, out_features=3 * in_channels, bias=True
        )
        self.attn_drop = nn.Dropout(p=attn_drop)
        self.proj = nn.Linear(
            in_features=in_channels, out_features=in_channels, bias=True
        )
        self.proj_drop = nn.Dropout(p=drop)
        self.softmax = nn.Softmax(dim=-1)
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(
                (2 * grid_window_size[0] - 1) * (2 * grid_window_size[1] - 1), num_heads
            )
        )
        self.register_buffer(
            "relative_position_index",
            get_relative_position_index(grid_window_size[0], grid_window_size[1]),
        )
        trunc_normal_(self.relative_position_bias_table, std=0.02)

    def _get_relative_position_bias(self) -> torch.Tensor:
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(self.attn_area, self.attn_area, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        return relative_position_bias.unsqueeze(0)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        B_, N, C = input.shape
        qkv = (
            self.qkv_mapping(input)
            .reshape(B_, N, 3, self.num_heads, -1)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv.unbind(0)
        attn = self.softmax(
            q @ k.transpose(-2, -1) + self._get_relative_position_bias()
        )
        output = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
        output = self.proj(output)
        output = self.proj_drop(output)
        return output


class MaxViTTransformerBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        partition_function: Callable,
        reverse_function: Callable,
        num_heads: int = 32,
        grid_window_size: Tuple[int, int] = (7, 7),
        attn_drop: float = 0.0,
        drop: float = 0.0,
        drop_path: float = 0.0,
        mlp_ratio: float = 4.0,
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
    ) -> None:
        super(MaxViTTransformerBlock, self).__init__()
        self.partition_function: Callable = partition_function
        self.reverse_function: Callable = reverse_function
        self.grid_window_size: Tuple[int, int] = grid_window_size
        self.norm_1 = norm_layer(in_channels)
        self.attention = RelativeSelfAttention(
            in_channels=in_channels,
            num_heads=num_heads,
            grid_window_size=grid_window_size,
            attn_drop=attn_drop,
            drop=drop,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm_2 = norm_layer(in_channels)
        self.mlp = Mlp(
            in_features=in_channels,
            hidden_features=int(mlp_ratio * in_channels),
            act_layer=act_layer,  # type:ignore
            drop=drop,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        B, C, H, W = input.shape
        input_partitioned = self.partition_function(input, self.grid_window_size)
        input_partitioned = input_partitioned.view(
            -1, self.grid_window_size[0] * self.grid_window_size[1], C
        )
        output = input_partitioned + self.drop_path(
            self.attention(self.norm_1(input_partitioned))
        )
        output = output + self.drop_path(self.mlp(self.norm_2(output)))
        output = self.reverse_function(output, (H, W), self.grid_window_size)
        return output


class MaxViTBlock(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        downscale: bool = False,
        num_heads: int = 32,
        grid_window_size: Tuple[int, int] = (7, 7),
        attn_drop: float = 0.0,
        drop: float = 0.0,
        drop_path: float = 0.0,
        mlp_ratio: float = 4.0,
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Type[nn.Module] = nn.BatchNorm2d,
        norm_layer_transformer: Type[nn.Module] = nn.LayerNorm,
    ) -> None:
        super(MaxViTBlock, self).__init__()
        self.mb_conv = MBConv(
            in_channels=in_channels,
            out_channels=out_channels,
            downscale=downscale,
            act_layer=act_layer,
            norm_layer=norm_layer,
            drop_path=drop_path,
        )
        self.block_transformer = MaxViTTransformerBlock(
            in_channels=out_channels,
            partition_function=window_partition,
            reverse_function=window_reverse,
            num_heads=num_heads,
            grid_window_size=grid_window_size,
            attn_drop=attn_drop,
            drop=drop,
            drop_path=drop_path,
            mlp_ratio=mlp_ratio,
            act_layer=act_layer,
            norm_layer=norm_layer_transformer,
        )
        self.grid_transformer = MaxViTTransformerBlock(
            in_channels=out_channels,
            partition_function=grid_partition,
            reverse_function=grid_reverse,
            num_heads=num_heads,
            grid_window_size=grid_window_size,
            attn_drop=attn_drop,
            drop=drop,
            drop_path=drop_path,
            mlp_ratio=mlp_ratio,
            act_layer=act_layer,
            norm_layer=norm_layer_transformer,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = self.grid_transformer(self.block_transformer(self.mb_conv(input)))
        return output


class MaxViTStage(nn.Module):

    def __init__(
        self,
        depth: int,
        in_channels: int,
        out_channels: int,
        num_heads: int = 32,
        grid_window_size: Tuple[int, int] = (7, 7),
        attn_drop: float = 0.0,
        drop: float = 0.0,
        drop_path: Union[List[float], float] = 0.0,
        mlp_ratio: float = 4.0,
        act_layer: Type[nn.Module] = nn.GELU,
        norm_layer: Type[nn.Module] = nn.BatchNorm2d,
        norm_layer_transformer: Type[nn.Module] = nn.LayerNorm,
    ) -> None:
        super(MaxViTStage, self).__init__()
        self.blocks = nn.Sequential(
            *[
                MaxViTBlock(
                    in_channels=in_channels if index == 0 else out_channels,
                    out_channels=out_channels,
                    downscale=index == 0,
                    num_heads=num_heads,
                    grid_window_size=grid_window_size,
                    attn_drop=attn_drop,
                    drop=drop,
                    drop_path=(
                        drop_path if isinstance(drop_path, float) else drop_path[index]  # type: ignore
                    ),
                    mlp_ratio=mlp_ratio,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    norm_layer_transformer=norm_layer_transformer,
                )
                for index in range(depth)
            ]
        )

    def forward(self, input=torch.Tensor) -> torch.Tensor:
        output = self.blocks(input)
        return output


class MaxViT(nn.Module):

    def __init__(
        self,
        in_channels: int = 3,
        depths: Tuple[int, ...] = (2, 2, 5, 2),
        channels: Tuple[int, ...] = (64, 128, 256, 512),
        num_classes: int = 1000,
        embed_dim: int = 64,
        num_heads: int = 32,
        grid_window_size: Tuple[int, int] = (7, 7),
        attn_drop: float = 0.0,
        drop=0.0,
        drop_path=0.0,
        mlp_ratio=4.0,
        act_layer=nn.GELU,
        norm_layer=nn.BatchNorm2d,
        norm_layer_transformer=nn.LayerNorm,
        global_pool: str = "avg",
    ) -> None:
        super(MaxViT, self).__init__()
        # Check parameters
        assert len(depths) == len(
            channels
        ), "For each stage a channel dimension must be given."
        assert global_pool in [
            "avg",
            "max",
        ], f"Only avg and max is supported but {global_pool} is given"
        # Save parameters
        self.num_classes: int = num_classes
        # Init convolutional stem
        self.stem = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1),
            ),
            act_layer(),
            nn.Conv2d(
                in_channels=embed_dim,
                out_channels=embed_dim,
                kernel_size=(3, 3),
                stride=(1, 1),
                padding=(1, 1),
            ),
            act_layer(),
        )
        # Init blocks
        drop_path = torch.linspace(0.0, drop_path, sum(depths)).tolist()
        stages = []
        for index, (depth, channel) in enumerate(zip(depths, channels)):
            stages.append(
                MaxViTStage(
                    depth=depth,
                    in_channels=embed_dim if index == 0 else channels[index - 1],
                    out_channels=channel,
                    num_heads=num_heads,
                    grid_window_size=grid_window_size,
                    attn_drop=attn_drop,
                    drop=drop,
                    drop_path=drop_path[sum(depths[:index]) : sum(depths[: index + 1])],
                    mlp_ratio=mlp_ratio,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    norm_layer_transformer=norm_layer_transformer,
                )
            )
        self.stages = nn.ModuleList(stages)
        self.global_pool: str = global_pool
        self.head = nn.Linear(channels[-1], num_classes)

    @torch.jit.ignore  # type: ignore
    def no_weight_decay(self) -> Set[str]:
        nwd = set()
        for n, _ in self.named_parameters():
            if "relative_position_bias_table" in n:
                nwd.add(n)
        return nwd

    def reset_classifier(
        self, num_classes: int, global_pool: Optional[str] = None
    ) -> None:
        self.num_classes: int = num_classes
        if global_pool is not None:
            self.global_pool = global_pool
        self.head = (
            nn.Linear(self.num_features, num_classes)
            if num_classes > 0
            else nn.Identity()
        )

    def forward_features(self, input: torch.Tensor) -> torch.Tensor:
        output = input
        for stage in self.stages:
            output = stage(output)
        return output

    def forward_head(self, input: torch.Tensor, pre_logits: bool = False):
        if self.global_pool == "avg":
            input = input.mean(dim=(2, 3))
        elif self.global_pool == "max":
            input = torch.amax(input, dim=(2, 3))
        return input if pre_logits else self.head(input)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        output = self.forward_features(self.stem(input))
        output = self.forward_head(output)
        return output


def max_vit_base_224(**kwargs) -> MaxViT:
    """MaxViT base for a resolution of 224 X 224"""
    return MaxViT(
        depths=(2, 6, 14, 2), channels=(96, 192, 384, 768), embed_dim=64, **kwargs
    )


In [None]:
early_stopper = EarlyStopper(patience=3, min_delta=0.01)
optimizer = optim.SGD(net.parameters(), lr=0.015)
net.to(device)
train_loss_history = []
valid_loss_history = []
train_acc_history = []
valid_acc_history = []
EPOCHS = 20
for epoch in range(EPOCHS):
    train_loss, train_acc = train(net, train_dataloader)
    valid_loss, valid_acc = validate(net, valid_dataloader)
    print(
        f"Epoch: {epoch:2d}, training loss: {train_loss:.3f}, training acc: {train_acc:.3f} validation loss: {valid_loss:.3f}, validation acc: {valid_acc:.3f}"
    )
    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

    train_acc_history.append(train_acc)
    valid_acc_history.append(valid_acc)

    if valid_loss <= min(valid_loss_history):
        torch.save(net.state_dict(), "maxvit-net.pt")

In [None]:
epochs = range(1,  len(train_loss_history)+1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(epochs, train_acc_history)
axes[0].plot(epochs, valid_acc_history)
axes[0].set_title('resnet50(1) Training and validation accuracy',
                  fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy')
axes[0].set_xlabel('epoch') 
axes[0].legend(['Train', 'Validation'], loc='upper left')

axes[1].plot(epochs, train_loss_history)
axes[1].plot(epochs, valid_loss_history)
axes[1].set_title('resnet50(训练最后一层) Training and validation loss',
                  fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss')
axes[1].set_xlabel('epoch') 
axes[1].legend(['Train', 'Validation'], loc='upper right')
plt.show()