In [None]:
# 4 Logistic data regression

In [1]:
import torch

## 4.1 Example data
- input is a set of three different data sources (3 columns), with possible scenarios (8 rows)
- output is the ground-truth diagnosis (1-positive, 0-negative)

In [2]:

## an example of simplified rules
# input - 3 columns are individual modalities 
X = torch.tensor([
    [1.0, 0.0, 0.0],
    [1.0, 0.0, 1.0],
    [1.0, 1.0, 0.0],
    [1.0, 1.0, 1.0],
    [0.0, 0.0, 0.0],
    [0.0, 0.0, 1.0],
    [0.0, 1.0, 0.0],
    [0.0, 1.0, 1.0]   
    ])
# output - 0 and 1 being negative and positive, respectively
y = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0]).unsqueeze(1) # to test invariance to 3rd var X[,2]


## 4.2 Linear regression
### Weighted-sum without bias: X*a=y 

In [3]:
# least-squares solution: (X'X)^(-1)(X'X)a = (X'X)^(-1)X'y => a = (X'X)^(-1)X'y
a_nobias = torch.linalg.pinv(X) @ y
res_nobias = X @ a_nobias - y
print(a_nobias) # modality weights
print(torch.mean(res_nobias**2)) # mean-square residual
# NB: large residuals and a[2] ~= 0


tensor([[0.6250],
        [0.6250],
        [0.1250]])
tensor(0.0781)


### Weighted-sum with bias: [X,1s]*a=y

In [4]:
X1 = torch.cat([X,torch.ones(8,1)],dim=1)
a = torch.linalg.pinv(X1) @ y
res = X1 @ a - y
print(a)  # modality weights
print(torch.mean(res**2))  # mean-square residual
# NB: smaller residuals and now a[2] = 0

tensor([[5.0000e-01],
        [5.0000e-01],
        [2.9802e-08],
        [2.5000e-01]])
tensor(0.0625)


## 4.3 Logistic regression

In [5]:
model = torch.nn.Sequential(torch.nn.Linear(3, 1, bias=True), torch.nn.Sigmoid())
cross_entropy = torch.nn.BCELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-6)

for iter in range(int(1e4)):
    optimiser.zero_grad()
    pred = model(X) # bias was dealt in nn.Linear
    loss = cross_entropy(pred, y)
    loss.backward()
    optimiser.step()

a_lreg = torch.cat([model[0].weight.data.squeeze(), model[0].bias.data],0)
res_lreg = pred - y

print(a_lreg)  # modality weights
print(torch.mean(res_lreg**2))  # mean-square residual

tensor([15.6619, 15.8789, -0.1659, -7.3679])
tensor(1.1634e-07, grad_fn=<MeanBackward0>)
