# Contrastive Language-Image Pretraining with SogCLR

### **Introduction**

In this tutorial, you will learn how to conduct contrastive language-image pretraining by optimizing the [Global Contrastive Loss](https://arxiv.org/abs/2202.12387) (GCL) on a subset of the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/) dataset. Also, you will learn how to evaluate the model on retrieval task using the [MSCOCO](https://cocodataset.org/#home) dataset and zero-shot classification task using the [ImageNet](https://www.image-net.org/challenges/LSVRC/index.php) dataset. The code is based on [iSogCLR's](https://github.com/zhqiu/contrastive-learning-iSogCLR) codebase, which includes the implementation of CLIP, SogCLR and iSogCLR.

### Preparation

First, we:

1. Download the source code and data
2. Install required packages

In [None]:
!git clone -b project https://github.com/xywei00/csce689_iSogCLR.git iSogCLR

!export PYTHONPATH="$PYTHONPATH:./iSogCLR/bimodal_exps"
!export HUGGINGFACE_HUB_CACHE='./checkpoints/huggingface'
!mkdir checkpoints

!gdown 142xxRoMaHxX3BIfCw_1b_G_dgu-02Yq3    # clip_train.tar.gz
!gdown 142zQjlOw0Xw4tKzXMrQjYE6NtGRTeasT    # cc3m_subset_100k.tar.gz
!gdown 142tMsnclHTTPpnTXHSeNgTUlBk4She6o    # ms_coco_val.tar.gz
!gdown 1NXhfhwFy-nhdABACkodgYqm9pomDKE39    # val.tar

!mkdir datasets
!mkdir -p datasets/imagenet
!tar xf clip_train.tar.gz
!tar xf cc3m_subset_100k.tar.gz -C datasets
!tar xf mscoco_val.tar.gz -C datasets
!tar xf val.tar -C datasets/imagenet

!pip install -r ./iSogCLR/requirements_colab.txt    # there may be pip warnings/ errors, should be fine to ignore them

### Training

The following command runs the training script to train a ResNet50 (pretrained on ImageNet) and a DistilBERT (pretrained on BookCorpus and English Wikipedia) on the cc3m dataset using the SogCLR loss for 30 epochs with temperature 0.01.

In [None]:
!CUDA_VISIBLE_DEVICES=0 python ./iSogCLR/bimodal_exps/clip.py \
    --data_path ./datasets \
    --ann_path ./clip_train \
    --train_file cc3m_train_subset.json \
    --train_image_root cc3m_subset_100k \
    --output_dir output/sogclr_cc3m_g0.8_e30 \
    --init_model \
    --use_amp \
    --ita_type sogclr \
    --tau_init 0.01 \
    --sogclr_gamma 0.8 \
    --eta_init 0.03 --sched cosine \
    --no-distributed \
    --epochs 30

### Evaluation

The following command runs the evaluation script to evaluate the retrieval performance of the trained model on the MSCOCO validation dataset and the zero-shot classification performance on the ImageNet validation dataset. The evaluation command is obtained by appending `--evaluate --checkpoint /path/to/your/checkpoint --zs_dataset imagenet --zs_datafolder /path/to/imagenet/val` to the training command.

In [None]:
!CUDA_VISIBLE_DEVICES=0 python ./iSogCLR/bimodal_exps/clip.py \
    --data_path ./datasets \
    --ann_path ./clip_train \
    --train_file cc3m_train_subset.json \
    --train_image_root cc3m_subset_100k \
    --output_dir output/isogclr_cc3m_g0.8_e30 \
    --init_model \
    --use_amp \
    --ita_type sogclr \
    --tau_init 0.01 \
    --sogclr_gamma 0.8 \
    --eta_init 0.03 --sched cosine \
    --no-distributed \
    --epochs 30 \
    --evaluate --checkpoint './output/sogclr_cc3m_g0.8_e30/checkpoint_30.pth' \
    --zs_dataset imagenet --zs_datafolder ./datasets/imagenet/val

### Benchmarks

The following results are recall at 1 results on the provided MSCOCO and ImageNet datasets. The first row of results are from the model trained using the CLIP loss, and the second row of results are from the model trained using the SogCLR loss. All results are based on a batch size of 128 for 30-epoch pretraining. IR@1 denotes the recall at 1 of image retrieval on MSCOCO, TR@1 denotes the recall at 1 of text retrieval on MSCOCO, and ACC@1 denotes the top 1 accuracy on ImageNet. Average denotes the average of the three metrics.

| Method | MSCOCO TR@1 | MSCOCO IR@1 | ImageNet ACC@1 | Average |
|:----------:|:--------:|:--------:|:--------:|:--------:|
| CLIP | 12.0 | 9.32 | 21.35 | 14.22 |
| SogCLR |  14.38  |  10.73  | 24.54 | 16.55 |

In [None]:

# Define new loss functions
def vicreg_loss(preds, targets, sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0):
    """VICReg loss implementation"""
    # Similarity term
    sim_loss = F.mse_loss(preds, targets)
    # Variance term
    std_pred = torch.std(preds, dim=0) + 1e-4
    std_targ = torch.std(targets, dim=0) + 1e-4
    std_loss = torch.mean(F.relu(1 - std_pred)) + torch.mean(F.relu(1 - std_targ))
    # Covariance term
    cov_pred = preds.T @ preds / (preds.shape[0] - 1)
    cov_targ = targets.T @ targets / (targets.shape[0] - 1)
    cov_loss = torch.sum(cov_pred ** 2) + torch.sum(cov_targ ** 2)
    return sim_coeff * sim_loss + std_coeff * std_loss + cov_coeff * cov_loss

def supcon_loss(features, labels, temperature=0.07):
    """Supervised Contrastive Loss"""
    device = features.device
    labels = labels.unsqueeze(-1)
    mask = torch.eq(labels, labels.T).float().to(device)
    logits = torch.div(torch.matmul(features, features.T), temperature)
    exp_logits = torch.exp(logits) * (1 - torch.eye(features.shape[0]).to(device))
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
    return -mean_log_prob_pos.mean()


In [None]:

# Initialize optimizers
def get_optimizer(model, optimizer_name, lr=2e-4, weight_decay=1e-2):
    if optimizer_name == "adamw":
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_name == "sgd":
        return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unsupported optimizer: {optimizer_name}")


In [None]:

# Training loop with different loss functions
for loss_function_name in ["global_contrastive", "vicreg", "supcon"]:
    for optimizer_name in ["adamw", "sgd"]:
        print(f"Training with {loss_function_name} and {optimizer_name}")
        optimizer = get_optimizer(model, optimizer_name)
        
        for epoch in range(num_epochs):
            model.train()
            for batch in train_loader:
                images, texts, labels = batch
                optimizer.zero_grad()

                # Forward pass
                image_features = model.visual_encoder(images)
                text_features = model.text_encoder(texts)

                # Compute loss
                if loss_function_name == "global_contrastive":
                    loss = global_contrastive_loss(image_features, text_features)
                elif loss_function_name == "vicreg":
                    loss = vicreg_loss(image_features, text_features)
                elif loss_function_name == "supcon":
                    loss = supcon_loss(image_features, labels)
                else:
                    raise ValueError(f"Unsupported loss function: {loss_function_name}")

                # Backward and optimization
                loss.backward()
                optimizer.step()

                # Log metrics
                print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")


In [None]:

# Evaluation for retrieval and zero-shot classification
def evaluate(model, val_loader, metric="recall@1"):
    model.eval()
    with torch.no_grad():
        results = []
        for batch in val_loader:
            images, texts = batch
            image_features = model.visual_encoder(images)
            text_features = model.text_encoder(texts)

            if metric == "recall@1":
                # Calculate recall@1
                pass  # Add your Recall@1 calculation
            elif metric == "zeroshot_top1":
                # Calculate zero-shot top-1 accuracy
                pass  # Add your zero-shot accuracy calculation

            results.append(metric_result)
    return np.mean(results)


In [None]:

# Log results for each experiment
experiment_results = {}
for loss_function_name, optimizer_name in experiments:
    result_key = f"{loss_function_name}_{optimizer_name}"
    experiment_results[result_key] = evaluate(model, val_loader)

# Print and save results
print(experiment_results)
