In [1]:
import torch
from torch import nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
# Helpers
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def multi_label_acc(y_hat, y):
    return torch.count_nonzero(torch.round(y_hat) == y).item() / y.numel()

def classification_acc(y_hat, y):
    return torch.count_nonzero(torch.argmax(y_hat, dim=1) == y).item() / y.numel()

In [3]:
num_x = 64
num_c = 16
num_y = 10
k = 0.875 # ratio of x's to use to calculate concepts
num_xc = int(k * num_x)
num_samples = 10**5

In [4]:
class BinarySigmoid(nn.Module):
    def __init__(self):
        super(BinarySigmoid, self).__init__()

    def __repr__(self):
        return 'BinarySigmoid()'

    def forward(self, x):
        x = torch.nn.functional.sigmoid(x)
        return x + torch.round(x).detach() - x.detach()
    
def init_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.normal_(m.bias)
            m.requires_grad_(False)

x_to_c = nn.Sequential(
    nn.Linear(num_xc, 32),
    nn.Sigmoid(),
    nn.Linear(32, 32),
    nn.Sigmoid(),
    nn.Linear(32, num_c)
)

xc_to_y = nn.Sequential(
    nn.Linear(num_x-num_xc+num_c, 32),
    nn.Sigmoid(),
    nn.Linear(32, num_y)
)

xc_activation = BinarySigmoid()
cy_activation = nn.Softmax(dim=1)

init_weights(x_to_c)
init_weights(xc_to_y)

In [5]:
def gen_x(num_x, batch_size):
    return torch.rand((batch_size, num_x))

x = gen_x(num_x, num_samples)
c_logits = x_to_c(x[:,:num_xc])
mean_c_logits = torch.mean(c_logits, dim=0)
x_to_c[-1].bias -= mean_c_logits
c_logits = x_to_c(x[:,:num_xc])
c = xc_activation(c_logits)
y_logits = xc_to_y(torch.concat((x[:,num_xc:],c), dim=1))
mean_y_logits = torch.mean(y_logits, dim=0)
xc_to_y[-1].bias -= mean_y_logits
y_logits = xc_to_y(torch.concat((x[:,num_xc:],c), dim=1))
# y = cy_activation(y_logits)
y = cy_activation(y_logits / torch.std(y_logits, dim=0))
y_argmax = torch.argmax(y, dim=1)
print(torch.unique(y_argmax, return_counts=True))

(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([ 8618,  8669,  7904,  6727, 10109, 14912,  8984, 12321, 11590, 10166]))


In [6]:
x_to_new_x = nn.Sequential(
    nn.Linear(num_x, num_x)
)
init_weights(x_to_new_x)
new_x = x_to_new_x(x)

In [18]:
x_to_c_test = nn.Sequential(
    nn.Linear(num_x, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, num_c),
    nn.Sigmoid()
).to(device)

xc_to_y_test = nn.Sequential(
    nn.Linear(num_c, 256),
    nn.ReLU(),
    nn.Linear(256, num_y),
    nn.Softmax(dim=1)
).to(device)

In [19]:
epochs = 50
batch_size = 10**4
y_criterion = nn.CrossEntropyLoss()
c_criterion = nn.BCELoss()
optimizer = torch.optim.Adam(x_to_c_test.parameters(), lr=0.001)
loss_meter = AverageMeter()
acc_meter = AverageMeter()


for epoch in range(epochs):
    loss_meter.reset()
    acc_meter.reset()
    for index in range(num_samples//batch_size-1):
        X, C, Y = new_x[index*batch_size:(index+1)*batch_size].to(device), c[index*batch_size:(index+1)*batch_size].to(device), y_argmax[index*batch_size:(index+1)*batch_size].to(device)
        c_pred = x_to_c_test(X)
        c_loss = c_criterion(c_pred, C)
        
        optimizer.zero_grad()
        c_loss.backward()
        optimizer.step()
        loss_meter.update(c_loss.item(), X.shape[0])
        acc_meter.update(multi_label_acc(c_pred, C), X.shape[0])
    # print(f"Epoch: {epoch} Loss: {loss_meter.avg} y_acc: {acc_meter.avg}")

loss_meter.reset()
acc_meter.reset()
index = num_samples//batch_size-1
X, C, Y = new_x[index*batch_size:(index+1)*batch_size].to(device), c[index*batch_size:(index+1)*batch_size].to(device), y_argmax[index*batch_size:(index+1)*batch_size].to(device)
c_pred = x_to_c_test(X)
c_loss = c_criterion(c_pred, C)

loss_meter.update(c_loss.item(), X.shape[0])
acc_meter.update(multi_label_acc(c_pred, C), X.shape[0])
print(f"Test Loss: {loss_meter.avg} y_acc: {acc_meter.avg}")

Test Loss: 0.3597686290740967 y_acc: 0.72986875


In [20]:
optimizer = torch.optim.Adam(xc_to_y_test.parameters(), lr=0.001)

for epoch in range(epochs):
    loss_meter.reset()
    acc_meter.reset()
    for index in range(num_samples//batch_size-1):
        X, C, Y = new_x[index*batch_size:(index+1)*batch_size].to(device), c[index*batch_size:(index+1)*batch_size].to(device), y_argmax[index*batch_size:(index+1)*batch_size].to(device)
        c_pred = torch.round(x_to_c_test(X))
        y_pred = xc_to_y_test(c_pred)
        y_loss = y_criterion(y_pred, Y)
        
        optimizer.zero_grad()
        y_loss.backward()
        optimizer.step()
        loss_meter.update(y_loss.item(), X.shape[0])
        acc_meter.update(classification_acc(y_pred, Y), X.shape[0])
    # print(f"Epoch: {epoch} Loss: {loss_meter.avg} y_acc: {acc_meter.avg}")

loss_meter.reset()
acc_meter.reset()
index = num_samples//batch_size-1
X, C, Y = new_x[index*batch_size:(index+1)*batch_size].to(device), c[index*batch_size:(index+1)*batch_size].to(device), y_argmax[index*batch_size:(index+1)*batch_size].to(device)
c_pred = torch.round(x_to_c_test(X))
y_pred = xc_to_y_test(c_pred)
y_loss = y_criterion(y_pred, Y)

loss_meter.update(y_loss.item(), X.shape[0])
acc_meter.update(classification_acc(y_pred, Y), X.shape[0])
print(f"Test Loss: {loss_meter.avg} y_acc: {acc_meter.avg}")

Test Loss: 2.085350275039673 y_acc: 0.3687
