<a href="https://colab.research.google.com/github/StaniszewskiA/Loss-Functions/blob/main/Untitled4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [72]:
import json
from typing import Dict
import torch
from torch import optim
from torchvision import datasets, transforms


def get_config_cnnh_cifar():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    config = {
        "T": 30,
        "hash_save_path": "save/CNNH/",
        "optimizer": {
            "type": optim.Adam,
            "optim_params": {
                "lr": 1e-4,
                "betas": (0.9, 0.999)
            }
        },
        "info": "[CNNH]",
        "resize_size": 64,
        "crop_size": 32,
        "batch_size": 8,
        "net": "ResNet",
        "dataset": "cifar",
        "epochs": 100,
        "save_interval": 10,
        "test_MAP": 10,
        "device": device,
        "bit_list": [12],
    }

    return config


transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
])

cifar_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(cifar_dataset, batch_size=8, shuffle=True, num_workers=4)

bits = 12
n = 10000
train_labels = torch.tensor([], dtype=torch.long)

for data, labels in train_loader:
    train_labels = torch.cat((train_labels, labels), dim=0)

train_labels_capped = train_labels[:3]


Files already downloaded and verified


In [9]:
from torch import nn
from torchvision import models

class AlexNet(nn.Module):
      def __init__(self, hash_bits: int, pretrained: bool = True):
        super(AlexNet, self).__init__()

        model_alexnet = models.alexnet(pretrained=pretrained)
        self.features = model_alexnet.features

        self.hash_layer = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, hash_bits),
        )

      def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = x.view(x.size(0), -1)
        hash_code = self.hash_layer(x)
        return hash_code

In [10]:
from torch import nn
from torchvision import models

