In [22]:
import os
import sys
from PIL import Image
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from tensorboardX import SummaryWriter

sys.path.append('..')
from resnet import resnet18

# Collect ours

In [26]:
def train_val_test_split(data, val_size=0.15, test_size=0.15, random_state=9):
    assert val_size + test_size < 1
    val_test_size = val_size + test_size
    test_relative_size = test_size / val_test_size
    
    indices = range(data.shape[0])
    train_indices, test_indices = train_test_split(indices, test_size=val_test_size, random_state=random_state)
    val_indices, test_indices = train_test_split(test_indices, test_size=test_relative_size,
                                                 random_state=random_state)
    
    return data[train_indices], data[val_indices], data[test_indices]

def crop_center_or_pad(img, new_side):
    _, y, x = img.shape
    if x > new_side:
        startx = x//2-(new_side//2)
        return img[:, startx:startx+new_side,startx:startx+new_side]
    elif x < new_side:
        padx = (new_side//2) - x//2
        return np.pad(img, ((0,), (padx,), (padx,)), mode='constant', constant_values=-1)
    return img

In [7]:
root='../rebuild_dataset'
sz_names = np.load(f'{root}/sz_oleg.npz')['sz_names']
_, _, oleg_test = train_val_test_split(sz_names)
oleg_all = [f'PSZ2 G{name}' for name in sz_names]
oleg_test = [f'PSZ2 G{name}' for name in oleg_test]

In [18]:
class Planck(Dataset):
    def __init__(self, sz, split='None'):
        _, _, sz_test = train_val_test_split(sz)
        self.X = sz_test
        self.y = np.array([1] * sz_test.shape[0], dtype=np.uint8)

    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        X, y = self.X[idx], self.y[idx]

        X = X[:, ::2, ::2]
        X = crop_center_or_pad(X, 128)
        return torch.from_numpy(X), y

In [53]:
root='../rebuild_dataset'
sz = np.load(f'{root}/sz_oleg.npz')['sz_data'].astype(np.float32)
print('sz.shape', sz.shape)

test_dataset = Planck(sz, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

sz.shape (1000, 5, 256, 256)


In [72]:
checkpoint_name = '../checkpoints/v2_128_oleg_14/net_best.pt'
threshold = 0.4
device = torch.device('cpu')

model = resnet18(num_classes=1).to(device)
model.load_state_dict(torch.load(checkpoint_name, map_location=device))
model.eval();

In [73]:
ours_pred = []
for (X, y_true), name in zip(test_dataloader, oleg_test):
    X, y_true = X.to(device), y_true.unsqueeze(1).float().to(device)
    y_pred = torch.sigmoid(model(X))

    if y_pred.item() > threshold:
        ours_pred.append(name)

In [74]:
ours_pred = set(ours_pred)
len(ours_pred)

122

# Now find intersection

In [75]:
with open('./mmf1.tsv') as f:
    mmf1 = set([line[:-1] for line in f])
with open('./mmf3.tsv') as f:
    mmf3 = set([line[:-1] for line in f])
with open('./pws.tsv') as f:
    pws = set([line[:-1] for line in f])
with open('./all.tsv') as f:
    all = set([line[:-1] for line in f])
# with open('./catalogue_comparison/ours.tsv') as f:
#     ours_test = set([line[:-1] for line in f])
# with open('./catalogue_comparison/ours_test.tsv') as f:
#     ours_test = set([line[:-1] for line in f])

In [76]:
mmf1_score = len(mmf1.intersection(oleg_all)) / len(oleg_all)
mmf3_score = len(mmf3.intersection(oleg_all)) / len(oleg_all)
pws_score = len(pws.intersection(oleg_all)) / len(oleg_all)
all_score = len(all.intersection(oleg_all)) / len(oleg_all)
ours_score = len(ours_pred.intersection(oleg_all)) / len(oleg_all)

print(f'mmf1: {mmf1_score:.2f}')
print(f'mmf3: {mmf3_score:.2f}')
print(f'pws: {pws_score:.2f}')
print(f'all: {all_score:.2f}')
print(f'ours: {ours_score:.2f}')

mmf1: 0.75
mmf3: 0.77
pws: 0.64
all: 1.00
ours: 0.12


In [78]:
mmf1_score = len(mmf1.intersection(oleg_test)) / len(oleg_test)
mmf3_score = len(mmf3.intersection(oleg_test)) / len(oleg_test)
pws_score = len(pws.intersection(oleg_test)) / len(oleg_test)
all_score = len(all.intersection(oleg_test)) / len(oleg_test)
ours_score = len(ours_pred.intersection(oleg_test)) / len(oleg_test)

print(f'mmf1: {mmf1_score:.3f}')
print(f'mmf3: {mmf3_score:.3f}')
print(f'pws: {pws_score:.3f}')
print(f'all: {all_score:.3f}')
print(f'ours: {ours_score:.3f}')

mmf1: 0.747
mmf3: 0.740
pws: 0.620
all: 1.000
ours: 0.813


In [52]:
len(ours_test)

NameError: name 'ours_test' is not defined