In [None]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from torch.cuda.amp import GradScaler, autocast
import torchvision.utils as vutils
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import random
from torch.autograd import Variable


In [None]:
class FerDataset(Dataset):
    def __init__(self, feats, labels, transform):
        self.feats = feats
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            idx = idx.tolist()
        feat = self.transform(Image.fromarray(np.array(self.feats[index])))
        label = torch.tensor(self.labels[idx]).type(torch.long)
        return (feat, label)

def prepare_data(data):
    image_array = np.zeros(shape=(len(data), 48, 48))
    image_label = np.array(list(map(int, data['emotion'])))

    for i, row in enumerate(data.index):
        image = np.fromstring(data.loc[row, 'pixels'], dtype=int, sep=' ')
        image = np.reshape(image, (48, 48))
        image_array[i] = image

    return image_array, image_label


def get_dataloaders(path, batch_size):
    fer2013 = pd.read_csv(path)

    # [num, 48, 48] [num]
    xtrain, ytrain = prepare_data(fer2013[fer2013['Usage'] == 'Training'])
    xval, yval = prepare_data(fer2013[fer2013['Usage'] == 'PrivateTest'])
    xtest, ytest = prepare_data(fer2013[fer2013['Usage'] == 'PublicTest'])

    mu, st = 0, 255

    test_transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.TenCrop(40),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(lambda tensors: torch.stack(
            [transforms.Normalize(mean=(mu,), std=(st,))(t) for t in tensors])),
    ])
    train_transform = transforms.Compose([
        # 转为灰度图
        transforms.Grayscale(),
        # 随机裁剪再Resize
        transforms.RandomResizedCrop(48, scale=(0.8, 1.2)),
        # 修改亮度、对比度、饱和度
        transforms.RandomApply([transforms.ColorJitter(
            brightness=0.5, contrast=0.5, saturation=0.5)], p=0.5),
        # 仿射变换
        transforms.RandomApply(
            [transforms.RandomAffine(0, translate=(0.2, 0.2))], p=0.5),
        # 水平翻转
        transforms.RandomHorizontalFlip(),
        # 旋转
        transforms.RandomApply([transforms.RandomRotation(10)], p=0.5),
        # 上下左右中心裁剪
        transforms.FiveCrop(40),

        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops])),
        # 标准化
        transforms.Lambda(lambda tensors: torch.stack(
            [transforms.Normalize(mean=(mu,), std=(st,))(t) for t in tensors])),
        # 
        transforms.Lambda(lambda tensors: torch.stack(
            [transforms.RandomErasing()(t) for t in tensors])),
    ])
    train_transform = test_transform

    train = FerDataset(xtrain, ytrain, train_transform)
    val = FerDataset(xval, yval, test_transform)
    test = FerDataset(xtest, ytest, test_transform)

    trainloader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
    valloader = DataLoader(val, batch_size=64, shuffle=True, num_workers=2)
    testloader = DataLoader(test, batch_size=64, shuffle=True, num_workers=2)

    return trainloader, valloader, testloader

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=7):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(1, 64, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def get_net():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [None]:
def cross_entropy(outputs, smooth_labels):
    loss = torch.nn.KLDivLoss(reduction='batchmean')
    return loss(F.log_softmax(outputs, dim=1), smooth_labels)


def smooth_one_hot(true_labels: torch.Tensor, classes: int, smoothing=0.0):
    """
    if smoothing == 0, it's one-hot method
    if 0 < smoothing < 1, it's smooth method

    """
    device = true_labels.device
    true_labels = torch.nn.functional.one_hot(
        true_labels, classes).detach().cpu()
    assert 0 <= smoothing < 1
    confidence = 1.0 - smoothing
    label_shape = torch.Size((true_labels.size(0), classes))
    with torch.no_grad():
        true_dist = torch.empty(
            size=label_shape, device=true_labels.device)
        true_dist.fill_(smoothing / (classes - 1))
        _, index = torch.max(true_labels, 1)

        true_dist.scatter_(1, torch.LongTensor(
            index.unsqueeze(1)), confidence)
    return true_dist.to(device)


class LabelSmoothingLoss(torch.nn.Module):
    def __init__(self, smoothing: float = 0.1,
                 reduction="mean", weight=None):
        super(LabelSmoothingLoss, self).__init__()
        self.smoothing = smoothing
        self.reduction = reduction
        self.weight = weight

    def reduce_loss(self, loss):
        return loss.mean() if self.reduction == 'mean' else loss.sum() \
            if self.reduction == 'sum' else loss

    def linear_combination(self, x, y):
        return self.smoothing * x + (1 - self.smoothing) * y

    def forward(self, preds, target):
        assert 0 <= self.smoothing < 1

        if self.weight is not None:
            self.weight = self.weight.to(preds.device)

        n = preds.size(-1)
        log_preds = F.log_softmax(preds, dim=-1)
        loss = self.reduce_loss(-log_preds.sum(dim=-1))
        nll = F.nll_loss(
            log_preds, target, reduction=self.reduction, weight=self.weight
        )
        return self.linear_combination(loss / n, nll)


def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def random_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:

def train(model, train_loader, loss_fn, optimizer, epoch, device, scaler, writer,
        label_smooth, label_smooth_value,mixup,mixup_alpha,Ncrop ):
    model.train()
    count = 0
    correct = 0
    train_loss = 0
    for i, data in enumerate(train_loader):
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        with autocast():
            if Ncrop:
                bs, ncrops, c, h, w = images.shape
                images = images.view(-1, c, h, w)
                labels = torch.repeat_interleave(labels, repeats=ncrops, dim=0)

            if mixup:
                images, labels_a, labels_b, lam = mixup_data(
                    images, labels, mixup_alpha)
                images, labels_a, labels_b = map(
                    Variable, (images, labels_a, labels_b))

            if epoch == 1:
                img_grid = vutils.make_grid(
                    images, nrow=10, normalize=True, scale_each=True)
                writer.add_image("Augemented image", img_grid, i)

            outputs = model(images)

            if label_smooth:
                if mixup:
                    # mixup + label smooth
                    soft_labels_a = smooth_one_hot(
                        labels_a, classes=7, smoothing=label_smooth_value)
                    soft_labels_b = smooth_one_hot(
                        labels_b, classes=7, smoothing=label_smooth_value)
                    loss = mixup_criterion(
                        loss_fn, outputs, soft_labels_a, soft_labels_b, lam)
                else:
                    # label smoorth
                    soft_labels = smooth_one_hot(
                        labels, classes=7, smoothing=label_smooth_value)
                    loss = loss_fn(outputs, soft_labels)
            else:
                if mixup:
                    # mixup
                    loss = mixup_criterion(
                        loss_fn, outputs, labels_a, labels_b, lam)
                else:
                    # normal CE
                    loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data).item()
        count += labels.shape[0]

    return train_loss / count, correct / count

