In [1]:
'''
Installs needed:
pip install datasets
pip install torchvision
'''
import os
import torch
import datasets
from datasets import load_dataset
#from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import pandas as pd
import numpy as np
from torchvision import transforms
from tree import Tree

dogs = [
    1,  # "Hunting Dog"
    [
        2,  # "Sporting Dog"
        [
            3,  # "Spaniel"
            [156],  # Blenheim spaniel
            [215],  # Brittany spaniel
            [216],  # Clumber, clumber spaniel
            [219],  # Cocker spaniel, English cocker spaniel, cocker
            [217],  # English springer, English springer spaniel
            [218],  # Welsh springer spaniel
            [220],  # Sussex spaniel
            [221],  # Irish water spaniel
        ],
        [
            4,  # "Retriever"
            [205],  # Flat-coated retriever
            [206],  # Curly-coated retriever
            [207],  # Golden retriever
            [208],  # Labrador retriever
            [209],  # Chesapeake Bay retriever
        ],
        [
            5,  # "Pointer"
            [210],  # German short-haired pointer
            [211],  # Vizsla, Hungarian pointer
        ],
        [
            6,  # "Setter"
            [212],  # English setter
            [213],  # Irish setter, red setter
            [214],  # Gordon setter
        ],
    ],
    [
        7,  # "Terrier"
        [
            8,  # "Wirehair"
            [189],  # Lakeland terrier
            [190],  # Sealyham terrier, Sealyham
        ],
        [
            9,  # "Bullterrier"
            [179],  # Staffordshire bullterrier, Staffordshire bull terrier
            [180],  # American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
        ],
        [
            10,  # "Fox Terrier"
            [188],  # Wire-haired fox terrier
        ],
        [
            11,  # "Schnauzer"
            [196],  # Miniature schnauzer
            [197],  # Giant schnauzer
            [198],  # Standard schnauzer
        ],
        [191],  # Airedale, Airedale terrier
        [193],  # Australian terrier
        [181],  # Bedlington terrier
        [182],  # Border terrier
        [192],  # Cairn, cairn terrier
        [194],  # Dandie Dinmont, Dandie Dinmont terrier
        [195],  # Boston bull, Boston terrier
        [184],  # Irish terrier
        [183],  # Kerry blue terrier
        [185],  # Norfolk terrier
        [186],  # Norwich terrier
        [199],  # Scotch terrier, Scottish terrier, Scottie
        [200],  # Tibetan terrier, chrysanthemum dog
        [201],  # Silky terrier, Sydney silky
        [202],  # Soft-coated wheaten terrier
        [203],  # West Highland white terrier
        [187],  # Yorkshire terrier
    ],
    [
        12,  # "Hound"
        [
            13,  # "Coonhound"
            [165],  # Black-and-tan coonhound
            [166],  # Walker hound, Walker foxhound
        ],
        [
            14,  # "Foxhound"
            [167],  # English foxhound
            [168],  # Redbone
        ],
        [
            15,  # "Greyhound"
            [171],  # Italian greyhound
            [172],  # Whippet
        ],
        [
            16,  # "Wolfhound"
            [169],  # Borzoi, Russian wolfhound
            [170],  # Irish wolfhound
        ],
        [
            17,  # "Other Hounds"
            [160],  # Afghan hound, Afghan
            [161],  # Basset, basset hound
            [162],  # Beagle
            [163],  # Bloodhound, sleuthhound
            [164],  # Bluetick
            [173],  # Ibizan hound, Ibizan Podenco
            [174],  # Norwegian elkhound, elkhound
            [175],  # Otterhound, otter hound
            [176],  # Saluki, gazelle hound
            [177],  # Scottish deerhound, deerhound
            [178],  # Weimaraner
            [159],  # Rhodesian ridgeback
        ],
    ],
]

def calculate_num_classes(array):
    if isinstance(array[0], int):
        return len(array)
    return sum(calculate_num_classes(subarray) for subarray in array)

