# Mine Homologies in ImageNet

Author: YinTaiChen

## Import packages

In [1]:
import dionysus as d

import numpy
import random

import torch
import torchvision
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms

## Some parameters

In [2]:
N1 = 100
N2 = 100
DIMENSION = 4

transform = transforms.Compose(
    [transforms.Scale(size=(256,256)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Get pretrained alexnet

In [3]:
model = torchvision.models.alexnet(pretrained=True)
mod = list(model.classifier.children())
mod.pop()
model.classifier = nn.Sequential(*mod)

## Overwrite torch.utils.data.Dataset to randomly pick 20 images from each class

In [4]:
def default_loader(path):
    return Image.open(path).convert('RGB')

def default_flist_reader(flist, classes):
    """
    flist format: impath label\nimpath label\n ...(same to caffe's filelist)
    """
    imlist = []
    class_list = []
    label = 0
    with open(flist, 'r') as rf:
        for i, line in enumerate(rf.readlines()):
            impath, imlabel = line.strip().split()
            if label == imlabel:
                class_list.append( (impath, int(imlabel)) )
            else:
                imlist.append(class_list)
                class_list = []
                class_list.append( (impath, int(imlabel)) )
                label += 1
            if i == len(rf.readlines()) - 1:
                class_list.append( (impath, int(imlabel)) )
                imlist.append(class_list)
                class_list = []
    random_20_list = []
    for c in classes:
        for _ in range(20):
            # randomly choose 20 images in that class
            index = random.randint(0, len(imlist[c] - 1))
            random_20_list.append(imlist[c][index])
    return random_20_list

class ImageFilelist(data.Dataset):
    def __init__(self, root, flist, transform=None, target_transform=None,
        flist_reader=default_flist_reader, loader=default_loader, classes=None):
        self.root = root
        self.imlist = flist_reader(flist, classes)
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        impath, target = self.imlist[index]
        img = self.loader(os.path.join(self.root,impath))
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imlist)

## The algorithm

In [None]:
for alpha in range(N1):
    classes = random.sample(range(0, 1000), 10)
        
    dataset = ImageFilelist(root="./data", flist="dataset.txt", transform=transform, classes=classes)
    dataloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=3)
    feature_vectors = []
    
    for i, data in enumerate(dataloader):
        input, label = data
        input, label = Variable(input), Variable(label)
        output = net(input)
        feature_vectors.append(output.data.numpy())
        
    for beta in range(N2):
        chosen_features = []
        percent = random.randint(40, 60) / 100
        
        for _ in range(int(len(feature_vectors) * percent)):
            index = random.randint(0, len(feature_vectors) - 1)
            chosen_features.append(feature_vectors[index])
            
        min_dist_list = []
        
        for v in feature_vectors:
            dist_list = []
            
            for w in chosen_features:
                distance = numpy.linalg.norm(v-w)
                dist_list.append(distance)
                
            min_dist_list.append(min(dist_list))
            
        RADIUS = max(min_dist_list)
        f = d.fill_rips(chosen_features, DIMENSION, RADIUS)
        with open ('javaplex_'+string(alpha)+'_'+string(beta)+'.m', 'w') as file:
            pre_string = "stream.addElement("
            post_string = ");\n"
            for s in f:
                if len(s) == 1:
                    middle_string = '['+str(s[0])+']'
                elif len(s) == 2:
                    middle_string = '['+str(s[0])+','+str(s[1])+']'
                elif len(s) == 3:
                    middle_string = '['+str(s[0])+','+str(s[1])+','+str(s[2])+']'
                elif len(s) == 4:
                    middle_string = '['+str(s[0])+','+str(s[1])+','+str(s[2])+','+str(s[3])+']'
                file.write(pre_string+middle_string+post_string)