#### Import Library

In [1]:
import torch
import torch.nn.functional as F
from torch import nn,optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import torchvision.models as models
from torch.optim import SGD
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

# Hyperparameters
torch.manual_seed(1)
batch_size= 60
learning_rate=0.002
num_epoches=15
num_classes=7

#### Read data from images and excel-labels

In [2]:
# read the images and names
IMAGE_DIR = 'C_img_jpg'
def read_images(image_path=IMAGE_DIR):
    images = []
    images_names = [image for image in os.listdir(image_path) if not image.startswith('.')] 
    for image_name in images_names: 
            img = Image.open (os.path.join(image_path, image_name))
            if Random:
                img_transforms = transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.RandomHorizontalFlip(0.5),
                                       transforms.ToTensor()
                                       ])
                tensor = img_transforms(img)
                images.append(tensor)
            else:
                images.append(img)
    return images,images_names
images,names = read_images()

# read the label
excel=pd.read_csv('C_all_data_add_20211007_xz.csv', encoding = 'gb2312')
excel = np.array(excel)
excel = excel.tolist()
# Alligning the labels with images. (i.e. the first label corresponds to the first label.)
labels = [] 
sub_labels = [] 
for i in range(len(names)):
    for j in range(len(excel)):
        if names[i] == excel[j][0]:
            labels.append(excel[j][1])
            sub_labels.append(excel[j][2:])

#### Split the dataset

In [3]:
Random = 1
if Random:
    from sklearn.model_selection import train_test_split
    
    final_input=[]
    for i in range(len(names)):
        final_input.append((torch.Tensor(sub_labels[i]),images[i]))
    
    X_train_3,X_test_3,y_train_3,y_test_3=train_test_split(final_input,labels,test_size=0.2,random_state=42)
    train_data_3=[]
    test_data_3=[]
    for i in range(len(X_train_3)):
        train_data_3.append((X_train_3[i],y_train_3[i]))
    for i in range(len(X_test_3)):
        test_data_3.append((X_test_3[i],y_test_3[i]))

else:
    # Stratified Random Sampling
    from sklearn.model_selection import StratifiedShuffleSplit
    split = StratifiedShuffleSplit(n_splits = 1,test_size = 0.2,random_state = 42)

    names = np.array(names)
    labels = np.array(labels)
    names = names.reshape((names.shape[0],-1))
    labels = labels.reshape((labels.shape[0],-1))
    data = np.hstack((names,labels)) # hstack:each name corresponds to one label 
    for train_index,test_index in split.split(data,data[:,-1]):
        train_set = data[train_index,:]
        test_set = data[test_index,:]

#### Data Preprocessing

In [None]:
# data transformation
train_transforms = transforms.Compose([transforms.RandomHorizontalFlip(0.5),
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor()
                                        ])

train_data_1 = []
train_data_3 = []
for row in range(len(train_set)):
    for i in range(len(names)):
        if names[i] == train_set[row][0]: 
            img_data = train_transforms(images[i]) 
            # train_set is [names, labels], we need to get the train_data_1 [images, labels], train_data_3 [images, sub_labels, labels]
            train_data_1.append((img_data,int(train_set[row][1]))) 
            train_data_3.append((img_data,torch.Tensor(sub_labels[i][:]),int(train_set[row][1])))

test_transforms =  transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.ToTensor(),
                                       
                                       ])
test_data_1 = []
test_data_3 = []
for row in range(len(test_set)):
    for i in range(len(names)):
        if names[i] == test_set[row][0]:
            img_data = test_transforms(images[i])
            test_data_1.append((img_data,int(test_set[row][1]))) 
            test_data_3.append((img_data,torch.Tensor(sub_labels[i][:]),int(train_set[row][1])))

In [4]:
# data loader
test_size = 0.2
# sdudent_train_loader = DataLoader(student_train , batch_size = batch_size ,shuffle=True)
# student_test_loader = DataLoader(student_test, batch_size = batch_size, shuffle=False) #int(batch_size*test_size/(1-test_size))
teacher_train_loader = DataLoader(train_data_3 , batch_size=batch_size ,shuffle=True)
teacher_test_loader = DataLoader(test_data_3, batch_size=batch_size,shuffle=False )