def evaluate(model, val_loader, device, Ncrop):
    model.eval()
    count = 0
    correct = 0
    val_loss = 0
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            if Ncrop:
                # fuse crops and batchsize
                bs, ncrops, c, h, w = images.shape
                images = images.view(-1, c, h, w)

                # forward
                outputs = model(images)

                # combine results across the crops
                outputs = outputs.view(bs, ncrops, -1)
                outputs = torch.sum(outputs, dim=1) / ncrops

            else:
                outputs = model(images)

            loss = nn.CrossEntropyLoss()(outputs, labels)

            val_loss += loss
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels.data).item()
            count += labels.shape[0]

        return val_loss / count, correct / count


In [None]:
def cross_entropy(outputs, smooth_labels):
    loss = torch.nn.KLDivLoss(reduction='batchmean')
    return loss(F.log_softmax(outputs, dim=1), smooth_labels)
def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

In [None]:
net_name = 'ResNet18'
epochs = 300
batch_size = 128
lr = 0.1
scheduler = 'reduce'
momentum = 0.9
weight_decay = 1e-4

# todo
label_smooth = True 
label_smooth_value = 0.1
mixup = True
mixup_alpha = 1.0
Ncrop = True

# other
data_path = 'data/fer13/fer2013.csv'
results = './results'
writer = SummaryWriter(results+'/tensorboard_log')
device = torch.device('cuda:0')
save_freq = 10
resume = 0
seed = 0
name = 'official'
best_acc = 0.

train_loader, val_loader, test_loader = get_dataloaders(
        path=data_path,
        bs=batch_size, augment=True)
net = get_net()
net = net.to(device)

loss = cross_entropy
optimizer = torch.optim.SGD(net.parameters(), lr=lr, 
        momentum=momentum, weight_decay=weight_decay, nesterov=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.75, patience=5, verbose=True)
scaler = GradScaler()

for epoch in range(1, epochs+1):
    train_loss, train_acc = train(
        net, train_loader, loss, optimizer, epoch, device, scaler, writer,
        label_smooth, label_smooth_value,mixup,mixup_alpha,Ncrop)
    val_loss, val_acc = evaluate(net, val_loader, device, Ncrop)
    
    scheduler.step(val_acc)

    writer.add_scalar("Train/Loss", train_loss.item(), epoch)
    writer.add_scalar("Train/Accuracy", train_acc, epoch)
    writer.add_scalar("Valid/Loss", val_loss.item(), epoch)
    writer.add_scalar("Valid/Accuracy", val_acc, epoch)

    writer.add_scalars("Loss", {"Train": train_loss.item()}, epoch)
    writer.add_scalars("Accuracy", {"Train": train_acc}, epoch)
    writer.add_scalars("Loss", {"Valid": val_loss.item()}, epoch)
    writer.add_scalars("Accuracy", {"Valid": val_acc}, epoch)

    is_best = val_acc > best_acc
    best_acc = max(val_acc, best_acc)
    writer.add_scalar("Valid/Best Accuracy", best_acc, epoch)
    
writer.close()