def load_data():

    train_ds = load_dataset("imagenet-1k", split='train', streaming=True, trust_remote_code=True)
    val_ds = load_dataset("imagenet-1k", split='validation', streaming=True, trust_remote_code=True)

    # List of dog class indices in ImageNet
    num_total_classes = calculate_num_classes(dogs)
    class_tree =  Tree(dogs)
    dog_classes = np.array(class_tree.nodes_at_depth(num_total_classes))
    print ("Dog Classes")
    print (dog_classes)
    print (len(dog_classes))

    # Define transformations for the dataset
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x),  # Ensure 3 channels
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


    dog_dataset = train_ds.filter(lambda example: example['label'] in dog_classes)

    dog_dataset_val = val_ds.filter(lambda example: example['label'] in dog_classes)

    def collate_fn(batch):
        """
        Custom collate function to apply transformations on streaming data.
        """
        images, labels = [], []
        for example in batch:
            image = transform(example['image'])  # Apply transformations
            images.append(image)
            labels.append(example['label'])
        return torch.stack(images), torch.tensor(labels)

    # Create a DataLoader for the filtered dataset
    train_dataloader = DataLoader(dog_dataset, batch_size=30, num_workers=4, collate_fn=collate_fn)
    val_dataloader = DataLoader(dog_dataset_val, batch_size=30, shuffle=False, num_workers=4, collate_fn=collate_fn)

    return train_dataloader, val_dataloader, dogs, num_total_classes

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import numpy as np

from tree import Tree

class_tree = None
distance_matrix = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# use conda install pytorch torchvision -c pytorch in python env to import; if using Colab probably use pip

In [3]:
!export CUDA_LAUNCH_BLOCKING=1

In [4]:
def precompute_lca_distances(labels, contrastive_classes, contrastive_class_to_id, class_tree):
    print (labels.shape)
    num_labels = labels.shape[0]
    lca_matrix = torch.zeros((num_labels, num_labels), dtype=torch.long)
    distance_matrix = torch.zeros((num_labels, num_labels),  dtype=torch.float)

    # Precompute LCAs and distances
    for i in range(num_labels):
        for j in range(i + 1, num_labels):

            class_i = contrastive_classes[labels[i].item()]
            class_j = contrastive_classes[labels[j].item()]
            lca = class_tree.find_lca(class_i, class_j)
            distance_i = class_tree.find_distance_to_ancestor(class_i, lca)
            distance_j = class_tree.find_distance_to_ancestor(class_j, lca)
            min_distance = min(distance_i, distance_j)
            lca_matrix[i, j] = lca
            lca_matrix[j, i] = lca
            distance_matrix[i, j] = min_distance
            distance_matrix[j, i] = min_distance

    return distance_matrix

# Basic Contrastive Learning Model
class ContrastiveModel(nn.Module):
    def __init__(self, num_classes, embedding_dim=128):
        super(ContrastiveModel, self).__init__()
        # Using ResNet backbone since we are using ImageNet and therefore compatible, feel free to change
        self.backbone = models.resnet18(pretrained=True)

        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.projection_head = nn.Sequential(
            nn.Linear(num_features, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim)
        )

        # Classification head
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        features = self.backbone(x)  # Extract features
        embeddings = self.projection_head(features)  # Project embeddings
        logits = self.classifier(embeddings)  # Classification logits
        return embeddings, logits


# Define the NCE Loss
class NCELoss(nn.Module):
    def __init__(self, temperature=0.07, dist_func_param=1):
        super(NCELoss, self).__init__()
        self.temperature = temperature
        self.dist_func_param = dist_func_param

    def forward(self, embeddings, labels):
        global distance_matrix
        # Normalize embeddings to unit vectors
        embeddings = nn.functional.normalize(embeddings, dim=1)

        similarity_matrix = torch.matmul(embeddings, embeddings.T) / self.temperature

        # Mask out self-similarity
        mask = torch.eye(similarity_matrix.size(0), device=similarity_matrix.device).bool()
        similarity_matrix = similarity_matrix.masked_fill(mask, float('-inf'))

        # Create targets: positive samples have the same label
        #labels = labels.unsqueeze(0) == labels.unsqueeze(1)
        #positives = labels.float()
        distances = torch.exp(-self.dist_func_param * distance_matrix[labels][:, labels])

        #print ("filtered distances")
        #print (distances[:10,:10])
        #print ("similarity matrix")
        #print (similarity_matrix[:10,:10])

        # Compute log-softmax and NCE loss
        log_prob = nn.functional.log_softmax(similarity_matrix, dim=1)
        #print ("Log prob")
        #print (log_prob[:10,:10])
        loss_matrix = log_prob * distances

        loss_matrix.fill_diagonal_(0)

        #loss = -torch.sum(log_prob * positives) / labels.sum() #loss is only how far apart the positives are
        loss = -torch.sum(loss_matrix) / loss_matrix.shape[0] # / labels.sum() #loss is only how far apart the positives are
        return loss


