# Testing Cuda and installing packages

In [1]:
import torch
print(torch.cuda.is_available())  # Should return True if CUDA is correctly set up

True


In [2]:
!pip install datasets transformers huggingface_hub supervision timm sentence_transformers open_clip_torch faiss-gpu
# !pip install flash_attn
! pip install git+https://github.com/deepglint/unicom.git
!pip install faiss-gpu

Collecting datasets
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
Collecting supervision
  Downloading supervision-0.24.0-py3-none-any.whl.metadata (14 kB)
Collecting open_clip_torch
  Downloading open_clip_torch-2.29.0-py3-none-any.whl.metadata (31 kB)
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting ftfy (from open_clip_torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading datasets-3.0.2-py3-none-any.whl (472 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m2

In [None]:
# !pip install flash-attn --no-build-isolation

Collecting flash-attn
  Using cached flash_attn-2.6.3.tar.gz (2.6 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hcanceled
[31mERROR: Operation cancelled by user[0m[31m
[0m

# Loading VLM4Bio data and creating embeddings

In [10]:
# dont skip if running training loop
!git lfs clone https://huggingface.co/datasets/sammarfy/VLM4Bio VLM4BIO_data

          with new flags from 'git clone'

'git clone' has been updated in upstream Git to have comparable
speeds to 'git lfs clone'.
Cloning into 'VLM4BIO_data'...
remote: Enumerating objects: 31716, done.[K
remote: Counting objects: 100% (31716/31716), done.[K
remote: Compressing objects: 100% (31648/31648), done.[K
remote: Total 31716 (delta 73), reused 31698 (delta 63), pack-reused 0 (from 0)[K
Receiving objects: 100% (31716/31716), 269.44 MiB | 9.21 MiB/s, done.
Resolving deltas: 100% (73/73), done.
Updating files: 100% (31484/31484), done.


In [None]:
# don't skip if running training loop
!mkdir downloaded_images

!mv VLM4BIO_data/datasets/Bird/chunk_0/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_1/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_2/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_3/* downloaded_images
!mv VLM4BIO_data/datasets/Bird/chunk_4/* downloaded_images

!mv VLM4BIO_data/datasets/Butterfly/chunk_0/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_1/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_2/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_3/* downloaded_images
!mv VLM4BIO_data/datasets/Butterfly/chunk_4/* downloaded_images

!mv VLM4BIO_data/datasets/Fish/chunk_0/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_1/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_2/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_3/* downloaded_images
!mv VLM4BIO_data/datasets/Fish/chunk_4/* downloaded_images

# Importing packages and downloading/cleaning data

In [3]:
import os
import pandas as pd
import requests
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
from PIL import Image
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [12]:
# skippable if loading embeds from .mat file
import requests
import os
import pandas as pd

def download_metadata_csv(metadata_url, metadata_file_path):
    if not os.path.exists(metadata_file_path):
        response = requests.get(metadata_url)
        if response.status_code == 200:
            with open(metadata_file_path, 'wb') as f:
                f.write(response.content)
            print(f"Metadata CSV file saved as {metadata_file_path}")
        else:
            print(f"Failed to download the metadata file. Status code: {response.status_code}")
            return
    else:
        print(f"Metadata file already exists at {metadata_file_path}")

def convert_metadata_to_dataframe(metadata_file_path):
    df = pd.read_csv(metadata_file_path)
    df.rename(columns={"fileNameAsDelivered": "image_filename", "scientificName": "scientificName"}, inplace=True)
    return df


def filter_existing_images_with_scientific_name(df, output_dir, taxa='Fish'):
    df['file_exists'] = df['image_filename'].apply(lambda x: os.path.exists(os.path.join(output_dir, x)))
    df_cleaned = df[df['file_exists']].copy()
    df_cleaned.drop(columns=['file_exists'], inplace=True)
    df_cleaned['taxa'] = taxa
    return df_cleaned

def save_dataframe_to_csv(df, csv_file_path="cleaned_images_with_scientific_names.csv"):
    df.to_csv(csv_file_path, index=False)
    print(f"DataFrame saved to {csv_file_path}")


taxa = ['Fish', 'Bird', 'Butterfly']
cleaned_dfs = []
for t in taxa:
  metadata_url = f"https://huggingface.co/datasets/sammarfy/VLM4Bio/resolve/main/datasets/{t}/metadata/metadata_10k.csv"
  output_dir = "downloaded_images"
  metadata_file_path = f"metadata_10k_{t}.csv"
  download_metadata_csv(metadata_url, metadata_file_path)
  df = convert_metadata_to_dataframe(metadata_file_path)
  df_cleaned = filter_existing_images_with_scientific_name(df, output_dir, t)
  cleaned_dfs.append(df_cleaned)
  save_dataframe_to_csv(df_cleaned, f"cleaned_images_with_scientific_names_{t}.csv")


labels = df_cleaned["scientificName"].unique().tolist()
full_species_df = pd.concat(cleaned_dfs)

# full_species_df = full_species_df.dropna(subset=['scientificName'])
# full_species_df = full_species_df[full_species_df['scientificName'].str.strip() != '']


Metadata CSV file saved as metadata_10k_Fish.csv
DataFrame saved to cleaned_images_with_scientific_names_Fish.csv
Metadata CSV file saved as metadata_10k_Bird.csv
DataFrame saved to cleaned_images_with_scientific_names_Bird.csv
Metadata CSV file saved as metadata_10k_Butterfly.csv
DataFrame saved to cleaned_images_with_scientific_names_Butterfly.csv


In [None]:
# skippable if loading embeds from .mat file
import open_clip
import torch
from PIL import Image
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

def get_model():
    model, _, processing = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="laion2b_s34b_b79k", device=device)

    model.to(device)
    model = model.eval()
    return model, processing


def get_image_embeddings_clip(image_filename, model, preprocess):
    # folder where images are stored
    image_folder = "downloaded_images"

    try:
        image_path = os.path.join(image_folder, image_filename)
        image = Image.open(image_path)
        image = preprocess(image).unsqueeze(0).to(device)

        with torch.no_grad():
            image_embeddings = model.encode_image(image)

        return image_embeddings.cpu().numpy()
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        return None  # Return None for failed image processing


In [None]:
# skippable if loading embeds from .mat file

model, processing = get_model()

full_species_df['image_embeddings'] = full_species_df['image_filename'].apply(lambda x: get_image_embeddings_clip(x, model, processing))

open_clip_pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

In [None]:
full_species_df

Unnamed: 0,image_filename,scientificName,taxa,image_embeddings
0,UWZM-F-0001570.JPG,Lepomis macrochirus,Fish,"[[0.2290793, -0.5630022, -0.8978242, 0.0935000..."
1,UWZM-F-0001664.JPG,Lepomis megalotis,Fish,"[[0.4251782, -0.5221929, -1.1134714, -0.167325..."
2,UWZM-F-0001696.JPG,Lepomis microlophus,Fish,"[[0.09046164, -0.50626826, -1.4220252, -0.0800..."
3,UWZM-F-0001697.JPG,Lepomis punctatus,Fish,"[[0.12731902, -0.4336033, -1.3776393, -0.21835..."
4,UWZM-F-0000002.JPG,Alosa aestivalis,Fish,"[[0.013927758, -0.9186851, -1.0240858, 0.88987..."
...,...,...,...,...
10008,Butterfly_imbalanced_test_Eueides_isabella_114...,Eueides isabella,Butterfly,"[[0.10432007, -0.7902194, 0.06585194, -0.51994..."
10009,Butterfly_imbalanced_test_Eueides_isabella_980...,Eueides isabella,Butterfly,"[[0.43515843, -0.9358704, 0.3825564, -0.465492..."
10010,Butterfly_imbalanced_test_Rhetus_periander_307...,Rhetus periander,Butterfly,"[[-0.05358781, -1.0487952, 0.012689903, -0.709..."
10011,Butterfly_imbalanced_test_Rhetus_periander_371...,Rhetus periander,Butterfly,"[[-0.008405313, -0.76264656, 0.0523196, -0.474..."


# Saving the embedding data in .mat and .pt format

In [None]:
# skippable if loading embeds from .mat file


from scipy.io import savemat
import numpy as np

# Convert the embeddings to a numpy array
np_embeds_list = np.array(full_species_df['image_embeddings'].tolist())

# Ensure embeddings are of the correct type (float32)
embeddings = np_embeds_list.astype('float32')
save_vars = {
    'scientificName': full_species_df['scientificName'].tolist(),
    'taxa': full_species_df['taxa'].tolist(),
    'image_filename': full_species_df['image_filename'].tolist(),
    'embeddings': embeddings
}

# Save as .mat file
savemat('embeddings.mat', save_vars)

tensor_embeddings = torch.tensor(embeddings)

# Save the tensor to a .pt file
torch.save(tensor_embeddings, 'embeddings.pt')

full_species_df[['image_filename', 'scientificName', 'taxa']].to_csv('selected_columns_embeddings.csv', index=False)

# Load the embedding data from .mat file

In [16]:
# skippable if generating embeddings from scratch

from scipy.io import loadmat

data = loadmat('embeddings.mat')
loaded_embeddings = data['embeddings']

# data['embeddings'] = [sub_array for sub_array in loaded_embeddings]

import pandas as pd

# Sample lists

# Create a DataFrame with each list as a column
full_species_df = pd.DataFrame({
    'image_filename': data['image_filename'],
    'scientificName': data['scientificName'],
    'taxa': data['taxa'],
    'image_embeddings': [sub_array for sub_array in loaded_embeddings]
})
full_species_df

Unnamed: 0,image_filename,scientificName,taxa,image_embeddings
0,UWZM-F-0001570.JPG ...,Lepomis macrochirus,Fish,"[[0.2290793, -0.5630022, -0.8978242, 0.0935000..."
1,UWZM-F-0001664.JPG ...,Lepomis megalotis,Fish,"[[0.4251782, -0.5221929, -1.1134714, -0.167325..."
2,UWZM-F-0001696.JPG ...,Lepomis microlophus,Fish,"[[0.09046164, -0.50626826, -1.4220252, -0.0800..."
3,UWZM-F-0001697.JPG ...,Lepomis punctatus,Fish,"[[0.12731902, -0.4336033, -1.3776393, -0.21835..."
4,UWZM-F-0000002.JPG ...,Alosa aestivalis,Fish,"[[0.013927758, -0.9186851, -1.0240858, 0.88987..."
...,...,...,...,...
31447,Butterfly_imbalanced_test_Eueides_isabella_114...,Eueides isabella,Butterfly,"[[0.10432007, -0.7902194, 0.06585194, -0.51994..."
31448,Butterfly_imbalanced_test_Eueides_isabella_980...,Eueides isabella,Butterfly,"[[0.43515843, -0.9358704, 0.3825564, -0.465492..."
31449,Butterfly_imbalanced_test_Rhetus_periander_307...,Rhetus periander,Butterfly,"[[-0.05358781, -1.0487952, 0.012689903, -0.709..."
31450,Butterfly_imbalanced_test_Rhetus_periander_371...,Rhetus periander,Butterfly,"[[-0.008405313, -0.76264656, 0.0523196, -0.474..."


In [17]:
full_species_df

Unnamed: 0,image_filename,scientificName,taxa,image_embeddings
0,UWZM-F-0001570.JPG ...,Lepomis macrochirus,Fish,"[[0.2290793, -0.5630022, -0.8978242, 0.0935000..."
1,UWZM-F-0001664.JPG ...,Lepomis megalotis,Fish,"[[0.4251782, -0.5221929, -1.1134714, -0.167325..."
2,UWZM-F-0001696.JPG ...,Lepomis microlophus,Fish,"[[0.09046164, -0.50626826, -1.4220252, -0.0800..."
3,UWZM-F-0001697.JPG ...,Lepomis punctatus,Fish,"[[0.12731902, -0.4336033, -1.3776393, -0.21835..."
4,UWZM-F-0000002.JPG ...,Alosa aestivalis,Fish,"[[0.013927758, -0.9186851, -1.0240858, 0.88987..."
...,...,...,...,...
31447,Butterfly_imbalanced_test_Eueides_isabella_114...,Eueides isabella,Butterfly,"[[0.10432007, -0.7902194, 0.06585194, -0.51994..."
31448,Butterfly_imbalanced_test_Eueides_isabella_980...,Eueides isabella,Butterfly,"[[0.43515843, -0.9358704, 0.3825564, -0.465492..."
31449,Butterfly_imbalanced_test_Rhetus_periander_307...,Rhetus periander,Butterfly,"[[-0.05358781, -1.0487952, 0.012689903, -0.709..."
31450,Butterfly_imbalanced_test_Rhetus_periander_371...,Rhetus periander,Butterfly,"[[-0.008405313, -0.76264656, 0.0523196, -0.474..."


# Forming clusters

In [18]:
import faiss
import numpy as np

# set seed for reproducibility
np.random.seed(5)

def set_clusters_on_df(df, ncentroids=3000, niter=20):

  # get embeddings from df
  np_embeds_list = np.array(df['image_embeddings'].tolist())

  embeddings = np.squeeze(np_embeds_list, axis=1).astype('float32')

  # number of clusters - cannot exceed number of samples
  verbose = True

  # num samples, and dimensionality of embeds
  n, d = embeddings.shape

  kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose)
  kmeans.train(embeddings)

  centroids = kmeans.centroids
  D, I = kmeans.index.search(embeddings, 1)
  df['cluster_id'] = I.flatten()
  df['square_dist'] = D.flatten()
  return df

In [None]:
# full_species_df[['image_filename', 'scientificName', 'taxa', 'cluster_id', 'square_dist']].to_csv('filenames_with_clusters.csv', index=False)

In [19]:
EMBEDDING_SIZE = full_species_df['image_embeddings'].iloc[0].shape[1]
EMBEDDING_SIZE

512

# Setting Up the Dataset

In [8]:
from sklearn.model_selection import train_test_split

def get_test_train_split(df_cleaned, test_size=0.2, stratify_by_scientific_name=False):
  if stratify_by_scientific_name:
    # if we want to get at least one from each we can filter out all options with only 1 image
    species_counts = df_cleaned['scientificName'].value_counts()
    valid_species = species_counts[species_counts > 1].index
    df_filtered = df_cleaned[df_cleaned['scientificName'].isin(valid_species)]
    # with stratification, we need to specify the test_size because we need to hit a minimum
    calc_min = len(valid_species)/len(df_filtered)
    min_split = max(test_size, calc_min)
    print(f'The split for stratification is: {min_split}')

    train_df, test_df = train_test_split(df_filtered, test_size=min_split, stratify=df_filtered['scientificName'])
  else:
    train_df, test_df = train_test_split(df_cleaned, test_size=test_size)
  return train_df, test_df


In [9]:
class VLM4BioEmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.img_dir = 'downloaded_images/'
        self.image_files = df['image_filename'].tolist()
        if 'cluster_ids' in df.columns:
            self.cluster_ids = df['cluster_ids'].tolist()
        else:
            self.cluster_ids = [-1] * len(self.image_files)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image and apply transformations
        image_path = os.path.join(self.img_dir, self.image_files[idx]).strip()
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = self.cluster_ids[idx]
        return image, label


In [20]:
# split dataset to train and test
train_df, test_df = get_test_train_split(full_species_df, test_size=0.2, stratify_by_scientific_name=True)

NUM_CLUSTERS = 3000

# make sure clusters are set on training data ONLY
train_df = set_clusters_on_df(train_df, ncentroids=NUM_CLUSTERS)

# create datasets
# train_dataset = VLM4BioEmbeddingDataset(train_df)
# test_dataset = VLM4BioEmbeddingDataset(test_df)
# ^ these should be set in training loop, include transform

The split for stratification is: 0.2


# Partial FC


In [14]:
# this code is taken from https://github.com/deepglint/unicom/blob/main/partial_fc.py

import math
from typing import Callable

import torch
from torch import distributed
from torch.nn.functional import linear, normalize


class CombinedMarginLoss(torch.nn.Module):
    def __init__(self,
                 s,
                 m1,
                 m2,
                 m3,
                 interclass_filtering_threshold=0):
        super().__init__()
        self.s = s
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3
        self.interclass_filtering_threshold = interclass_filtering_threshold

        # For ArcFace
        self.cos_m = math.cos(self.m2)
        self.sin_m = math.sin(self.m2)
        self.theta = math.cos(math.pi - self.m2)
        self.sinmm = math.sin(math.pi - self.m2) * self.m2
        self.easy_margin = False

    def forward(self, logits, labels):
        index_positive = torch.where(labels != -1)[0]

        if self.interclass_filtering_threshold > 0:
            with torch.no_grad():
                dirty = logits > self.interclass_filtering_threshold
                dirty = dirty.float()
                mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
                mask.scatter_(1, labels[index_positive], 0)
                dirty[index_positive] *= mask
                tensor_mul = 1 - dirty
            logits = tensor_mul * logits

        target_logit = logits[index_positive, labels[index_positive].view(-1)]

        if self.s == 1:
            return logits

        if self.m1 == 1.0 and self.m3 == 0.0:
            with torch.no_grad():
                target_logit.arccos_()
                logits.arccos_()
                final_target_logit = target_logit + self.m2
                logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
                logits.cos_()
            logits = logits * self.s

        elif self.m3 > 0:
            final_target_logit = target_logit - self.m3
            logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
            logits = logits * self.s
        else:
            raise

        return logits


class PartialFC_V2(torch.nn.Module):
    """
    https://arxiv.org/abs/2203.15565
    A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
    When sample rate less than 1, in each iteration, positive class centers and a random subset of
    negative class centers are selected to compute the margin-based softmax loss, all class
    centers are still maintained throughout the whole training process, but only a subset is
    selected and updated in each iteration.
    .. note::
        When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
    Example:
    --------
    >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
    >>> for img, labels in data_loader:
    >>>     embeddings = net(img)
    >>>     loss = module_pfc(embeddings, labels)
    >>>     loss.backward()
    >>>     optimizer.step()
    """
    _version = 2

    def __init__(
        self,
        margin_loss: Callable,
        embedding_size: int,
        num_classes: int,
        sample_rate: float = 1.0,
        fp16: bool = False,
        is_normlize: int = 1,
        sample_num_feat=None
    ):
        """
        Paramenters:
        -----------
        embedding_size: int
            The dimension of embedding, required
        num_classes: int
            Total number of classes, required
        sample_rate: float
            The rate of negative centers participating in the calculation, default is 1.0.
        """
        super(PartialFC_V2, self).__init__()
        assert (
            distributed.is_initialized()
        ), "must initialize distributed before create this"
        self.rank = distributed.get_rank()
        self.world_size = distributed.get_world_size()

        self.dist_cross_entropy = DistCrossEntropy()
        self.embedding_size = embedding_size
        self.sample_rate: float = sample_rate
        self.sample_num_feat: int = sample_num_feat
        self.fp16 = fp16
        self.is_normlize = is_normlize
        self.num_local: int = num_classes // self.world_size + int(
            self.rank < num_classes % self.world_size
        )
        self.class_start: int = num_classes // self.world_size * self.rank + min(
            self.rank, num_classes % self.world_size
        )
        self.num_sample: int = int(self.sample_rate * self.num_local)
        self.last_batch_size: int = 0

        self.is_updated: bool = True
        self.init_weight_update: bool = True
        self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))

        # margin_loss
        if isinstance(margin_loss, Callable):
            self.margin_softmax = margin_loss
        else:
            raise

    def sample(self, labels, index_positive):
        """
            This functions will change the value of labels
            Parameters:
            -----------
            labels: torch.Tensor
                pass
            index_positive: torch.Tensor
                pass
            optimizer: torch.optim.Optimizer
                pass
        """
        with torch.no_grad():
            positive = torch.unique(labels[index_positive], sorted=True).cuda()
            if self.num_sample - positive.size(0) >= 0:
                perm = torch.rand(size=[self.num_local]).cuda()
                perm[positive] = 2.0
                index = torch.topk(perm, k=self.num_sample)[1].cuda()
                index = index.sort()[0].cuda()
            else:
                index = positive
            self.weight_index = index

            labels[index_positive] = torch.searchsorted(
                index, labels[index_positive])

        return self.weight[self.weight_index]

    def forward(
        self,
        local_embeddings: torch.Tensor,
        local_labels: torch.Tensor,
    ):
        """
        Parameters:
        ----------
        local_embeddings: torch.Tensor
            feature embeddings on each GPU(Rank).
        local_labels: torch.Tensor
            labels on each GPU(Rank).
        Returns:
        -------
        loss: torch.Tensor
            pass
        """
        local_labels.squeeze_()
        local_labels = local_labels.long()

        batch_size = local_embeddings.size(0)
        if self.last_batch_size == 0:
            self.last_batch_size = batch_size
        assert self.last_batch_size == batch_size, (
            f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")

        _gather_embeddings = [
            torch.zeros((batch_size, self.embedding_size)).cuda()
            for _ in range(self.world_size)
        ]
        _gather_labels = [
            torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
        ]
        _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
        distributed.all_gather(_gather_labels, local_labels)

        embeddings = torch.cat(_list_embeddings)
        labels = torch.cat(_gather_labels)

        labels = labels.view(-1, 1)
        index_positive = (self.class_start <= labels) & (
            labels < self.class_start + self.num_local
        )
        labels[~index_positive] = -1
        labels[index_positive] -= self.class_start

        if self.sample_rate < 1:
            weight = self.sample(labels, index_positive)
        else:
            weight = self.weight

        if self.sample_num_feat is not None and self.sample_num_feat < weight.size(1):
            with torch.no_grad():
                noise = torch.rand(weight.size(1), device=weight.device)  # noise in [0, 1]
                ids_shuffle = torch.argsort(noise)[: self.sample_num_feat]
            weight = weight.index_select(1, ids_shuffle)
            embeddings = embeddings.index_select(1, ids_shuffle)

        with torch.cuda.amp.autocast(self.fp16):
            if self.is_normlize:
                norm_embeddings = normalize(embeddings)
                norm_weight_activated = normalize(weight)
                logits = linear(norm_embeddings, norm_weight_activated)
            else:
                logits = linear(embeddings, weight)
        if self.fp16:
            logits = logits.float()
        if self.is_normlize:
            logits = logits.clamp(-1, 1)
        else:
            logits = torch.clip(logits, -64, 64)

        logits = self.margin_softmax(logits, labels)
        loss = self.dist_cross_entropy(logits, labels)
        return loss


class DistCrossEntropyFunc(torch.autograd.Function):
    """
    CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
    Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
    """

    @staticmethod
    def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
        """ """
        batch_size = logits.size(0)
        # for numerical stability
        max_logits, _ = torch.max(logits, dim=1, keepdim=True)
        # local to global
        distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
        logits.sub_(max_logits)
        logits.exp_()
        sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
        # local to global
        distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
        logits.div_(sum_logits_exp)
        index = torch.where(label != -1)[0]
        # loss
        loss = torch.zeros(batch_size, 1, device=logits.device)
        loss[index] = logits[index].gather(1, label[index])
        distributed.all_reduce(loss, distributed.ReduceOp.SUM)
        ctx.save_for_backward(index, logits, label)
        return loss.clamp_min_(1e-30).log_().mean() * (-1)

    @staticmethod
    def backward(ctx, loss_gradient):
        """
        Args:
            loss_grad (torch.Tensor): gradient backward by last layer
        Returns:
            gradients for each input in forward function
            `None` gradients for one-hot label
        """
        (
            index,
            logits,
            label,
        ) = ctx.saved_tensors
        batch_size = logits.size(0)
        one_hot = torch.zeros(
            size=[index.size(0), logits.size(1)], device=logits.device
        )
        one_hot.scatter_(1, label[index], 1)
        logits[index] -= one_hot
        logits.div_(batch_size)
        return logits * loss_gradient.item(), None


class DistCrossEntropy(torch.nn.Module):
    def __init__(self):
        super(DistCrossEntropy, self).__init__()

    def forward(self, logit_part, label_part):
        return DistCrossEntropyFunc.apply(logit_part, label_part)


class AllGatherFunc(torch.autograd.Function):
    """AllGather op with gradient backward"""

    @staticmethod
    def forward(ctx, tensor, *gather_list):
        gather_list = list(gather_list)
        distributed.all_gather(gather_list, tensor)
        return tuple(gather_list)

    @staticmethod
    def backward(ctx, *grads):
        grad_list = list(grads)
        rank = distributed.get_rank()
        grad_out = grad_list[rank]

        dist_ops = [
            distributed.reduce(
                grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
            if i == rank
            else distributed.reduce(
                grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
            )
            for i in range(distributed.get_world_size())
        ]
        for _op in dist_ops:
            _op.wait()

        grad_out *= len(grad_list)  # cooperate with distributed loss function
        return (grad_out, *[None for _ in range(len(grad_list))])


AllGather = AllGatherFunc.apply

In [23]:
# setting s, m1, m2, m3, sample_rate, num_feat based on https://github.com/deepglint/unicom/blob/main/retrieval.py#L34
s = 32
m1 = 1.0
m2 = 0.3
m3 = 0.0
sample_rate = 1.0
num_feat = 256 # random number, not based on defaults

# initializing CombinedMarginLoss and PartialFC
margin_loss = CombinedMarginLoss(s, m1, m2, m3)

import torch.distributed as dist
import os

# Initialize distributed processing with a single process
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', world_size=1, rank=0)

partial_fc = PartialFC_V2(
    margin_loss, embedding_size=EMBEDDING_SIZE,
    num_classes=NUM_CLUSTERS, sample_rate=sample_rate,
    sample_num_feat=num_feat)


# Obtain base model

In [24]:
import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_
from torch.utils.checkpoint import checkpoint
from torchvision.transforms import (CenterCrop, Compose, InterpolationMode,
                                    Normalize, Resize, ToTensor)


class VisionTransformer(nn.Module):
    def __init__(self, input_size=224, patch_size=32, in_channels=3, dim=768, embedding_size=768,
                 depth=12, num_heads=12, mlp_ratio=4, drop_path_rate=0.0, using_checkpoint=True):
        super().__init__()
        self.dim = dim
        self.patch_embed = PatchEmbedding(
            input_size, patch_size, in_channels, dim,)
        self.pos_embed = nn.Parameter(torch.zeros(
            1, self.patch_embed.num_patches, dim))
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        self.blocks = nn.ModuleList(
            [
                Block(dim, num_heads, mlp_ratio, dpr[i], self.patch_embed.num_patches, using_checkpoint) for i in range(depth)
            ])
        self.norm = nn.LayerNorm(dim)

        self.feature = nn.Sequential(
            nn.Linear(dim * self.patch_embed.num_patches, dim, False),
            nn.BatchNorm1d(dim, eps=2e-5),
            nn.Linear(dim, embedding_size, False),
            nn.BatchNorm1d(embedding_size, eps=2e-5))

        trunc_normal_(self.pos_embed, std=0.02)
        self.apply(self._init_weights)
        self.extra_gflops = 0.0
        for _block in self.blocks:
            self.extra_gflops += _block.extra_gflops

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        for func in self.blocks:
            x = func(x)
        x = self.norm(x.float())
        return torch.reshape(x, (B, self.patch_embed.num_patches * self.dim))

    def forward(self, x):
        x = self.forward_features(x)
        x = self.feature(x)
        return x


class Mlp(nn.Module):
    def __init__(self, dim, dim_hidden):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim_hidden)
        self.act = nn.ReLU6()
        self.fc2 = nn.Linear(dim_hidden, dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        with torch.cuda.amp.autocast(True):
            B, L, D = x.shape
            qkv = self.qkv(x).reshape(B, L, 3, self.num_heads,
                                      D // self.num_heads).permute(2, 0, 3, 1, 4)
        with torch.cuda.amp.autocast(False):
            q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            x = (attn @ v).transpose(1, 2).reshape(B, L, D)
        with torch.cuda.amp.autocast(True):
            x = self.proj(x)
        return x


class Block(nn.Module):
    def __init__(self, dim: int, num_heads: int, mlp_ratio: int = 4, drop_path: float = 0.0, patch_n: int = 32, using_checkpoint=False):
        super().__init__()
        self.using_checkpoint = using_checkpoint
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads)
        if drop_path > 0:
            self.drop_path = DropPath(drop_path)
        else:
            self.drop_path = nn.Identity()
        self.mlp = Mlp(dim, dim * mlp_ratio)
        self.extra_gflops = (num_heads * patch_n * (dim // num_heads) * patch_n * 2) / (1000**3)

    def forward_impl(self, x):
        with torch.cuda.amp.autocast(True):
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

    def forward(self, x):
        if self.using_checkpoint:
            return checkpoint(self.forward_impl, x)
        else:
            return self.forward_impl(x)


class PatchEmbedding(nn.Module):
    def __init__(self, input_size=224, patch_size=32, in_channels: int = 3, dim: int = 768):
        super().__init__()
        if isinstance(input_size, int):
            input_size = (input_size, input_size)
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        H = input_size[0] // patch_size[0]
        W = input_size[1] // patch_size[1]
        self.num_patches = H * W
        self.proj = nn.Conv2d(
            in_channels, dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


def build_model(name="ViT-L/14@336px"):
    if name == "ViT-B/32":
        model = VisionTransformer(
            input_size=224, patch_size=32, in_channels=3, dim=768, embedding_size=512,
            depth=12, num_heads=12, drop_path_rate=0.1, using_checkpoint=True)
    elif name == "ViT-B/16":
        model = VisionTransformer(
            input_size=224, patch_size=16, in_channels=3, dim=768, embedding_size=768,
            depth=12, num_heads=12, drop_path_rate=0.1, using_checkpoint=True)
    elif name == "ViT-L/14":
        model = VisionTransformer(
            input_size=224, patch_size=14, in_channels=3, dim=1024, embedding_size=768,
            depth=24, num_heads=16, drop_path_rate=0.1, using_checkpoint=True)
    elif name == "ViT-L/14@336px":
        model = VisionTransformer(
            input_size=336, patch_size=14, in_channels=3, dim=1024, embedding_size=768,
            depth=24, num_heads=16, drop_path_rate=0.1, using_checkpoint=True)
    return model


def _convert_image_to_rgb(image):
    return image.convert("RGB")


def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=InterpolationMode.BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073),
                  (0.26862954, 0.26130258, 0.27577711)),
    ])


def load_model_and_transform(name="ViT-L/14@336px"):
    if name == "ViT-B/32":
        return build_model(name), _transform(224)
    elif name == "ViT-B/16":
        return build_model(name), _transform(224)
    elif name == "ViT-L/14":
        return build_model(name), _transform(224)
    elif name == "ViT-L/14@336px":
        return build_model(name), _transform(336)
    else:
        raise



In [25]:
# from UNICOM https://github.com/deepglint/unicom/blob/main/unicom/model.py#L82

import hashlib
import os
import urllib
import warnings

import torch
from typing import List
from tqdm import tqdm

__all__ = ["load", "available_models"]

_MODELS = {
    "ViT-B/32": "https://github.com/deepglint/unicom/releases/download/b32/FP16-ViT-B-32.pt",
    "ViT-B/16": "https://github.com/deepglint/unicom/releases/download/b16/FP16-ViT-B-16.pt",
    "ViT-L/14": "https://github.com/deepglint/unicom/releases/download/l14/FP16-ViT-L-14.pt",
    "ViT-L/14@336px": "https://github.com/deepglint/unicom/releases/download/l14_336px/FP16-ViT-L-14-336px.pt",
}

_SHA256 = {
    "FP16-ViT-B-32.pt": "f9d5696a9b58dbbbefee2d31615ca59084f2895a0fdd2ca4c235e0f9b2793f7a",
    "FP16-ViT-B-16.pt": "c04f324f7c3b4435667236ec6c0eca1cd62f9d64fbfc2d06f8e8e60e6497edef",
    "FP16-ViT-L-14.pt": "ff3ab62ff782876460099e6e0ee17b73a7c01109de2fffd595f16f4129404bbd",
    "FP16-ViT-L-14-336px.pt": "3916ab5aed3b522fc90345be8b4457fe5dad60801ad2af5a6871c0c096e8d7ea",
}


def available_models() -> List[str]:
    """Returns the names of available CLIP models"""
    return list(_MODELS.keys())


def rm_module_from_state_dict(state_dict: dict) -> dict:
    result = {}
    for k, value in state_dict.items():

        if "module." in k:
            k_removed = k.split("module.")[-1]
            result[k_removed] = value
        else:
            result[k] = value
    return result


# copy from https://github.com/openai/CLIP/blob/main/clip/clip.py#L43
def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    expected_sha256 = _SHA256[filename]
    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(
            f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
            return download_target
        else:
            warnings.warn(
                f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
        raise RuntimeError(
            "Model has been downloaded but the SHA256 checksum does not not match")

    return download_target


# copy from https://github.com/openai/CLIP/blob/main/clip/clip.py#L94
def load(name: str, device: str = "cpu", download_root: str = None):
    if name in _MODELS:
        model_path = _download(
            _MODELS[name], download_root or os.path.expanduser("~/.cache/unicom"))
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}")
    with open(model_path, 'rb') as opened_file:
        state_dict = torch.load(opened_file, map_location="cpu")

    model, transform = load_model_and_transform(name)
    state_dict_fp32 = {}
    for k, v in state_dict.items():
        state_dict_fp32[k] = v.float()

    model.load_state_dict(state_dict)
    return model, transform

# Training Loop

In [26]:
# get_transform func as per https://github.com/deepglint/unicom/blob/main/retrieval.py#L408
from torchvision import transforms
import PIL


def get_transform(
        image_size: int = 224,
        is_train: bool = True
):
    from timm.data import create_transform
    mean = (0.48145466, 0.4578275, 0.40821073)
    std = (0.26862954, 0.26130258, 0.27577711)

    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=image_size,
            is_training=True,
            color_jitter=0.4,
            auto_augment='rand-m9-mstd0.5-inc1',
            interpolation='bicubic',
            re_prob=0.25,
            re_mode='pixel',
            re_count=1,
            mean=mean,
            std=std,
        )
        return transform

    # eval transform
    t = []
    if image_size <= 224:
        crop_pct = 224 / 256
    else:
        crop_pct = 1.0
    size = int(image_size / crop_pct)
    t.append(transforms.Resize(size, interpolation=PIL.Image.BICUBIC))
    t.append(transforms.CenterCrop(image_size))
    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(mean, std))
    return transforms.Compose(t)


In [27]:
from torch import optim
from tqdm import tqdm

BATCH_SIZE = 128

# defaults from retrieval.py
LR = 0.0001
LR_PFC = 5.0

num_epochs = 10

partial_fc.train().cuda()

backbone_model, transform = load("ViT-L/14@336px", device="cuda")
train_transform = get_transform(336)

train_dataset = VLM4BioEmbeddingDataset(train_df, transform=train_transform)
test_dataset = VLM4BioEmbeddingDataset(test_df, transform=train_transform)

steps_per_epoch = len(train_dataset) // BATCH_SIZE

dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
optimizer = torch.optim.AdamW(
    [
        {"params": backbone_model.parameters(), "lr": LR},
        {"params": partial_fc.parameters(), "lr": LR * LR_PFC}
    ]
)
lr_scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=[LR, LR*LR_PFC],
    steps_per_epoch=len(dataloader),
    epochs=num_epochs,
    pct_start=0.1
)

backbone_model.train()
backbone_model.float().cuda()
partial_fc.train()

for epoch in range(num_epochs):
    epoch_loss = 0
    with tqdm(dataloader, unit="batch") as batch_epoch:
        batch_epoch.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
        for images, labels in batch_epoch:
          images = images.cuda()
          labels = labels.long().cuda()

          optimizer.zero_grad()

          with torch.cuda.amp.autocast(True):
              batch_embeddings = backbone_model(images).float().cuda()
              # batch_embeddings = normalize(batch_embeddings)
              loss = partial_fc(batch_embeddings, labels)

          loss.backward()
          optimizer.step()
          lr_scheduler.step()

          epoch_loss += loss.item()
          batch_epoch.set_postfix(loss=loss.item())

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# Final centroids after training
updated_centroids = partial_fc.weight.detach().cpu().numpy()

# Save model and centroids
torch.save(backbone_model.state_dict(), 'trained_model.pth')
np.save('final_centroids.npy', updated_centroids)

print("Training complete. Model and centroids saved.")



100%|█████████████████████████████████████| 1.69G/1.69G [03:18<00:00, 9.14MiB/s]
  state_dict = torch.load(opened_file, map_location="cpu")
  with torch.cuda.amp.autocast(True):
  return fn(*args, **kwargs)
  with torch.cuda.amp.autocast(True):
  with torch.cuda.amp.autocast(True):
  with torch.cuda.amp.autocast(False):
  with torch.cuda.amp.autocast(True):
  with torch.cuda.amp.autocast(self.fp16):
Epoch [1/10]:   0%|          | 0/194 [00:07<?, ?batch/s]


RuntimeError: Function AllGatherFuncBackward returned an invalid gradient at index 0 - got [128, 512] but expected shape compatible with [128, 768]

# Evaluate using test set

# CLEANUP

In [None]:
dist.destroy_process_group()
