In [None]:
#!/usr/bin/env python3
"""torch_logistic_regression.ipynb
James Gardner 2019
with help from Matthew Alger wrt pytorch

performs logistic regression on feature vectors
against positional matching labels using pytorch

required patch_catalogue.csv be present in cwd
as output by feature_vectors.ipynb as well as manual_labels.csv

will save the following:
weights.csv, predictions.csv, objects.csv, multi_objects.csv,
torch_lr_losses.pdf, torch_lr_weights.pdf, torch_lr_predictions.pdf, torch_lr_partition.pdf
"""

import csv
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt

import torch
from torch.autograd import Variable
import torch.nn.functional as F

In [None]:
# load the patch catalogue (load the dataset and make it iterable)
catalogue = pd.read_csv('patch_catalogue.csv')
catalogue.set_index(['name_TGSS','name_NVSS'],inplace=True)

scores = catalogue['score']
# remove positions, could test to see if it recovers separation?
del (catalogue['ra_TGSS'],catalogue['dec_TGSS'],
     catalogue['ra_NVSS'],catalogue['dec_NVSS'],
     catalogue['score'])

# these derived log features prove more useful than the regular values
catalogue['log_flux_TGSS']       = np.log10(catalogue['peak_TGSS'])
catalogue['log_integrated_TGSS'] = np.log10(catalogue['integrated_TGSS'])
catalogue['log_ratio_flux_TGSS'] = np.log10(catalogue['peak_TGSS']/
                                            catalogue['integrated_TGSS'])
catalogue['log_flux_NVSS']       = np.log10(catalogue['peak_NVSS'])

In [None]:
# create features and labels within pytorch
# scores are out of separation scorer, so 0.1 should likely be 0
labels = (scores.values > 0.1)
features = catalogue.values

# train on half the catalogue (A), predict against the whole thing (A+B)
labels_A = labels[::2]
# labels_B = labels[1::2]
features_A = features[::2]
# features_B = features[1::2]

labels_A = Variable(torch.from_numpy(labels_A).float())
# labels_B = Variable(torch.from_numpy(labels_B).float())
features_A = Variable(torch.Tensor(features_A))
# features_B = Variable(torch.Tensor(features_B))

In [None]:
# create the model class
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        # for our uses, the output layer is binary classification
        self.linear = torch.nn.Linear(input_dim, 1)
        
    def forward(self, x):
        # and here's the sigmoid!
        outputs = F.sigmoid(self.linear(x))
        return outputs

In [None]:
# set some of the training hyper-parameters
input_dim = features_A.shape[1]
# learning rate cf. time-step in physical simulations
learning_rate = 0.001
# an epoch is a total cycle through all the training data
# increase this value if the losses plot doesn't appear to stabilise
num_epochs = int(1e4)

In [None]:
# instantiate the model, criterion (i.e. loss), and optimizer classes
model = LogisticRegression(input_dim)
# binary cross entropy, standard use
criterion = torch.nn.BCELoss(size_average=True)
# stochastic gradient decent cf. unbiased estimate of a noisy observation
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
# train the model, this can take some time depending on num_epochs
losses = []
for epoch in tqdm(range(num_epochs)):
    # reset gradient accumulation
    optimizer.zero_grad()
    # forward step: predict and find loss
    predictions = model(features_A)
    loss = criterion(predictions, labels_A)
    # use .item() to stop memory leak to GPU, advice from M.Alger
    losses.append(loss.item())
    # backwards step: use loss to optimize a little bit
    loss.backward()
    optimizer.step()

In [None]:
# create a histogram of the loss trend, hopefully shows some stabilisation
# if it doesn't, try increasing num_epochs
plt.figure(figsize=(14,7))
plt.rcParams.update({'font.size': 18})
plt.plot(losses)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('torch_lr - losses showing stabilisation')
plt.savefig('torch_lr_losses.pdf',bbox_inches='tight') 

In [None]:
# save weights and bias to reconstruct model if needs be
parameters = list(model.parameters())
weights = parameters[0].detach().numpy().ravel()
bias = parameters[1].detach().numpy()

np.savetxt('weights.csv', np.concatenate((weights,bias)), delimiter=',')

In [None]:
# create plot of classifier weights
# of particular interest: separation, alpha, log_flux_NVSS
plt.figure(figsize=(14,7))
plt.rcParams.update({'font.size': 18})
plt.bar(range(len(weights)),weights)
plt.xlabel('weights')
plt.xticks(range(len(weights)),catalogue.columns,rotation='vertical')
plt.ylabel('co-eff')
plt.title('torch_lr - weights')
plt.savefig('torch_lr_weights.pdf',bbox_inches='tight') 

