# prep

In [0]:
!nvidia-smi
import os
import logging
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

if not os.path.exists("drive"):
    from google.colab import drive
    drive.mount('/gdrive')
    !ln -s /gdrive/My\ Drive/ML drive_ml
else:
    !ln -s /gdrive/My\ Drive/ML drive_ml

In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, models

from PIL import Image
import matplotlib.pyplot as plt
import time, os, json, logging, random, h5py, shutil
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)




# config
exp_name = 'inception224'

learning_rate = 1e-4

num_epochs = 100
decay_epoch = [40, 70, 90]

crop_size = 44
input_ch  = 1
batch_size = 32

valid_epoch = 5

output_ch = 7
num_workers = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## prep

In [0]:
# load data
raw_data = {}
with h5py.File('/content/drive_ml/data.h5', 'r') as fp:
    for ds in ['train', 'valid', 'test']:
        raw_data[ds] = {}
        raw_data[ds]['data'] = np.array(fp[ds]['data']).reshape((-1, 48,48)).astype(np.uint8)
        raw_data[ds]['label'] = np.array(fp[ds]['label'])

class FerDataset(Dataset):
    def __init__(self, input_ch, data_label, transform=None):
        self.data = data_label['data']
        self.label= torch.LongTensor(data_label['label'])
        self.transform = transform
        self.input_ch = input_ch

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        dt = self.data[idx,:]
        dt = np.stack([dt]*self.input_ch, axis=2)
        dt = self.transform(dt)
        return dt, self.label[idx]
    
def create_dataloader(crop_size, input_size):
    if crop_size==48:
        transform = {
            'train':None, 
            'valid':None,
            'test': None
        }
    else:
        transform = {
            'train': transforms.Compose([
                transforms.ToPILImage(mode='L'),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomCrop(crop_size),
                transforms.Resize(input_size),
                transforms.ToTensor()
            ]),
            'valid':transforms.Compose([
                transforms.ToPILImage(mode='L'),
                transforms.CenterCrop(crop_size),
                transforms.Resize(input_size),
                transforms.ToTensor(),
            ]),
            'test': transforms.Compose([
                transforms.ToPILImage(mode='L'),
                transforms.CenterCrop(crop_size),
                transforms.Resize(input_size),
                transforms.ToTensor(),
            ])
        }


    datasets = { x: FerDataset(input_ch, raw_data[x], transform[x]) for x in ['train', 'valid', 'test'] }

    logging.info(', '.join([
        "{} {}".format(x, len(datasets[x]))
        for x in ['train', 'valid', 'test']
    ]))

    dataloaders = {
        "train": DataLoader(
            datasets['train'], 
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers
        ), 
        "valid": DataLoader(
            datasets['valid'], 
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers
        ), 
        "test": DataLoader(
            datasets['test'], 
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers
        ),
    }
    return datasets, dataloaders

#datasets, dataloaders = create_dataloader(crop_size, crop_size)
#datasets_incep, dataloaders_incep = create_dataloader(crop_size, 299)
#datasets_dense, dataloaders_dense = create_dataloader(crop_size, 224)

datasets, dataloaders = create_dataloader(crop_size, 224)

In [0]:
class ResNet18(nn.Module):
    name = 'ResNet18'
    def __init__(self, input_ch, output_ch):
        super(ResNet18, self).__init__()
        
        self.resnet = models.resnet18(pretrained=False)
            
        if input_ch!=3:
            self.resnet.conv1 = torch.nn.Conv2d(input_ch, 64, kernel_size=7)
            
        old_n_ch = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(old_n_ch, output_ch)
        
        
        for param in self.resnet.parameters():
            param.requires_grad = True
        
    def forward(self, x):
        out = self.resnet(x)
        return out
    
    
    