#### Load the Network

In [5]:
"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, embed_dim=128, norm_layer=None):
        super().__init__()
        '''img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)'''
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        '''B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C'''
        #x = x.transpose(1, 2)  #[B,2,128]
        x = self.norm(x)
        return x


class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.1,
                 proj_drop_ratio=0.1):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim] [B, 3, 128]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim][B, 3, 3*128]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads /8, num_patches + 1 /3, embed_dim_per_head /16]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head] [B,8,3,16]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] [B,8,3,3]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class VisionTransformer(nn.Module):
    def __init__(self, num_classes=7, embed_dim=128, depth=12, num_heads=8, mlp_ratio=4.0,num_patches=8, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.3,
                 attn_drop_ratio=0.3, drop_path_ratio=0.3, embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(embed_dim=embed_dim)
        

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 2, 128]
        # [1, 1, 128]-> [B, 1, 128]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 3, 128]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x


def _init_vit_weights(m):
    """
    ViT weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)
        

In [35]:
def my_forward(model, x):
    mo = nn.Sequential(*list(model.children())[:-1])
    feature = mo(x)
    feature = feature.view(x.size(0), -1)
    output= model.fc(feature)
    return feature, output

if Random:
    class Net(nn.Module):
        def __init__(self, input_size, hidden_size, num_classes=7):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(input_size, hidden_size)
            self.relu = nn.ReLU()
            self.fc = nn.Linear(hidden_size, num_classes)

        def forward(self, x):
            out= self.fc1(x)
            out= self.relu(out)
            self.feature=out
            out= self.fc(out)
            return out

else:
    class Net(nn.Module):
        def __init__(self, input_size, hidden_size, num_classes=7):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(input_size, hidden_size)
            self.relu = nn.ReLU()
            self.dropout = nn.Dropout(p=0.5)
            self.fc2 = nn.Linear(hidden_size, hidden_size)
            self. fc = nn.Linear(hidden_size, 7)


        def forward(self, x):
            out= self.fc1(x)
            out= self.relu(out)
            out= self.dropout(out)
            out= self.fc2(out)
            out= self.relu(out)
            out= self.fc(out)
            return out

device=torch.device('cuda')
torch.cuda.set_device(7)
model_2=Net(6, 512, 7).to(device)

if Random:
    model_1 = torch.load('model_1_Random.pth').to(device)
    model_2.load_state_dict(torch.load('model_2_param_Random.pkl'))
else:
    model_1 = torch.load('model_1_分层.pth').to(device)
    model_2.load_state_dict(torch.load('model_2_param.pkl'))

per_patch = 1
# Teacher model
fmodel = torch.load('fmodel_with 1 token.pth').to(device)

# Student model
student = models.resnet18(pretrained = True)   
student.fc = nn.Sequential(
    nn.Linear(student.fc.in_features, 7)
)
student = student.to(device)

#### Train the Network

In [40]:
learning_rate = 0.1
num_epoches = 10
# distillation temperature
temp = 6
# hard_loss
hard_loss = nn.CrossEntropyLoss()
# hard_loss weight
alpha = 0.1
# soft_loss
soft_loss = nn.KLDivLoss(reduction='batchmean') # KL divergence loss
#optimizer = optim.Adam(student.parameters(),lr = learning_rate)
optimizer = optim.SGD(student.parameters(),lr=learning_rate)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)     # learning rate decay
per_patch = 1

model_1.eval()
model_2.eval()
fmodel.eval()
student.train()

train_loss_all = []   # store the loss of training set
train_accur_all = []  # store the accuracy of training set
test_loss_all = []    # store the loss of test set
test_accur_all = []   # store the accuracy of test set

# train the student model
for epoch in range(num_epoches):
    running_loss=0.0
    running_acc=0.0
    student.train()
    for j,data in enumerate(teacher_train_loader,1):
        img_plus, label = data
        img = img_plus[1]
        sub_label = img_plus[0]
        img = Variable(img).to(device)
        sub_label = Variable(sub_label).to(device)
        label = Variable(label).to(device) 
        
        # calculate the teacher model prediction
        with torch.no_grad():
            model_1, model_2, fmodel = model_1.to(device), model_2.to(device), fmodel.to(device)
            feature1, out1 = my_forward(model_1,img)
            feature2, out2 = my_forward(model_2,sub_label)
            '''for i in range(per_patch):
                exec("img_token_%i = feature1[:,128*i:128*(i+1)].unsqueeze(1)"%i)
                exec("label_token_%i = feature2[:,128*i:128*(i+1)].unsqueeze(1)"%i)
            input3 = torch.cat((img_token_0,img_token_1,img_token_2,img_token_3),1)
            input3 = torch.cat((input3,label_token_0,label_token_1,label_token_2,label_token_3),1)'''
            feature1 = feature1.unsqueeze(1)
            feature2 = feature2.unsqueeze(1)
            input3 = torch.cat((feature1,feature2),1)
            teacher_preds = fmodel(input3)

        # calculate the student model prediction
        student_preds = student(img)
        
        # calculate the hard_loss
        student_loss = hard_loss(student_preds, label)
        # calculate the soft_loss
        distillation_loss = soft_loss(
                            F.log_softmax(student_preds/temp, dim=1),
                            F.softmax(teacher_preds/temp, dim=1)
        )
        
        # calculate the total loss by weighted sum of hard_loss and soft_loss
        loss = alpha * student_loss + (1 - alpha) * distillation_loss
        
        running_loss += loss.data * label.size(0)
        _, pred = torch.max(student_preds,1)
        num_correct = (pred==label).sum()
        running_acc += num_correct.data
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #scheduler.step() # update learning rate
    
    print('Train {} epoch, Loss: {:.6f},Acc: {:.6f}'.format(epoch+1,running_loss/(len(train_data_3)),running_acc/len(train_data_3)))
    train_loss_all.append(running_loss / len(train_data_3))   # store the loss of training set, and then plot it
    train_accur_all.append(running_acc/len(train_data_3))     # store the accuracy of training set, and then plot it
    
    # evaluate the student model
    student.eval()
    with torch.no_grad():
        eval_loss = 0
        eval_acc = 0
        train_loss_results = []
        
        for data in teacher_test_loader:
            img_plus, label = data
            img = img_plus[1]
            sub_label = img_plus[0]
            img = Variable(img).to(device)
            sub_label = Variable(sub_label).to(device)
            label = Variable(label).to(device) 
            
            student_preds = student(img)
            loss = hard_loss(student_preds, label)
            eval_loss += loss.data
            _,pred = torch.max(student_preds,1)
            num_correct = (pred == label).sum()
            eval_acc += num_correct.data

    print('Test Loss: {:,.6f}, Acc: {:,.6f}'.format(eval_loss/(len(test_data_3)), eval_acc/(len(test_data_3))))
    test_loss_all.append(eval_loss/(len(test_data_3)))
    test_accur_all.append(eval_acc/(len(test_data_3)))

Train 1 epoch, Loss: 0.020741,Acc: 0.985465
Test Loss: 0.013985, Acc: 0.776744
Train 2 epoch, Loss: 0.017469,Acc: 0.991279
Test Loss: 0.013172, Acc: 0.765116
Train 3 epoch, Loss: 0.016094,Acc: 0.993605
Test Loss: 0.012403, Acc: 0.776744
Train 4 epoch, Loss: 0.016114,Acc: 0.993023
Test Loss: 0.013892, Acc: 0.755814
Train 5 epoch, Loss: 0.015421,Acc: 0.995349
Test Loss: 0.012216, Acc: 0.790698
Train 6 epoch, Loss: 0.016023,Acc: 0.991860
Test Loss: 0.013071, Acc: 0.788372
Train 7 epoch, Loss: 0.013930,Acc: 0.995930
Test Loss: 0.012938, Acc: 0.783721
Train 8 epoch, Loss: 0.014895,Acc: 0.991860
Test Loss: 0.012438, Acc: 0.795349
Train 9 epoch, Loss: 0.013312,Acc: 0.995349
Test Loss: 0.014450, Acc: 0.751163
Train 10 epoch, Loss: 0.015565,Acc: 0.993023
Test Loss: 0.013276, Acc: 0.755814


#### Save the model


In [17]:
torch.save(student,'student_new.pth')