In [1]:
from fastai import *
from fastai.vision import *
from fastai.vision.image import *

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

import os
import cv2

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.utils import class_weight
import scipy.optimize as opt
from sklearn.metrics import f1_score

import math
import cv2
import subprocess
from operator import itemgetter

In [2]:
def create_class_weight(labels_dict, mu=0.5):
    total = np.sum(list(labels_dict.values()))
    keys = labels_dict.keys()
    class_weight = dict()
    class_weight_log = dict()

    for key in keys:
        score = total / float(labels_dict[key])
        score_log = math.log(mu * total / float(labels_dict[key]))
        class_weight[key] = round(score, 2) if score > 1.0 else round(1.0, 2)
        class_weight_log[key] = round(score_log, 2) if score_log > 1.0 else round(1.0, 2)

    return class_weight, class_weight_log

# Class abundance for protein dataset
labels_dict = {
    0: 12885,
    1: 1254,
    2: 3621,
    3: 1561,
    4: 1858,
    5: 2513,
    6: 1008,
    7: 2822,
    8: 53,
    9: 45,
    10: 28,
    11: 1093,
    12: 688,
    13: 537,
    14: 1066,
    15: 21,
    16: 530,
    17: 210,
    18: 902,
    19: 1482,
    20: 172,
    21: 3777,
    22: 802,
    23: 2965,
    24: 322,
    25: 8228,
    26: 328,
    27: 11
}

print('\nTrue class weights:')
print(create_class_weight(labels_dict)[0])
print('\nLog-dampened class weights:', end='')
weights = create_class_weight(labels_dict)[1]
print(weights)


True class weights:
{0: 3.94, 1: 40.5, 2: 14.02, 3: 32.53, 4: 27.33, 5: 20.21, 6: 50.38, 7: 18.0, 8: 958.15, 9: 1128.49, 10: 1813.64, 11: 46.46, 12: 73.81, 13: 94.57, 14: 47.64, 15: 2418.19, 16: 95.82, 17: 241.82, 18: 56.3, 19: 34.27, 20: 295.24, 21: 13.45, 22: 63.32, 23: 17.13, 24: 157.71, 25: 6.17, 26: 154.82, 27: 4616.55}

Log-dampened class weights:{0: 1.0, 1: 3.01, 2: 1.95, 3: 2.79, 4: 2.61, 5: 2.31, 6: 3.23, 7: 2.2, 8: 6.17, 9: 6.34, 10: 6.81, 11: 3.15, 12: 3.61, 13: 3.86, 14: 3.17, 15: 7.1, 16: 3.87, 17: 4.8, 18: 3.34, 19: 2.84, 20: 4.99, 21: 1.91, 22: 3.46, 23: 2.15, 24: 4.37, 25: 1.13, 26: 4.35, 27: 7.74}


In [3]:
class_weights = [weights[i] for i in range(28)]
print(class_weights)

[1.0, 3.01, 1.95, 2.79, 2.61, 2.31, 3.23, 2.2, 6.17, 6.34, 6.81, 3.15, 3.61, 3.86, 3.17, 7.1, 3.87, 4.8, 3.34, 2.84, 4.99, 1.91, 3.46, 2.15, 4.37, 1.13, 4.35, 7.74]


In [4]:
weighted_cross_entropy = partial(F.binary_cross_entropy_with_logits, 
                                 weight=to_device(torch.tensor(class_weights), data.device))