class ResNet50(nn.Module):
    name = 'ResNet50'
    def __init__(self, input_ch, output_ch):
        super(ResNet50, self).__init__()
        
        self.resnet = models.resnet50(pretrained=False)
            
        if input_ch!=3:
            self.resnet.conv1 = torch.nn.Conv2d(input_ch, 64, kernel_size=7)
            
        old_n_ch = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(old_n_ch, output_ch)
        
        
        for param in self.resnet.parameters():
            param.requires_grad = True
        
    def forward(self, x):
        out = self.resnet(x)
        return out
    

    

class VGG11(nn.Module):
    name = 'VGG11'
    def __init__(self, input_ch, output_ch):
        super(VGG11, self).__init__()
        
        self.vgg11 = models.vgg11(pretrained=False)
        
        old_n_ch = self.vgg11.classifier[6].in_features
        self.vgg11.classifier[6] = nn.Linear(old_n_ch, output_ch)
        
        for param in self.vgg11.parameters():
            param.requires_grad = True
        
        
        if input_ch!=3:
            self.vgg11.features[0] = torch.nn.Conv2d(input_ch, 64, kernel_size=7)
        
        
    def forward(self, x):
        out = self.vgg11(x)
        return out
    

class ResBlock(nn.Module):
    def __init__(self, input_ch, expand=True):
        super(ResBlock, self).__init__()

        output_ch = input_ch*2 if expand else input_ch

        self.conv_block = nn.Sequential(
            nn.Conv2d(input_ch, output_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(output_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_ch, output_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(output_ch)
        )

        self.downsample = nn.Sequential(
            nn.Conv2d(input_ch, output_ch, kernel_size=1, bias=False),
            nn.BatchNorm2d(output_ch)
        ) if expand else None

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv_block(x)

        if self.downsample:
            identity = self.downsample(x)
        else:
            identity = x

        out = out+identity
        out = self.relu(out)
        return out


class LightResNet(nn.Module):
    """9 layer, suitable for 44x44 image"""
    
    name = 'LightResNet'
    
    def __init__(self, input_ch, output_ch):
        super(LightResNet, self).__init__()
        
        # 第一层conv换成小一点的kernel
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_ch, 64, kernel_size=5, bias=False),   # 40x40x64ch
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        # 第一层pooling不用了
        
        self.conv2_block = ResBlock(64, expand=False)   # 20x20x64ch
        self.conv3_block = ResBlock(64)                 # 10x10x128ch
        self.conv4_block = ResBlock(128)                # 5x5x256ch
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))     # 256ch
        self.fc1 = nn.Linear(256, 7)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2_block(x)
        x = self.conv3_block(x)
        x = self.conv4_block(x)
        x = self.avgpool(x)
        
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        return x
    
    
    
from torchvision.models import inception

class Inception3(nn.Module):
    name = 'Inception3'
    
    def __init__(self, input_ch, output_ch):
        super(Inception3, self).__init__()
        
        self.inception = inception.Inception3(num_classes=output_ch, aux_logits=False)
        self.inception.Conv2d_1a_3x3 = inception.BasicConv2d(input_ch, 32, kernel_size=3, stride=2)
        
        for param in self.inception.parameters():
            param.requires_grad = True
        
        
    def forward(self, x):
        out = self.inception(x)
        return out
    
    
    
class DenseNet121(nn.Module):
    name = 'DenseNet121'
    def __init__(self, input_ch, output_ch):
        super(DenseNet121, self).__init__()
        
        self.densenet = models.densenet121(pretrained=False, progress=False)
            
        if input_ch!=3:
            self.densenet.features[0] = nn.Conv2d(
                input_ch, 64, kernel_size=7, stride=2, padding=3, bias=False)
            
        old_n_ch = self.densenet.classifier.in_features
        self.densenet.classifier = nn.Linear(old_n_ch, output_ch)
        
        for param in self.densenet.parameters():
            param.requires_grad = True
        
    def forward(self, x):
        out = self.densenet(x)
        return out