In [None]:
# classify the entire catalogue and compare to labels
features_cat = Variable(torch.Tensor(features))
predictions_cat = model(features_cat).detach().numpy()

# where the two populations cross is where we say the classifier decides the split
nc_100 = np.histogram(predictions_cat[labels == 0],bins=np.arange(0,1,0.01),density=True)[0]
pc_100 = np.histogram(predictions_cat[labels == 1],bins=np.arange(0,1,0.01),density=True)[0]
midpoint = 0.5
for i in range(len(nc_100)):
    if nc_100[i] < pc_100[i]:
        midpoint = 0.01*i
        break

pred_labels_cat = (predictions_cat > midpoint).astype(float)
pred_labels_cat = np.array([x[0] for x in pred_labels_cat])

# unconditional, note the native 63% negative class bias
accuracy = (pred_labels_cat == labels).mean()
# precision (true if said so)
precision = (labels[pred_labels_cat == True] == True).mean()
# recall (said so if true)
recall    = (pred_labels_cat[labels == True] == True).mean()
print(('over whole catalogue:\n accuracy = {0:.3f}, precision = {1:.3f}, recall = {2:.3f}')
      .format(accuracy,precision,recall))

# saves names of match and predicted label
catalogue['pred_labels'] = pred_labels_cat
catalogue['pred_labels'].to_csv('predictions.csv')

In [None]:
def inv_sigmoid(y):
    """given: y = 1/(e^-x+1)"""
    x = np.log(y/(1-y))
    return x

In [None]:
# create histogram of predictions with populations separated off of label
plt.rcParams.update({'font.size': 18})
fig, (ax1, ax2, ax3) = plt.subplots(3,figsize=(14,16))
ax1.set_title('logistic regression predictions \n \n score, h(x)')
ax2.set_title('class probability, g(x)')
ax3.set_title('class prediction, f(x)')
ax1.set_ylabel('pdf')
ax2.set_ylabel('pdf')
ax1.hist((inv_sigmoid(predictions_cat[labels == 0]),
          inv_sigmoid(predictions_cat[labels == 1])), bins=100,
         histtype='step', label=('negative class','postive class'), color = ('red','blue'), density = True)
ax1.legend()
ax2.hist((predictions_cat[labels == 0],predictions_cat[labels == 1]), bins=100,
         histtype='step', label=('negative class','postive class'), color = ('red','blue'), density = True)
ax2.legend()

negative_class = np.histogram(predictions_cat[labels == 0],bins=[0,midpoint,1],density=True)[0]
positive_class = np.histogram(predictions_cat[labels == 1],bins=[0,midpoint,1],density=True)[0]

ax3.bar(np.array((0,0.75)),negative_class,0.2, label='negative class', edgecolor='red', color='None')
ax3.bar(np.array((0.25,1)),positive_class,0.2, label='positive class', edgecolor='blue', color='None')
ax3.set_xticks((0,0.25,0.75,1))
ax3.set_xticklabels(('0','0','1','1'))
ax3.set_ylabel('pdf, binned as [0,{},1]'.format(midpoint))
ax3.legend()

plt.savefig('torch_lr_predictions.pdf',bbox_inches='tight')

In [None]:
# compute accuracy against manual labels
manual_labels = pd.read_csv('manual_labels.csv')
manual_labels.set_index(['name_TGSS','name_NVSS'],inplace=True)
man_cat = catalogue.loc[manual_labels.index.values]

label_man = manual_labels['manual_label'].values
pred_man = man_cat['pred_labels']

accuracy = (pred_man == label_man).mean()
precision = (label_man[pred_man == True] == True).mean()
recall    = (pred_man[label_man == True] == True).mean() 
print(('on manual labels:\n accuracy = {0:.3f}, precision = {1:.3f}, recall = {2:.3f}')
      .format(accuracy,precision,recall))

In [None]:
# partition the sky into physical objects using classifier
# we do this naively, by transitively linking together matches
# critical is that this naive partitioning can be bad given a good classifier
cat_pairs = set(catalogue.index.values)
obj_pairs = []

for pair in tqdm(cat_pairs):
    if catalogue.loc[pair]['pred_labels'] == 1:
        obj_pairs.append(pair)

objects = {}
tnames = {}
nnames = {}

