# Libraries

In [1]:
import pandas as pd
from anytree import Node
from preprocessing import *
import torch
import torch.nn as nn
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

# Populate Taxonomy

In [2]:
root = Node("object", rank="root")

# Create root nodes
marine_life = Node("marine life", parent=root, rank="binary")
inanimate = Node("inanimate", parent=root, rank="binary")

# Create class nodes under the respective root nodes
asteroidea = Node("asteroidea", parent=marine_life, rank="class")
phaeophyceae = Node("phaeophyceae", parent=marine_life, rank="class")
bivalia = Node("bivalia", parent=marine_life, rank="class")
myxini = Node("myxini", parent=marine_life, rank="class")
artificial = Node("artificial", parent=inanimate, rank="class")
natural = Node("natural", parent=inanimate, rank="class")
chlorophyta = Node("chlorophyta", parent=marine_life, rank="class")
monocots = Node("monocots", parent=marine_life, rank="class")

# Create genus nodes under the respective class nodes
asterias = Node("asterias", parent=asteroidea, rank="genus")
fucus = Node("fucus", parent=phaeophyceae, rank="genus")
henrica = Node("Henrica", parent=asteroidea, rank="genus")
mya = Node("mya", parent=bivalia, rank="genus")
myxine = Node("myxine", parent=myxini, rank="genus")
cylindrical = Node("cylindrical", parent=artificial, rank="genus")
solid = Node("solid", parent=natural, rank="genus")
arboral = Node("arboral", parent=natural, rank="genus")
saccharina = Node("saccharina", parent=phaeophyceae, rank="genus")
ulva = Node("ulva", parent=chlorophyta, rank="genus")
urospora = Node("Urospora", parent=chlorophyta, rank="genus")
zostera = Node("zostera", parent=monocots, rank="genus")

# Create species nodes under the respective genus nodes
asterias_rubens = Node("asterias rubens", parent=asterias, rank="species")
fucus_vesiculosus = Node("fucus vesiculosus", parent=fucus, rank="species")
henrica_species = Node("henrica", parent=henrica, rank="species")  # Assuming "henrica" is a species
mytilus_edulis = Node("mytilus edulis", parent=mya, rank="species")
myxine_glurinosa = Node("myxine glurinosa", parent=myxine, rank="species")
pipe = Node("pipe", parent=cylindrical, rank="species")
rock = Node("rock", parent=solid, rank="species")
saccharina_latissima = Node("saccharina latissima", parent=saccharina, rank="species")
tree = Node("tree", parent=arboral, rank="species")
ulva_intestinalis = Node("ulva intestinalis", parent=ulva, rank="species")
urospora_species = Node("urospora", parent=urospora, rank="species")
zostera_marina = Node("zostera marina", parent=zostera, rank="species")

classes_file = '/mnt/RAID/datasets/label-studio/fjord/classes.txt'

species_names = []
with open(classes_file, 'r') as file:
    species_names = [line.strip() for line in file]

genus_names = [node.name for node in root.descendants if node.rank == 'genus']
class_names = [node.name for node in root.descendants if node.rank == 'class']
binary_names = [node.name for node in root.descendants if node.rank == 'binary']

# Read Dataset

In [3]:
df = pd.read_parquet('/mnt/RAID/projects/FjordVision/train_dataset.parquet')

# Pytorch Model

In [4]:
class BranchCNN(nn.Module):
    def __init__(self, num_classes, num_additional_features):
        super(BranchCNN, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(128 * 8 * 8 + num_additional_features, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 256),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x, additional_features):
        x = x.view(x.size(0), -1)
        combined_input = torch.cat((x, additional_features), dim=1)
        return self.fc_layers(combined_input)