In [0]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, logdir, resume_checkpoint=None, override=False):
    def train_one_epoch(epoch):
        model.train(True)

        tot_loss = 0.0
        tot_corr = 0
        tot_iter = len(datasets['train'])

        for inputs, labels in dataloaders['train']:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            outputs = nn.Softmax(dim=1)(outputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            preds = torch.argmax(outputs, dim=1)
            tot_loss += loss.item()
            tot_corr += torch.sum(preds == labels).item()

        logging.info('Train loss: {:.5f}, acc: {:.3f}%'.format(
            tot_loss / tot_iter, 
            100 * tot_corr / tot_iter
        ))
    
    def validate_one_epoch(epoch, best_val_acc):
        model.eval()

        tot_loss = 0.0
        tot_corr = 0
        tot_iter = len(datasets['valid'])

        for inputs, labels in dataloaders['valid']:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            outputs = nn.Softmax(dim=1)(outputs)
            loss = criterion(outputs, labels)

            preds = torch.argmax(outputs, dim=1)
            tot_loss += loss.item()
            tot_corr += torch.sum(preds == labels).item()
        
        epoch_loss = tot_loss / tot_iter
        epoch_acc = tot_corr / tot_iter
        logging.info('Valid Loss: {:.5f} Acc: {:.3f}%'.format(
            epoch_loss,
            100 * epoch_acc
        ))
        
        
        latest_path = os.path.join(logdir, 'latest.pth.tar')
        best_path   = os.path.join(logdir, 'best.pth.tar')

        torch.save({
            'epoch'     : epoch,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
            'scheduler' : scheduler.state_dict(),
            'val_acc'   : epoch_acc
        }, latest_path)
        if epoch_acc > best_val_acc:
            shutil.copyfile(latest_path, best_path)
            
        logging.info("Checkpoint saved")
        return epoch_acc
    
    
    start_epoch = 0
    best_val_acc = 0.0
    chechpoint_epoch = None
    
    
    if resume_checkpoint is not None:
        checkpoint = torch.load(resume_checkpoint)
        start_epoch = checkpoint['epoch']
        best_val_acc = checkpoint['val_acc']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
    else:
        assert override or not os.path.exists(logdir), logdir
        os.makedirs(logdir, exist_ok=override)
    
    logging.info("Log directory: "+logdir)
    
    
    try:
        for epoch in range(start_epoch, num_epochs):
            logging.info('Epoch {}/{}'.format(epoch+1, num_epochs))
            scheduler.step()
            train_one_epoch(epoch)
            chechpoint_epoch = epoch
            
            if (epoch+1)%valid_epoch==0:
                val_acc = validate_one_epoch(epoch, best_val_acc)
                best_val_acc = max(best_val_acc, val_acc)

    except KeyboardInterrupt:
        logging.warning("interrupted, latest chechpoint is at epoch {}".format(epoch))

    logging.info('Training complete with best val acc: {:4f}'.format(best_val_acc))

    
    
def test_model(model, criterion, checkpoint_path, data_type='test'):
    assert os.path.exists(checkpoint_path)
    logging.info("Checkpoint: "+checkpoint_path)
    
    checkpoint = torch.load(checkpoint_path)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    

    model.eval()

    tot_loss = 0.0
    tot_corr = 0
    tot_iter = len(datasets[data_type])

    for inputs, labels in dataloaders[data_type]:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        outputs = nn.Softmax(dim=1)(outputs)
        loss = criterion(outputs, labels)

        preds = torch.argmax(outputs, dim=1)
        tot_loss += loss.item()
        tot_corr += torch.sum(preds == labels).item()

    logging.info('On {} set: Loss: {:.5f} Acc: {:.3f}%'.format(
        data_type,
        tot_loss / tot_iter,
        100 * tot_corr / tot_iter
    ))
    


def model_list_inference(model_list, checkpoint_path_list):
    for model, checkpoint_path in zip(model_list, checkpoint_path_list):
        assert os.path.exists(checkpoint_path), checkpoint_path
        logging.info("Checkpoint: "+checkpoint_path)

        checkpoint = torch.load(checkpoint_path)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
    
        model.eval()
    
    inference_result = []

    for i, ((inputs, labels), (inputs_in, _), (inputs_den, _)) in enumerate(zip(dataloaders['test'], dataloaders_incep['test'], dataloaders_dense['test'])):
        inputs = inputs.to(device)
        inputs_in = inputs_in.to(device)
        inputs_den = inputs_den.to(device)
        
        outputs = []
        for i, model in enumerate(model_list):
            if model.name=='Inception3':
                out = model(inputs_in)
            elif model.name=='DenseNet121':
                out = model(inputs_den)
            else:
                out = model(inputs)
            outputs.append(out.cpu().detach())
            
        inference_result.append((outputs, labels))
    return inference_result


def test_fusion(inference_result, softmax=False, weight=None, verbose=True):
    if weight is None:
        weight = [1]*len(inference_result[0][0])
    
        
    tot_corr = 0
    tot_iter = len(datasets['test'])

    for i, (outputs, labels)  in enumerate(inference_result):
        preds = []
        for j, x in enumerate(outputs):
            if softmax:
                x = nn.Softmax(dim=1)(x)
            preds.append(x*weight[j])
        
        preds = torch.stack(preds, dim=0).sum(dim=0)
        preds = torch.argmax(preds, dim=1)
        
        tot_corr += torch.sum(preds == labels).item()
    if verbose:
        logging.info('Test Acc: {:.3f}%'.format(
            100 * tot_corr / tot_iter
        ))
    
    return tot_corr / tot_iter

## train/test

In [0]:
####################################################################################
#                                  Training                                        #
####################################################################################

model     = Inception3(input_ch, output_ch).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=decay_epoch, gamma=0.1)