# Sample dataset preparation, making flexible
# TODO: make as function of level of specificity
class SampleDataset(Dataset):
    def __init__(self, size=100, num_classes=10, transform=None):
        self.size = size
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # Random RGB image and label
        image = torch.rand(3, 224, 224)
        label = idx % self.num_classes
        if self.transform:
            image = self.transform(image)
        return image, label

'''
# DataLoader setup
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = SampleDataset(size=100, num_classes=10, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Initialize model, loss, and optimizer
num_classes = 10'''

train_dataloader, val_dataloader, dogs, num_total_classes = load_data()
print ("Dogs")
print (dogs)
print ("Num total classes")
print (num_total_classes)
class_tree = Tree(dogs)

contrastive_classes_specificity = 2
contrastive_classes = class_tree.nodes_at_depth(contrastive_classes_specificity)
contrastive_class_to_id = {_cls: i for i, _cls in enumerate(contrastive_classes)}

clf_classes_specificity = 4
clf_classes = class_tree.nodes_at_depth(clf_classes_specificity)
clf_class_to_id = {_cls: i for i, _cls in enumerate(clf_classes)}

print ("Contrastive classes")
print (contrastive_classes)
print ("Clf classes")
print (clf_classes)

num_classes = len(clf_classes)

model = ContrastiveModel(num_classes).cuda()
nce_loss_fn = NCELoss().cuda() # Generally probably add all of this to Colab to use the GPU
classification_loss_fn = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Precompute distance matrix
distance_matrix = torch.tensor(precompute_lca_distances(np.arange(len(contrastive_classes)), contrastive_classes, contrastive_class_to_id, class_tree), device=device)

print ("Distance matrix")
print (distance_matrix[:10,:10])

Dog Classes
[156 215 216 219 217 218 220 221 205 206 207 208 209 210 211 212 213 214
 189 190 179 180 188 196 197 198 191 193 181 182 192 194 195 184 183 185
 186 199 200 201 202 203 187 165 166 167 168 171 172 169 170 160 161 162
 163 164 173 174 175 176 177 178 159]
