# Hyperbolic & hierarchical image classification

In [None]:
# Imports
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import sys
import re
import datetime
import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.io import ImageReadMode, read_image
import torch.nn as nn
import torch.nn.functional as F
import csv

# installing geoopt
!pip install -q git+https://github.com/geoopt/geoopt.git
! [ ! -f mobius_linear_example.py ] && wget -q https://raw.githubusercontent.com/geoopt/geoopt/master/examples/mobius_linear_example.py
import geoopt

import argparse
import random
import math
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from   torch.autograd import Variable
import matplotlib.pyplot as plt

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for geoopt (pyproject.toml) ... [?25l[?25hdone


In [None]:
def get_class_mapping(dataset):
  if dataset == 'cifar100':
    class_mapping = {
    0: "apple",
    1: "aquarium_fish",
    2: "baby",
    3: "bear",
    4: "beaver",
    5: "bed",
    6: "bee",
    7: "beetle",
    8: "bicycle",
    9: "bottle",
    10: "bowl",
    11: "boy",
    12: "bridge",
    13: "bus",
    14: "butterfly",
    15: "camel",
    16: "can",
    17: "castle",
    18: "caterpillar",
    19: "cattle",
    20: "chair",
    21: "chimpanzee",
    22: "clock",
    23: "cloud",
    24: "cockroach",
    25: "couch",
    26: "crab",
    27: "crocodile",
    28: "cup",
    29: "dinosaur",
    30: "dolphin",
    31: "elephant",
    32: "flatfish",
    33: "forest",
    34: "fox",
    35: "girl",
    36: "hamster",
    37: "house",
    38: "kangaroo",
    39: "keyboard",
    40: "lamp",
    41: "lawn_mower",
    42: "leopard",
    43: "lion",
    44: "lizard",
    45: "lobster",
    46: "man",
    47: "maple",
    48: "motorcycle",
    49: "mountain",
    50: "mouse",
    51: "mushroom",
    52: "oak",
    53: "orange",
    54: "orchid",
    55: "otter",
    56: "palm",
    57: "pear",
    58: "pickup_truck",
    59: "pine",
    60: "plain",
    61: "plate",
    62: "poppy",
    63: "porcupine",
    64: "possum",
    65: "rabbit",
    66: "raccoon",
    67: "ray",
    68: "road",
    69: "rocket",
    70: "rose",
    71: "sea",
    72: "seal",
    73: "shark",
    74: "shrew",
    75: "skunk",
    76: "skyscraper",
    77: "snail",
    78: "snake",
    79: "spider",
    80: "squirrel",
    81: "streetcar",
    82: "sunflower",
    83: "sweet_pepper",
    84: "table",
    85: "tank",
    86: "telephone",
    87: "television",
    88: "tiger",
    89: "tractor",
    90: "train",
    91: "trout",
    92: "tulip",
    93: "turtle",
    94: "wardrobe",
    95: "whale",
    96: "willow",
    97: "wolf",
    98: "woman",
    99: "worm"}
  return class_mapping

def get_embedding(dataset, hierarch):
    if dataset == 'cifar100':
      if hierarch == 'balanced':
        embedding = torch.load('/content/balanced.pth.best')
      if hierarch == 'base':
        embedding = torch.load('/content/base.pth.best')
      if hierarch == 'expert':
        embedding = torch.load('/content/expert.pth.best')
      if hierarch == 'random':
        embedding = torch.load('/content/random.pth.best')
      if hierarch == 'root':
        embedding = torch.load('/content/root.pth.best')
      if hierarch == 'size':
        embedding = torch.load('/content/size.pth.best')
    return embedding

def get_translator(mapping, embedding):
  embedding_list = embedding['objects']
  tensors = {}
  for i in range(len(mapping)):
    map = mapping[i]
    index = [j for j, x in enumerate(embedding_list) if x == map+'.n01']
    tensors[i] = embedding['embeddings'][index]
  return tensors

def translating(inputs, translator):
  tensors = []
  for item in inputs:
    tensors.append(translator[item.tolist()])

  concatenated_tensor = torch.cat(tensors, dim=0)
  reshaped_tensors = concatenated_tensor.reshape(len(inputs), 64)
  return reshaped_tensors

def get_leaf_nodes_tensors(inputs, translator):
  tensors = []
  for item in inputs:
    tensors.append(translator[item])
  return tensors

DATASET = "cifar100"
HIERARCH = "balanced"
embedding = get_embedding(DATASET, HIERARCH)
class_mapping = get_class_mapping(DATASET)
translator = get_translator(class_mapping, embedding)

leaf_nodes_tensors = get_leaf_nodes_tensors(range(0,100,1), translator)
leaf_nodes_tensors = torch.cat(leaf_nodes_tensors, dim=0)

np.save("prototypes_" + HIERARCH + '_64dim', leaf_nodes_tensors, allow_pickle=True, fix_imports=True)

In [None]:
# HIERARCHY CLASS
def add_relations_from_csv(csv_file, hierarchy):
    with open(csv_file, 'r') as file:
        csv_reader = csv.reader(file)
        next(csv_reader)  # Skip the header row if present

        for row in csv_reader:
            if len(row) >= 2:
                child, parent = row[:2]  # Extract the first two columns
                hierarchy.add_relation(child, parent)

class Hierarchy:
  def __init__(self):
      self.hierarchy = {}

  def add_relation(self, child, parent):
      if child not in self.hierarchy:
          self.hierarchy[child] = []
      self.hierarchy[child].append(parent)

  def get_parents(self, child):
      if child in self.hierarchy:
          return self.hierarchy[child]
      return []

  def get_siblings(self, child):
      parents = self.get_parents(child)
      siblings = []
      for parent in parents:
          siblings.extend(self.get_children(parent))
      if child in siblings:
          siblings.remove(child)
      return siblings

  def get_children(self, parent):
      children = []
      for child, parents in self.hierarchy.items():
          if parent in parents:
              children.append(child)
      return children

  def get_cousins(self, child):
      parents = self.get_parents(child)
      cousins = []
      for parent in parents:
          siblings = self.get_siblings(parent)
          cousins.extend(siblings)
      if child in cousins:
          cousins.remove(child)
      return cousins

  def get_grandparents(self, child):
      parents = self.get_parents(child)
      grandparents = []
      for parent in parents:
          grandparents.extend(self.get_parents(parent))
      grandparents = list(set(grandparents))  # Remove duplicates
      return grandparents

  def find_lowest_common_ancestor(self, class1, class2):
      if class1 not in self.hierarchy or class2 not in self.hierarchy:
          return None

      path1 = self.get_path_to_root(class1)
      path2 = self.get_path_to_root(class2)

      lowest_common_ancestor = None
      for i in range(min(len(path1), len(path2))):
          if path1[i] == path2[i]:
              lowest_common_ancestor = path1[i]
          else:
              break

      return lowest_common_ancestor

  def get_path_to_root(self, class_name):
      path = []
      current = class_name

      while current != "root.n01":
          parents = self.get_parents(current)
          if not parents:
              break
          current = parents[0]
          path.append(current)

      return path[::-1]  # Reverse the path

  def find_step_distance(self, class1, class2):
      if class1 == class2:
          return 0
      if class1 == "root.n01":
          return len(self.get_path_to_root(class2))
      if class2 == "root.n01":
          return len(self.get_path_to_root(class1))

      path1 = self.get_path_to_root(class1)
      path2 = self.get_path_to_root(class2)

      # Check if class1 is an ancestor of class2
      if class1 in path2:
          return len(path2) - path2.index(class1)

      # Check if class2 is an ancestor of class1
      if class2 in path1:
          return len(path1) - path1.index(class2)

      lowest_common_ancestor = self.find_lowest_common_ancestor(class1, class2)
      if lowest_common_ancestor:
          distance1 = len(path1) - path1.index(lowest_common_ancestor)
          distance2 = len(path2) - path2.index(lowest_common_ancestor)
          return distance1+distance2
      return None

# HIERARCHY FUNCS
hierarchy = Hierarchy()
csv_file = DATASET + '_' + HIERARCH + '.csv'
add_relations_from_csv(csv_file, hierarchy)

# Check if two classes are siblings
def is_sibling(class1, class2):
    siblings = hierarchy.get_siblings(class1)
    return class2 in siblings

# Check if two classes are cousins
def is_cousin(class1, class2):
    uncles = hierarchy.get_cousins(class1)
    for item in uncles:
      if is_parent(item, class2):
        return True
    return False

# Check if one class is a parent of another classb
def is_parent(class1, class2):
    children = hierarchy.get_children(class1)
    return class2 in children

# Check if one class is a grandparent of another class
def is_grandparent(class1, class2):
    grandparents = hierarchy.get_grandparents(class2)
    return class1 in grandparents

# find lowest common ancestor
def lca(class1, class2):
    lca = hierarchy.find_lowest_common_ancestor(class1, class2)
    return lca

# calculate step distances between classes
def step_distance(class1, class2):
  distance = hierarchy.find_step_distance(class1, class2)
  return distance

def step_distance_from_lca(class1, class2):
  lowest_ancestor = lca(class1, class2)
  distance = step_distance(class1, lowest_ancestor)
  return distance

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_model(network, dims, prototypes=None, tau=1.0, curvature=1.00):
    if network == "resnet34":
        model = resnet34(dims, prototypes, tau, curvature)
    elif network == "resnet50":
        model = resnet50(dims, prototypes)
    return model.cuda()

def get_cifar100(batch_size, ex_class=-1):
    cmean = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    cstd  = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

    # Transforms.
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(cmean, cstd)
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cmean, cstd)
    ])

    # Get train loader.
    train_data = torchvision.datasets.CIFAR100(root="data/", train=True, transform=transform_train,download=True)
    train_loader = DataLoader(train_data, shuffle=True, num_workers=1, batch_size=batch_size)

    # Get test loader.
    test_data = torchvision.datasets.CIFAR100(root="data/", train=False, transform=transform_test)
    test_loader = DataLoader(test_data, shuffle=False, num_workers=1, batch_size=batch_size)

    return train_loader, test_loader