train_model(model, criterion, optimizer, scheduler, 
            num_epochs=num_epochs, 
            logdir='drive_ml/log/{}'.format(exp_name),
            #resume_checkpoint='drive_ml/log/inception224/latest.pth.tar'
            #override=True
           )

In [0]:
####################################################################################
#                                  Testing                                         #
####################################################################################


test_model(model, criterion, checkpoint_path='drive_ml/log/{}/best.pth.tar'.format(exp_name))

In [0]:
model     = Inception3(input_ch, output_ch).to(device)
test_model(model, criterion, checkpoint_path='drive_ml/log/inception/best.pth.tar', data_type='train')

## fuse

In [0]:
inference_result = model_list_inference([
    DenseNet121(input_ch, output_ch).to(device),
    Inception3(input_ch, output_ch).to(device),
    LightResNet(input_ch, output_ch).to(device),
    ResNet50(input_ch, output_ch).to(device),
    ResNet18(input_ch, output_ch).to(device),
    VGG11(input_ch, output_ch).to(device),
],
[
    'drive_ml/log/densenet/best.pth.tar',
    'drive_ml/log/inception/best.pth.tar',
    'drive_ml/log/lightres/best.pth.tar',
    'drive_ml/log/resnet50/best.pth.tar',
    'drive_ml/log/resnet18/best.pth.tar',
    'drive_ml/log/vgg11/best.pth.tar',
])

In [0]:
axisx, axisy = [], []

for x in range(0, 200):
    x = x/100.
    #logging.info(str(x))
    acc = test_fusion(inference_result, softmax=False, weight = [1,0.23,1.5,0.3,0,0], verbose=False)
    axisx.append(x)
    axisy.append(acc)
    
import matplotlib.pyplot as plt
plt.plot(axisx, axisy)
plt.show()


i = np.argmax(np.array(axisy))
print(axisx[i], axisy[i])

In [0]:
test_fusion(inference_result, softmax=True, weight = [1,0.9,0,0,0], verbose=False)