class ResNet(nn.Module):
    """ResNet for Hashing"""
    def __init__(self, hash_bits: int, pretrained: bool = True):
        """Constructor"""
        super(ResNet, self).__init__()

        model_resnet = models.resnet18(pretrained=pretrained)
        self.features = nn.Sequential(*list(model_resnet.children())[:-1])

        self.hash_layer = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(model_resnet.fc.in_features, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, hash_bits),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass"""
        x = self.features(x)
        x = x.view(x.size(0), -1)
        hash_code = self.hash_layer(x)
        return hash_code

In [41]:
def initialize_hash(n, bits):
  hashes = 2 * torch.rand((n, bits)) - 1
  return hashes

In [85]:
import numpy as np

def initalize_sim_matrix(bits):
  sim_matrix = np.outer(train_labels_capped, train_labels_capped)

  sim_matrix_scaled = np.kron(sim_matrix, np.ones((bits, bits)))

  return sim_matrix_scaled

In [86]:
import numpy as np

def compute_hash_matrix(hash):
  hash_numpy = np.array(hash)

  if len(hash_numpy.shape) == 1:
      hash_transposed = hash_numpy.reshape(-1, 1)
  else:
      hash_transposed = hash_numpy.T

  hash_matrix = np.outer(hash_numpy, hash_transposed)

  return hash_matrix


In [154]:
import torch
import torch.nn.functional as F

def reconstruction_triplet_loss(hash_matrix, sim_matrix, triplets):
  hash_tensor = torch.tensor(hash_matrix, dtype=torch.float32)
  sim_tensor = torch.tensor(sim_matrix, dtype=torch.float32)

  reconstruction_loss = torch.nn.functional.mse_loss(hash_tensor, sim_tensor)

  anchor, positive, negative = zip(*triplets)

  anchor = torch.tensor(anchor, dtype=torch.float32)
  positive = torch.tensor(positive, dtype=torch.float32)
  negative = torch.tensor(negative, dtype=torch.float32)

  pos_distances = F.pairwise_distance(anchor, positive)
  neg_distances = F.pairwise_distance(anchor, negative)

  margin = 1.0

  triplet_loss = F.relu(pos_distances - neg_distances + margin)
  #triplet_loss += sim_matrix[torch.arange(len(triplets)), torch.arange(len(triplets))]

  alpha = 0.3

  loss = alpha * reconstruction_loss + (1 - alpha) * triplet_loss

  return loss


In [38]:
class CNNHTrainer:
  def __init__(self, cnnh_config) -> None:
      self.config = cnnh_config
      self.net = self.initialize_model()
      self.optimizer = self.setup_optimizer()
      self.epochs = self.config["epochs"]
      self.hashes = initalize_hashes(10000, 12)
      self.train_labels = self.dataset_manager.load_dataset_labels(self.train_loader, self.device)
      self.criterion = reconstruction_triplet_loss(hash_matrix, sim_matrix, anchor,
                                                    anchor_positive_similarity,
                                                    anchor_negative_similarity)

  def initialize_model(self):
    return self.config["net"](self.config["bit_list"][0])

  def setup_optimizer(self):
        optimizer_type = self.config["optimizer"]["type"]
        optimizer_params = self.config["optimizer"]["optim_params"]

        optimizer = optimizer_type(self.net.parameters(), **optimizer_params)

        if 'weight_decay' in optimizer_params and optimizer_params['weight_decay'] > 0:
            optimizer_params["weight_decay"] = 1e-5

        return optimizer

  def train(self):
    train_losses = []

    for epoch in range(self.epochs):
      self.net.train()
      train_loss = self.train_epoch()
      train_losses.append(train_loss)

      if (epoch + 1) % self.config["test_MAP"] == 0:
          current_map = self.evaluate()
          if current_map > best_map:
              best_map = current_map

      if (epoch + 1) % self.config["save_interval"] == 0:
          self.save_model_state(epoch + 1, best_map, train_losses)

      if (epoch + 1) == self.epochs:
          self.evaluate()
          self.plot_results()

  def train_epoch(self):
    epoch_losses = []
    for batch_index, (images, labels) in enumerate(self.train_loader):
        torch.cuda.empty_cache()
        images, labels = images.to(self.device), labels.to(self.device)
        self.optimizer.zero_grad()
        # Forward pass
        predictions = self.net(images)
        self.em.update(labels, predictions)
        # Calculate loss
        loss = self.calculate_loss(predictions, labels, batch_index)
        # Backward pass
        loss.backward()
        # Update weights
        self.optimizer.step()

        epoch_losses.append(loss.item())

    return epoch_losses

  def calculate_loss(self, prediction, targets, index):
    loss = self.criterion(prediction, targets, index)

    if self.optimizer_params.get('l2_reg', False):
        l2_reg = torch.tensor(0.0, device=self.device)
        for param in self.net.parameters():
            l2_reg += torch.norm(param)
        loss += 0.5 * self.optimizer_params['weight_decay'] * l2_reg

    return loss

  def save_model_state(self, epoch, best_map, train_losses):
    bit = self.config["bit_list"]
    save_path = f"model_state_epoch_{epoch}_{self.config['dataset']}_{bit}.pt"
    torch.save(self.net.state_dict(), save_path)
    print(f"Model state saved at epoch {epoch}: {save_path}")

    result_dict = {
        "epoch": epoch,
        "best_map": best_map,
        "train_losses": train_losses
    }
    result_path = f"training_results_epoch_{epoch}.json"
    with open(result_path, 'w') as result_file:
        json.dump(result_dict, result_file)
    print(f"Training results saved at epoch {epoch}: {result_path}")

In [128]:
def generate_triplets(hashes):
  triplets = []

  for i in range(len(hashes)):
      anchor = hashes[i]

      # Select a positive sample (similar to the anchor)
      positive_index = np.random.choice(np.where(hashes == anchor)[0])
      positive = hashes[positive_index]

      # Select a negative sample (dissimilar to the anchor)
      negative_index = np.random.choice(np.where(hashes != anchor)[0])
      negative = hashes[negative_index]

      triplets.append((anchor, positive, negative))

  return triplets

In [157]:
cnnh_config_cifar = get_config_cnnh_cifar()

bits = 1
data_amount = train_labels_capped.size(0)

hashes = initialize_hash(data_amount, bits)
hash_matrix = compute_hash_matrix(hashes)
sim_matrix = initalize_sim_matrix(bits)

print("Hash Matrix:")
print(hash_matrix)
print("\nSimilarity Matrix:")
print(sim_matrix)

triplets = generate_triplets(hashes)

rec_loss = reconstruction_triplet_loss(hash_matrix, sim_matrix, triplets)

print(rec_loss)

Hash Matrix:
[[ 0.59496176 -0.23058371 -0.06509267]
 [-0.23058371  0.08936516  0.02522735]
 [-0.06509267  0.02522735  0.00712156]]

Similarity Matrix:
[[25. 45. 10.]
 [45. 81. 18.]
 [10. 18.  4.]]
tensor(403.2830)