class HierarchicalCNN(nn.Module):
    def __init__(self, num_classes_hierarchy, num_additional_features):
        super(HierarchicalCNN, self).__init__()

        # Convolutional layers for feature extraction
        self.conv_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(3, 128, kernel_size=3, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ),
            nn.Sequential(
                nn.Conv2d(128, 256, kernel_size=3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ),
            nn.Sequential(
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ),
            nn.Sequential(
                nn.Conv2d(512, 512, kernel_size=3, padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            )
        ])

        # Branches for each hierarchical level
        self.branches = nn.ModuleList([
            BranchCNN(num_classes, num_additional_features)
            for num_classes in num_classes_hierarchy
        ])

    def forward(self, x, conf, iou, pred_species):
        outputs = []

        # Reshape tensors to have the same dimensions
        conf = conf.view(-1, 1)
        iou = iou.view(-1, 1)
        pred_species = pred_species.view(-1, 1)  # Reshape pred_species

        # Concatenate additional inputs
        additional_features = torch.cat((conf, iou, pred_species), dim=1)

        for conv_layer, branch in zip(self.conv_layers, self.branches):
            x = conv_layer(x)
            branch_output = branch(x, additional_features)
            outputs.append(branch_output)

        return outputs


# Create a defaultdict to store the counts for each rank
rank_counts = defaultdict(int)

# Iterate over the nodes of the tree
for node in root.descendants:
    rank = node.rank
    rank_counts[rank] += 1

# Example instantiation of the model
num_classes_hierarchy = list(rank_counts.values())  # Example: [num_species, num_genus, num_class, num_binary]
num_additional_features = 3  # Assuming 3 additional features: conf, iou, pred_species

model = HierarchicalCNN(num_classes_hierarchy, num_additional_features)

# Dataloader

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

class CustomDataset(Dataset):
    def __init__(self, dataframe, species_names, transform=None):
        self.dataframe = dataframe
        self.species_names = species_names
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image = row['masked_image']
        image = deserialize_image(image)
        image_pil = Image.fromarray(image)
        image_resized = image_pil.resize((128, 128))  # Resize image to 128x128

        if self.transform:
            image_resized = self.transform(image_resized)

        # Convert images to tensor
        image_tensor = torch.tensor(np.array(image_resized), dtype=torch.float32).permute(2, 0, 1) / 255.0  # Normalize if necessary

        # Convert additional features to tensor
        conf_tensor = torch.tensor(row['confidence'], dtype=torch.float32)
        iou_tensor = torch.tensor(row['iou_with_best_gt'], dtype=torch.float32)

        # Convert species names to indices
        pred_species_index = self.species_names.index(row['predicted_species']) if row['predicted_species'] in self.species_names else -1
        label_index = self.species_names.index(row['species']) if row['species'] in self.species_names else -1

        pred_species_tensor = torch.tensor(pred_species_index, dtype=torch.long)
        label_tensor = torch.tensor(label_index, dtype=torch.long)s

        return image_tensor, conf_tensor, iou_tensor, pred_species_tensor, label_tensor

# Instantiate CustomDataset with the species names list
dataset = CustomDataset(df, species_names)

# DataLoader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Hierarchical Cross Entropy loss

In [6]:
class HierarchicalCrossEntropyLoss(nn.Module):
    def __init__(self, weights):
        super(HierarchicalCrossEntropyLoss, self).__init__()
        self.weights = weights

    def forward(self, outputs, targets):
        total_loss = 0
        for k, (output, target, weight) in enumerate(zip(outputs, targets, self.weights)):
            loss = nn.CrossEntropyLoss()(output, target)
            total_loss += weight * loss
        return total_loss
    

# Calculate weights based on their normalised relative size of the sum
# of the number of classes at each hierarchical level
weights = [num_classes / sum(num_classes_hierarchy) for num_classes in num_classes_hierarchy]
criterion = HierarchicalCrossEntropyLoss(weights)

# Training Loop

In [7]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Optimizer setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Number of training epochs
num_epochs = 10

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0

    for images, conf, iou, pred_species, labels in dataloader:
        # Move data to the appropriate device
        images, conf, iou, pred_species, labels = images.to(device), conf.to(device), iou.to(device), pred_species.to(device), labels.to(device)

        # Forward pass
        outputs = model(images, conf, iou, pred_species)

        # Loss computation
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Print average loss for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x524291 and 8195x512)