In [None]:
import os
os.chdir("..")

import data_utils
import torch
import open_clip

from tqdm import tqdm
from torch.utils.data import DataLoader

In [2]:
device = "cuda"
batch_size = 128

In [None]:
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai",device=device)

In [6]:
train_data = data_utils.get_data("cub_train", clip_preprocess)
val_data = data_utils.get_data("cub_val", clip_preprocess)
test_data = data_utils.get_data("cub_test", clip_preprocess)

In [22]:
def get_clip_image_features(model, dataset, batch_size=256, device = "cuda"):
    #_make_save_dir(save_name)
    all_features = []
    all_concepts = []

    with torch.no_grad():
        for images, labels, concept_labels in tqdm(DataLoader(dataset, batch_size)):
            features = model.encode_image(images.to(device).float())
            all_features.append(features)
            all_concepts.append(torch.stack(concept_labels, dim=1).float().to(device))
    return torch.cat(all_features),  torch.cat(all_concepts)

In [23]:
train_feats, train_c = get_clip_image_features(clip_model, train_data, batch_size, device)
val_feats, val_c = get_clip_image_features(clip_model, val_data, batch_size, device)
test_feats, test_c = get_clip_image_features(clip_model, test_data, batch_size, device)

100%|██████████| 38/38 [00:20<00:00,  1.88it/s]
100%|██████████| 10/10 [00:05<00:00,  1.96it/s]
100%|██████████| 46/46 [00:24<00:00,  1.90it/s]


In [24]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define Logistic Regression Model
class LogisticRegression(nn.Module):
    def __init__(self, input_dim, n_concepts):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, n_concepts)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        return self.sigmoid(self.linear(x))

# Initialize model, loss function, and optimizer
input_dim = train_feats.shape[1]
n_concepts = train_c.shape[1]
model = LogisticRegression(input_dim, n_concepts).to(device)
criterion = nn.BCELoss()  # Binary Cross Entropy Loss
#optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training Loop
num_epochs = 2400#2400
for epoch in range(num_epochs):
    model.train()
    
    # Forward pass
    outputs = model(train_feats)#/train_feats.norm(dim=1, keepdim=True))
    loss = criterion(outputs, train_c.float())
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
        model.eval()
        with torch.no_grad():
            predictions = model(val_feats)#/val_feats.norm(dim=1, keepdim=True))
            predictions = (predictions >= 0.5).float()
            accuracy = (predictions == val_c).float().mean()
            print(f"Val Accuracy: {accuracy.item():.4f}")

# Evaluate on the test set
model.eval()
with torch.no_grad():
    predictions = model(test_feats)#/test_feats.norm(dim=1, keepdim=True))
    bin_predictions = (predictions >= 0.5).float()
    accuracy = (bin_predictions == test_c).float().mean()
    print(f"Test Accuracy: {accuracy.item():.4f}")

Epoch [100/2400], Loss: 0.3117
Val Accuracy: 0.8675
Epoch [200/2400], Loss: 0.2716
Val Accuracy: 0.8823
Epoch [300/2400], Loss: 0.2533
Val Accuracy: 0.8879
Epoch [400/2400], Loss: 0.2422
Val Accuracy: 0.8909
Epoch [500/2400], Loss: 0.2344
Val Accuracy: 0.8932
Epoch [600/2400], Loss: 0.2284
Val Accuracy: 0.8946
Epoch [700/2400], Loss: 0.2236
Val Accuracy: 0.8956
Epoch [800/2400], Loss: 0.2194
Val Accuracy: 0.8961
Epoch [900/2400], Loss: 0.2158
Val Accuracy: 0.8972
Epoch [1000/2400], Loss: 0.2127
Val Accuracy: 0.8976
Epoch [1100/2400], Loss: 0.2098
Val Accuracy: 0.8977
Epoch [1200/2400], Loss: 0.2072
Val Accuracy: 0.8982
Epoch [1300/2400], Loss: 0.2048
Val Accuracy: 0.8986
Epoch [1400/2400], Loss: 0.2026
Val Accuracy: 0.8990
Epoch [1500/2400], Loss: 0.2005
Val Accuracy: 0.8992
Epoch [1600/2400], Loss: 0.1985
Val Accuracy: 0.8993
Epoch [1700/2400], Loss: 0.1967
Val Accuracy: 0.8995
Epoch [1800/2400], Loss: 0.1950
Val Accuracy: 0.8995
Epoch [1900/2400], Loss: 0.1934
Val Accuracy: 0.8995
Ep

In [None]:
torch.save(model.linear.weight)
#torch.save(predictions, "data/cub/clip_vit_b_32_probe_c_preds.pt")

In [35]:
torch.save(model.linear, "data/cub_linear_probe.pth")