index = 0
for pair in tqdm(obj_pairs):
    tname, nname = pair[0], pair[1]
    
    if not tname in tnames and not nname in nnames:
        i = index
        objects[i] = [tname,nname]
        tnames[tname] = i
        nnames[nname] = i
    elif tname in tnames and not nname in nnames:
        i = tnames[tname]
        objects[i].append(nname)
        nnames[nname] = i
    elif not tname in tnames and nname in nnames:
        i = nnames[nname]
        objects[i].append(tname)
        tnames[tname] = i
    elif tname in tnames and nname in nnames:
        # must merge objects, zig-zag problem
        i = tnames[tname]
        j = nnames[nname]
        if i == j:
            continue
        else:
            obj_i = objects[i]
            obj_j = objects[j]
            merged_obj = list(set(obj_i+obj_j))
            objects[index] = merged_obj
            del objects[i], objects[j] 
            for name in merged_obj:
                if   name[0] == 'T':
                    tnames[name] = index
                elif name[0] == 'N':
                    nnames[name] = index
        
    index += 1

In [None]:
# find the most interesting objects, those with many components
multi_objects = {}
most_components = 0
most_components_i = 0
for key, val in objects.items():
    if len(val) > 2:
        multi_objects[key] = val
        if len(val) > most_components:
            most_components = len(val)
            most_components_i = key
# the extreme amount of components here is a sign that the naive partioning is indeed naive
print(most_components, multi_objects[most_components_i])

In [None]:
# save object partition
def dict_to_csv(dict_to_convert, filename):
    values = []
    for val in dict_to_convert.values():
        values.append(val)

    with open(filename, 'w', newline = '') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerows(values)
    
dict_to_csv(objects,'objects.csv')
dict_to_csv(multi_objects,'multi_objects.csv')

In [None]:
def connect_the_dots(centre,field_of_view):
    """creates a picture of the object partition in the square
    about the centre (tuple of ra, dec in degrees) of fov (degrees)
    only shows the links between sources across surveys, i.e. 'matches'
    """
    c_ra, c_dec = centre
    w_fov = field_of_view
    
    lookup_cat = pd.read_csv('patch_catalogue.csv', usecols=['name_TGSS','name_NVSS',
                                                            'ra_TGSS','dec_TGSS','ra_NVSS','dec_NVSS'])
    lookup_cat.set_index(['name_TGSS','name_NVSS'],inplace=True)
    lookup_cat['pred_labels'] = pred_labels_cat

    # find all objects/links within a 3 degree window of centre
    window = lookup_cat[(lookup_cat['pred_labels']==1) &
                        (lookup_cat['ra_TGSS']>c_ra-w_fov) &
                        (lookup_cat['ra_TGSS']<c_ra+w_fov) &
                        (lookup_cat['dec_TGSS']>c_dec-w_fov) &
                        (lookup_cat['dec_TGSS']<c_dec+w_fov) &
                        (lookup_cat['ra_NVSS']>c_ra-w_fov) &
                        (lookup_cat['ra_NVSS']<c_ra+w_fov) &
                        (lookup_cat['dec_NVSS']>c_dec-w_fov) &
                        (lookup_cat['dec_NVSS']<c_dec+w_fov)]

    del window['pred_labels'], lookup_cat
    walues = window.values
    del window    
    tgss_x = np.reshape(walues[:,[0]],len(walues))
    tgss_y = np.reshape(walues[:,[1]],len(walues))
    nvss_x = np.reshape(walues[:,[2]],len(walues))
    nvss_y = np.reshape(walues[:,[3]],len(walues))
    
    plt.figure(figsize=(14,14))
    plt.rcParams.update({'font.size': 18})
    plt.plot(tgss_x,tgss_y,'r,')
    plt.plot(nvss_x,nvss_y,'b,')

    for i in tqdm(range(len(walues))):
        plt.plot([tgss_x[i],nvss_x[i]],[tgss_y[i],nvss_y[i]],'k-',linewidth=0.5)

    # invert x-axis to read as RA from right to left
    ax = plt.gca()
    xlim = ax.get_xlim()
    ax.set_xlim(xlim[::-1])
        
    plt.title('Naive partition of TGSS to NVSS in sky around {0:.2f},{1:.2f}'.format(c_ra,c_dec))
    plt.ylabel('DEC / °')
    plt.xlabel('RA / °')
    plt.savefig('torch_lr_partition.pdf',bbox_inches='tight')

In [None]:
# interesting centre candidates:
# 153.65,-27.09 | J101436.8-270532, for the many-component object about the centre
# 166.10,-27.16, 158.60 -15.58, 152.64,-18.01
centre = 153.65,-27.09
field_of_view = 5
connect_the_dots(centre,field_of_view)