In [None]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion))

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion))

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class BottleNeck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion))

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))

class ResNet(nn.Module):
    def __init__(self, block, num_block, num_classes=64, prototypes=None, tau=1.0, curvature=1.0):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, 64)

        # Fixed final layer.
        self.prototypes = prototypes
        self.ball = geoopt.PoincareBallExact(c=curvature)
        self.tau = tau

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, y=None, l=None):
        output = self.conv1(x)
        output = self.conv2_x(output)
        output = self.conv3_x(output)
        output = self.conv4_x(output)
        output = self.conv5_x(output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)
        output = self.ball.expmap0(output)

        if self.prototypes is not None:
            output = self.ball.expmap0(output)
            output = -self.ball.dist(output[:,None,:], self.prototypes) * self.tau # output[:,None,:]
        return output

def resnet34(dims, prototypes, tau, curvature):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=dims, prototypes=prototypes, tau=tau, curvature=curvature)

def resnet50(dims, prototypes):
    return ResNet(BottleNeck, [3, 4, 6, 3], num_classes=dims, prototypes=prototypes)

In [None]:
def set_seed(seed):
    """Set seed"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

def train(model, train_loader, loss_function, epoch, prototypes=None):
    model.train()
    total = 0
    avgloss = 0.
    itercount = 0.
    n_steps_per_epoch = len(train_loader.dataset) / train_loader.batch_size
    # Iterate over all samples.
    for batch_index, (images, labels) in enumerate(train_loader):
        total += len(images)

        # Images and labels to GPU.
        images = images.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()

        # Forward propagation.
        outputs = model(images)
        # tensors = translating(labels, translator).cuda()
        # loss = loss_function(outputs, tensors)
        loss = loss_function(outputs, labels)

        # Backward propagation.
        avgloss += loss
        itercount += 1

        loss.backward()
        optimizer.step()
    return avgloss / itercount

def test(model, test_loader, loss_function):
    model.eval()

    step_distances = []
    step_distances_from_lca = []
    correct = 0.0
    sibling = 0.0
    cousin = 0.0
    total_test_loss = 0.0
    itercount = 0.0
    for (images, labels) in test_loader:
        images = images.cuda()
        labels = labels.cuda()

        # Forward propagation.
        outputs = model(images)

        # tensors = translating(labels, translator).cuda()
        # test_loss = loss_function(outputs, tensors)
        test_loss = loss_function(outputs, labels)

        _, predictions = outputs.max(1)

        # Evaluation
        for pred, label in zip(predictions, labels):
          pred, label = pred.item(), label.item()

          if pred == label: #accuracy
            correct += 1
          if is_sibling(class_mapping[pred] + '.n01', class_mapping[label] + '.n01'): #siblings
            sibling += 1
          if is_cousin(class_mapping[pred] + '.n01', class_mapping[label] + '.n01'): #cousins
            cousin += 1
          step_distances_from_lca.append(step_distance_from_lca(class_mapping[pred]+'.n01', class_mapping[label]+'.n01'))
          step_distances.append(step_distance(class_mapping[pred]+'.n01', class_mapping[label]+'.n01'))

        # loss logging
        itercount += 1
        total_test_loss += test_loss.item()

    sibling_acc = sibling / len(test_loader.dataset)
    cousin_acc = cousin / len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)
    loss = total_test_loss / itercount
    return acc, loss, sibling_acc, cousin_acc, step_distances_from_lca, step_distances

if __name__ == '__main__':
    # Parse user arguments.
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", dest="network", default="resnet34", type=str)
    parser.add_argument("-d", dest="dataset", default="cifar100", type=str)
    parser.add_argument("-b", dest="batch_size", default=512, type=int)
    parser.add_argument("-l", dest="learning_rate", default=0.1, type=float)
    parser.add_argument("-u", dest="use_scheduler", default=1, type=int)
    parser.add_argument("-s", dest="scale_factor", default=0.95, type=float)
    parser.add_argument("-t", dest="tau", default=10, type=float)
    parser.add_argument("-e", dest="epochs", default=200, type=int)
    parser.add_argument("-f", dest="resfile", default="", type=str)
    parser.add_argument("--prot", dest="prototype_file", default="prototypes_"+ HIERARCH +"_64dim.npy", type=str)
    parser.add_argument("--c", dest="curvature", default=1, type=float)
    parser.add_argument("--seed", dest="seed", default=42, type=int)
    args = parser.parse_args()
    set_seed(args.seed)

    # Get data.
    if args.dataset == "cifar100":
        train_loader, test_loader = get_cifar100(args.batch_size, 0)
        nr_classes = 100

    # prototypes = None
    prototypes = (torch.from_numpy(np.load(args.prototype_file)).float()).cuda()
    # prototypes = leaf_nodes_tensors.cuda()
    prototypes = F.normalize(prototypes, p=2, dim=1)
    prototypes = prototypes * args.scale_factor / math.sqrt(args.curvature)

    # Get network, loss function, and optimizer.
    model = get_model(args.network, nr_classes, prototypes, args.tau, args.curvature).cuda()
    loss_function = nn.CrossEntropyLoss()
    # loss_function = hyperbolic_loss
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=5e-4)
    train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,100,150], gamma=0.2)

    print("------------------------------------------------")
    # Perform training and periodic testing.
    for epoch in range(args.epochs):

        # Train for one epoch.
        train_loss = train(model, train_loader, loss_function, epoch, prototypes)

        # Test.
        if epoch % 10 == 0 or epoch == args.epochs -1:
            acc, test_loss, sibling_acc, cousin_acc, step_distances_from_lca, step_distances = test(model, test_loader, loss_function)

            print('[', epoch, ']')
            print('Train Loss:  ', round(train_loss.item(), 4))
            print('Test Loss:   ', round(test_loss, 4))
            print('Accuracy:    ', round(acc * 100, 2), '%')
            print('Sibling acc: ', round(sibling_acc * 100, 2), '%')
            print('Cousin acc:  ', round(cousin_acc * 100, 2), '%')
            print('Steps:       ', round(np.mean(step_distances), 4))
            print('LCA steps:   ', round(np.mean(step_distances_from_lca), 4))
            print("------------------------------------------------")

        # Learning rate scheduler update.
        if args.use_scheduler == 1:
            train_scheduler.step()

    print()

Files already downloaded and verified
------------------------------------------------
[ 0 ]
Train Loss:   5.2442
Test Loss:    4.8027
Accuracy:     1.54 %
Sibling acc:  3.71 %
Cousin acc:   4.49 %
Steps:        9.5604
LCA steps:    4.7956
------------------------------------------------
[ 10 ]
Train Loss:   4.0849
Test Loss:    4.1118
Accuracy:     4.92 %
Sibling acc:  13.81 %
Cousin acc:   4.03 %
Steps:        8.6216
LCA steps:    4.36
------------------------------------------------
[ 20 ]
Train Loss:   4.5925
Test Loss:    4.6326
Accuracy:     1.03 %
Sibling acc:  2.14 %
Cousin acc:   2.7 %
Steps:        10.098
LCA steps:    5.0593
------------------------------------------------
[ 30 ]
Train Loss:   3.4397
Test Loss:    3.5603
Accuracy:     11.19 %
Sibling acc:  24.5 %
Cousin acc:   2.62 %
Steps:        7.1598
LCA steps:    3.6918
------------------------------------------------
[ 40 ]
Train Loss:   2.9765
Test Loss:    3.419
Accuracy:     14.0 %
Sibling acc:  28.97 %
Cousin acc: 

## Helper functions

In [None]:
# a method to create a random hierarchichal structure. A csv is then manually created using the result
def generate_random_integer_list(sum_value, min_length=2, min_value=2, max_overall_value=5):
    random_list = []
    while sum_value > 0:
        max_value = min(sum_value, min_value + (sum_value - min_length + 1))
        if max_value < min_value:
            break
        rand_int = random.randint(min_value, min(max_value, max_overall_value))
        random_list.append(rand_int)
        sum_value -= rand_int

    while len(random_list) < min_length:
        rand_int = random.randint(min_value, sum_value)
        random_list.append(rand_int)
        sum_value -= rand_int

    return random_list

depth = 3
leaves_vals = []
total = 0
results_list = []

leaves = generate_random_integer_list(100, max_overall_value=6)
for entry in leaves:
    curr = total
    leaves_vals.append([i for i in range(curr, curr+entry)])
    total += entry

while results_list == []:
    try:
        results_list.append(leaves_vals)
        for i in range(depth):
            results_list.insert(0, generate_random_integer_list(len(results_list[0]), max_overall_value=5))
    except ValueError:
        results_list = []
        pass

print(results_list)