63
Dogs
[1, [2, [3, [156], [215], [216], [219], [217], [218], [220], [221]], [4, [205], [206], [207], [208], [209]], [5, [210], [211]], [6, [212], [213], [214]]], [7, [8, [189], [190]], [9, [179], [180]], [10, [188]], [11, [196], [197], [198]], [191], [193], [181], [182], [192], [194], [195], [184], [183], [185], [186], [199], [200], [201], [202], [203], [187]], [12, [13, [165], [166]], [14, [167], [168]], [15, [171], [172]], [16, [169], [170]], [17, [160], [161], [162], [163], [164], [173], [174], [175], [176], [177], [178], [159]]]]
Num total classes
4
Contrastive classes
[3, 4, 5, 6, 8, 9, 10, 11, 191, 193, 181, 182, 192, 194, 195, 184, 183, 185, 186, 199, 200, 201, 202, 203, 187, 13, 14, 15, 16, 17]
Clf classes
[156, 



(30,)
Distance matrix


  distance_matrix = torch.tensor(precompute_lca_distances(np.arange(len(contrastive_classes)), contrastive_classes, contrastive_class_to_id, class_tree), device=device)


tensor([[0., 1., 1., 1., 2., 2., 2., 2., 2., 2.],
        [1., 0., 1., 1., 2., 2., 2., 2., 2., 2.],
        [1., 1., 0., 1., 2., 2., 2., 2., 2., 2.],
        [1., 1., 1., 0., 2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 0., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 1., 0., 1., 1., 1., 1.],
        [2., 2., 2., 2., 1., 1., 0., 1., 1., 1.],
        [2., 2., 2., 2., 1., 1., 1., 0., 1., 1.],
        [2., 2., 2., 2., 1., 1., 1., 1., 0., 1.],
        [2., 2., 2., 2., 1., 1., 1., 1., 1., 0.]], device='cuda:0')


In [5]:
train_dataloader, val_dataloader, dogs, num_total_classes = load_data()
class_tree = Tree(dogs)

contrastive_classes_specificity = 1
contrastive_classes = class_tree.nodes_at_depth(contrastive_classes_specificity)
contrastive_class_to_id = {_cls: i for i, _cls in enumerate(contrastive_classes)}

clf_classes_specificity = 3
clf_classes = class_tree.nodes_at_depth(clf_classes_specificity)
clf_class_to_id = {_cls: i for i, _cls in enumerate(clf_classes)}

num_classes = len(clf_classes)

model = ContrastiveModel(num_classes).cuda()
nce_loss_fn = NCELoss().cuda() # Generally probably add all of this to Colab to use the GPU
classification_loss_fn = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Precompute distance matrix
distance_matrix = torch.tensor(precompute_lca_distances(np.arange(len(contrastive_classes)), contrastive_classes, contrastive_class_to_id, class_tree), device=device)

Dog Classes
[156 215 216 219 217 218 220 221 205 206 207 208 209 210 211 212 213 214
 189 190 179 180 188 196 197 198 191 193 181 182 192 194 195 184 183 185
 186 199 200 201 202 203 187 165 166 167 168 171 172 169 170 160 161 162
 163 164 173 174 175 176 177 178 159]
63
(3,)


  distance_matrix = torch.tensor(precompute_lca_distances(np.arange(len(contrastive_classes)), contrastive_classes, contrastive_class_to_id, class_tree), device=device)


In [6]:
print(device)



cuda


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import torch

def calculate_metrics(predictions, labels, average='macro'):
    """
    Calculates precision, recall, and F1-score.
    """
    predictions = predictions.cpu().numpy()
    labels = labels.cpu().numpy()

    precision = precision_score(labels, predictions, average=average, zero_division=0)
    recall = recall_score(labels, predictions, average=average, zero_division=0)
    f1 = f1_score(labels, predictions, average=average, zero_division=0)

    return precision, recall, f1

def evaluate_model(model, dataloader, loss_fn, device, num_classes):
    """
    Evaluates the model on a given dataloader and returns classification statistics.
    """
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    num_batches = 0

    all_predictions = []
    all_labels = []

    with torch.no_grad():  # Disable gradients for evaluation
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            labels = torch.tensor([class_tree.which_ancestor(label.item(), clf_classes) for label in labels], device='cuda')
            labels = torch.tensor([clf_class_to_id[clf_label.item()] for clf_label in labels], device='cuda')

            # Forward pass
            _, logits = model(images)
            loss = loss_fn(logits, labels)

            # Update loss
            running_loss += loss.item()

            # Predictions and accuracy
            _, predictions = torch.max(logits, dim=1)
            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

            # Store predictions and labels for metric computation
            all_predictions.append(predictions)
            all_labels.append(labels)
            num_batches += 1

    # Combine all predictions and labels across batches
    all_predictions = torch.cat(all_predictions, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    # Compute metrics
    precision, recall, f1 = calculate_metrics(all_predictions, all_labels)
    conf_matrix = confusion_matrix(all_labels.cpu().numpy(), all_predictions.cpu().numpy())

    # Compute average loss and accuracy
    avg_loss = running_loss / num_batches
    accuracy = total_correct / total_samples

    return avg_loss, accuracy, precision, recall, f1, conf_matrix



In [None]:
from tqdm import tqdm

# Training loop

train_dataloader, val_dataloader, dogs, num_total_classes = load_data()
class_tree = Tree(dogs)

contrastive_classes_specificity = 3
contrastive_classes = class_tree.nodes_at_depth(contrastive_classes_specificity)
contrastive_class_to_id = {_cls: i for i, _cls in enumerate(contrastive_classes)}

clf_classes_specificity = 2
clf_classes = class_tree.nodes_at_depth(clf_classes_specificity)
clf_class_to_id = {_cls: i for i, _cls in enumerate(clf_classes)}

num_classes = len(clf_classes)

model = ContrastiveModel(num_classes).cuda()
nce_loss_fn = NCELoss().cuda() # Generally probably add all of this to Colab to use the GPU
classification_loss_fn = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Precompute distance matrix
distance_matrix = torch.tensor(precompute_lca_distances(np.arange(len(contrastive_classes)), contrastive_classes, contrastive_class_to_id, class_tree), device=device)

for epoch in range(2):  # Using two epochs to test, adjust as needed
    model.train()
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    num_batches = 0
    all_predictions = []
    all_labels = []

    for i, batch in enumerate(tqdm(train_dataloader)):
        images, labels = batch
        images, labels = images.cuda(), labels.cuda()

        # Map labels to contrastive level
        contr_labels = torch.tensor([class_tree.which_ancestor(label.item(), contrastive_classes) for label in labels], device='cuda')
        contr_labels = torch.tensor([contrastive_class_to_id[contr_label.item()] for contr_label in contr_labels], device='cuda')

        # Map labels to classification level
        clf_labels = torch.tensor([class_tree.which_ancestor(label.item(), clf_classes) for label in labels], device='cuda')
        clf_labels = torch.tensor([clf_class_to_id[clf_label.item()] for clf_label in clf_labels], device='cuda')

        # Forward pass
        embeddings, logits = model(images)

        # Compute losses
        nce_loss = nce_loss_fn(embeddings, contr_labels) * len(contr_labels) / 75
        classification_loss = classification_loss_fn(logits, clf_labels)
        total_loss = nce_loss + classification_loss

        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += total_loss.item()

        # Compute accuracy
        _, predictions = torch.max(logits, dim=1)  # Get class predictions
        total_correct += (predictions == clf_labels).sum().item()  # Count correct predictions
        total_samples += clf_labels.size(0)  # Update total samples
        print(f"Batch accuracy: {total_correct / total_samples:.4f}")
        num_batches += 1

        all_predictions.append(predictions)
        all_labels.append(labels)
        

    # Compute average loss and accuracy
    epoch_loss = running_loss / num_batches
    train_accuracy = total_correct / total_samples
    
    # Validation phase
    val_loss, val_accuracy, val_precision, val_recall, val_f1, val_conf_matrix = evaluate_model(
        model, val_dataloader, classification_loss_fn, device, num_classes
    )

    # Log epoch statistics
    print(f"Epoch [{epoch + 1}/{2}]")
    print(f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    print(f"Validation Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1-Score: {val_f1:.4f}")
    print("-" * 50)


Dog Classes
[156 215 216 219 217 218 220 221 205 206 207 208 209 210 211 212 213 214
 189 190 179 180 188 196 197 198 191 193 181 182 192 194 195 184 183 185
 186 199 200 201 202 203 187 165 166 167 168 171 172 169 170 160 161 162
 163 164 173 174 175 176 177 178 159]
63


  distance_matrix = torch.tensor(precompute_lca_distances(np.arange(len(contrastive_classes)), contrastive_classes, contrastive_class_to_id, class_tree), device=device)


(63,)


0it [00:00, ?it/s]

Batch accuracy: 0.0667


0it [00:11, ?it/s]
Too many dataloader workers: 4 (max is dataset.num_shards=1). Stopping 3 dataloader workers.


Epoch [1/2]
Train Loss: 9.2640, Train Accuracy: 0.0667
Validation Loss: 3.4427, Validation Accuracy: 0.0000
Validation Precision: 0.0000, Recall: 0.0000, F1-Score: 0.0000
Confusion Matrix:
[[0 6 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 3 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 3 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 4 0 0 0 0 0 0 0 0 0 0 0 0 0]]
--------------------------------------------------


0it [00:00, ?it/s]

Batch accuracy: 0.1667


0it [00:12, ?it/s]
Too many dataloader workers: 4 (max is dataset.num_shards=1). Stopping 3 dataloader workers.


Epoch [2/2]
Train Loss: 9.5206, Train Accuracy: 0.1667
Validation Loss: 3.4267, Validation Accuracy: 0.0000
Validation Precision: 0.0000, Recall: 0.0000, F1-Score: 0.0000
Confusion Matrix:
[[0 6 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 3 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 2 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 3 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 4 0 0 0 0 0 0 0 0 0 0 0 0 0]]
--------------------------------------------------


In [None]:
# Training Loop
for epoch in range(args.num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    total_correct_train = 0
    total_samples_train = 0

    for batch in train_dataloader:
        images, labels = batch
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        embeddings, logits = model(images)

        # Compute classification loss
        clf_labels = torch.tensor([class_tree.which_ancestor(label.item(), clf_classes) for label in labels], device=device)
        clf_labels = torch.tensor([clf_class_to_id[clf_label.item()] for clf_label in clf_labels], device=device)
        loss = classification_loss_fn(logits, clf_labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update loss and accuracy
        train_loss += loss.item()
        _, predictions = torch.max(logits, dim=1)
        total_correct_train += (predictions == clf_labels).sum().item()
        total_samples_train += clf_labels.size(0)

    # Compute training statistics
    train_accuracy = total_correct_train / total_samples_train
    train_loss /= len(train_dataloader)

    # Validation phase
    val_loss, val_accuracy, val_precision, val_recall, val_f1, val_conf_matrix = evaluate_model(
        model, val_dataloader, classification_loss_fn, device, num_classes
    )

    # Log epoch statistics
    print(f"Epoch [{epoch + 1}/{args.num_epochs}]")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    print(f"Validation Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1-Score: {val_f1:.4f}")
    print(f"Confusion Matrix:\n{val_conf_matrix}")
    print("-" * 50)

In [None]:
# Training loop
# TODO: add in evals or call to eval.py file perhaps?
for epoch in range(2):  # Using two epochs to test but add more epochs
    model.train()
    running_loss = 0.0

    for batch in train_dataloader:
        images, labels = batch
        #print ("Batch")
        #print (images.shape, labels.shape)
        #print (labels)
        images, labels = images.cuda(), labels.cuda()

        # for each image, find the label at the correct classification level
        contr_labels = torch.tensor([class_tree.which_ancestor(label.item(), contrastive_classes) for label in labels], device='cuda')
        contr_labels = torch.tensor([contrastive_class_to_id[contr_label.item()] for contr_label in contr_labels], device='cuda')
        #print("CONTRASTIVE LABELS: ")
        #print(contr_labels)

        # Map labels to their IDs
        clf_labels = torch.tensor([class_tree.which_ancestor(label.item(), clf_classes) for label in labels], device='cuda')
        clf_labels = torch.tensor([clf_class_to_id[clf_label.item()] for clf_label in clf_labels], device='cuda')
        #print ("CLASS LABELS")
        #print (clf_labels)

        # Print tensor shapes
        '''print(f"Images shape: {images.shape}")
        print(f"Labels shape: {labels.shape}")
        print(f"Contrastive labels shape: {contr_labels.shape}")
        print(f"Class labels shape: {clf_labels.shape}")'''

        # Forward pass
        embeddings, logits = model(images)

        # Compute losses
        nce_loss = nce_loss_fn(embeddings, contr_labels) * len(contr_labels) / 75
        classification_loss = classification_loss_fn(logits, clf_labels)
        total_loss = nce_loss + classification_loss  # Combined loss

        # Print losses
        print(f"NCE, CLF Loss: {nce_loss} {classification_loss}")
        print(f"Total Loss: {total_loss.item()}")

        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    print(f"Epoch [{epoch+1}/2], Loss: {running_loss/len(train_dataloader):.4f}")


NCE, CLF Loss: 24.287818908691406 4.1379313468933105
Total Loss: 28.425750732421875
NCE, CLF Loss: 23.329919815063477 4.117332935333252
Total Loss: 27.44725227355957
NCE, CLF Loss: 24.186758041381836 4.11442756652832
Total Loss: 28.301185607910156
NCE, CLF Loss: 22.530248641967773 4.251932144165039
Total Loss: 26.782180786132812
NCE, CLF Loss: 22.803865432739258 4.272979736328125
Total Loss: 27.076845169067383


KeyboardInterrupt: 

In [None]:
train_dat