In [None]:
import time
import pandas as pd
import torch
import torchvision
import copy
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch import nn
from PIL import Image
from d2l import torch as d2l

In [None]:
def mytransform():
    transA = transforms.RandomHorizontalFlip()
    transB = transforms.RandomResizedCrop((224, 224), scale=(0.8, 1), ratio=(0.8, 1.25))
    transC = transforms.ColorJitter(brightness=0.7, contrast=0.7)
    return transforms.Compose([transA, transB, transC]) #三种变换的组合 

In [None]:
class FacialExpressionDataset(Dataset):
    def __init__(self, csv_path, train=True, readtensor=False, transform=None): #readtensor=True表示跳过前处理，直接读tensor        
        if not readtensor:        
            data = pd.read_csv(csv_path)
            if train:            
                data = data[data.Usage=='Training']
                data = data.reset_index(drop=True)
            else:
                data = data[data.Usage=='Test']
                data = data.reset_index(drop=True)            
            print('Total%d'%len(data))
            st = time.time()
            self.Xlst=None
            self.ylst=None
            self.xsublst = None
            self.ysublst = None
            count = 0
            groupnum = 0
            subcount = 0
            final = False
            ########逐个图片操作########
            for i in range(len(data)):
                if i%500==0:
                    et = time.time()
                    print('complete%d, %d s'%(i, int(et-st)))
                xlst = data.loc[i].pixels.split()        
                xlst = [int(x) for x in xlst]
                xlst = torch.tensor(xlst,dtype=torch.uint8).resize(48,48) #转为48*48的uint8 tensor
                xlst = transforms.ToPILImage()(xlst) #转为Image                
                y = torch.tensor([data.loc[i].emotion],dtype=torch.int64)
                
                # 从48*48扩至224*224并数据增广(image-augmentation)，若是1(Disgust)则多重复10倍以平衡数据集                
                turns = 1
                if transform:                    
                    turns = 20 if y==1 else 2
                    xlst = xlst.resize((288,288),resample=Image.HAMMING) #插值到288*288 
                else:
                    xlst = xlst.resize((224,224),resample=Image.HAMMING) #插值到224*224                
                for j in range(turns): 
                    if i==len(data)-1 and j==turns-1:
                        final = True
                    if transform:
                        x = transform(xlst) #数据增广并变到224*224                    
                    x = 256*transforms.ToTensor()(x).resize(1,1,224,224) #转为1*1*224*224的float32 tensor
                    #逐数据保存                    
                    pack = 200 #200个数据一组来处理
                    if final: #末数据
                        self.xsublst = torch.cat((self.xsublst,x),0)
                        self.ysublst = torch.cat((self.ysublst,y),0)
                        if groupnum==0:
                            self.Xlst = self.xsublst
                            self.ylst = self.ysublst
                        else:
                            self.Xlst = torch.cat((self.Xlst,self.xsublst),0)
                            self.ylst = torch.cat((self.ylst,self.ysublst),0)
                    elif subcount==0:                        
                        if groupnum==0:#首组的组头
                            self.xsublst = x
                            self.ysublst = y
                        elif groupnum==1:#次组的组头
                            self.Xlst = self.xsublst
                            self.ylst = self.ysublst
                            self.xsublst = x
                            self.ysublst = y
                        else:#其余的组头
                            self.Xlst = torch.cat((self.Xlst,self.xsublst),0)
                            self.ylst = torch.cat((self.ylst,self.ysublst),0)
                            self.xsublst = x
                            self.ysublst = y
                    else: #非组头
                        self.xsublst = torch.cat((self.xsublst,x),0)
                        self.ysublst = torch.cat((self.ysublst,y),0)
                    count = count+1
                    groupnum = count//pack
                    subcount = count%pack
                    
            ########保存处理后的数据########
            if train:
                torch.save(self.Xlst, '../data/train_data_Xlst_aug_balance.pt')
                torch.save(self.ylst, '../data/train_data_ylst_aug_balance.pt')
            else:
                torch.save(self.Xlst, '../data/test_data_Xlst_aug_balance.pt')
                torch.save(self.ylst, '../data/test_data_ylst_aug_balance.pt')
            self.L = len(self.ylst)
            print('correct!'if self.L==count else 'error!')
        
        else:
            if train:
                self.Xlst=torch.load('../data/train_data_Xlst_aug_balance.pt')
                self.ylst=torch.load('../data/train_data_ylst_aug_balance.pt')
            else:
                self.Xlst=torch.load('../data/test_data_Xlst_aug_balance.pt')
                self.ylst=torch.load('../data/test_data_ylst_aug_balance.pt')
            self.L = len(self.ylst)

    def __len__(self):
        return self.L

    def __getitem__(self, idx):
        X = torch.cat((self.Xlst[idx],self.Xlst[idx],self.Xlst[idx]),0)
        return (X, self.ylst[idx])

In [None]:
train_data = FacialExpressionDataset('../data/data.csv', train=True, readtensor=True, transform=mytransform())
test_data = FacialExpressionDataset('../data/data.csv', train=False, readtensor=True, transform=mytransform())

In [None]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs, device):
    model = model.to(device)
    since = time.time()
    val_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode
            running_loss = 0.0
            running_corrects = 0
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [None]:
#输出层拥有十倍学习率
def train_fine_tuning_vgg(net, train_iter, test_iter, learning_rate=5e-5, num_epochs=10, device=torch.device('cpu')):
    loss = nn.CrossEntropyLoss()
    params_1x = [param for name, param in net.named_parameters() if name not in ["classifier.6.weight", "classifier.6.bias"]]               
    trainer = torch.optim.SGD([{'params': params_1x},{'params': net.classifier[6].parameters(),'lr': learning_rate*10}],
                                lr=learning_rate, weight_decay=0.001)
    dataloaders_dict = {'train':train_iter, 'val':test_iter}
    finetune_net, hist = train_model(net, dataloaders_dict, loss, trainer, num_epochs, device)
    return finetune_net, hist

In [None]:
pretrained_net = torchvision.models.vgg11_bn(pretrained=True)
pretrained_net.classifier[6] = nn.Linear(4096,7)
nn.init.xavier_uniform_(pretrained_net.classifier[6].weight)

In [None]:
# for name, param in pretrained_net.named_parameters():
#     print('\n',name)

In [None]:
learning_rate, num_epochs, batch_size = 0.001, 25, 64
train_iter = DataLoader(train_data, batch_size, shuffle=True, num_workers=4)
test_iter = DataLoader(test_data, batch_size, shuffle=False, num_workers=4)
finetune_vgg11net, hist = train_fine_tuning_vgg(pretrained_net, train_iter, test_iter, learning_rate, num_epochs, torch.device('cuda:1'))

In [None]:
torch.save(finetune_vgg11net, './finetune_vgg11net.pkl')