In [13]:
from snorkel.labeling import labeling_function
import json
import os
import numpy as np
from snorkel.labeling import LFApplier
from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import LabelModel
from snorkel.analysis import metric_score
from snorkel.utils import probs_to_preds
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
from transformers import CLIPProcessor, CLIPModel

In [8]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
import numpy as np

def calculate_metrics(y_true, y_pred, abstain_class=-1):
    # Filter out samples where prediction is -1
    valid_indices = y_pred != abstain_class
    y_true_filtered = y_true[valid_indices]
    y_pred_filtered = y_pred[valid_indices]

    # Compute metrics
    precision = precision_score(y_true_filtered, y_pred_filtered, average='macro')
    recall = recall_score(y_true_filtered, y_pred_filtered, average='macro')
    f1 = f1_score(y_true_filtered, y_pred_filtered, average='macro')
    accuracy = accuracy_score(y_true_filtered, y_pred_filtered)

    return {
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1,
        'Accuracy': accuracy
    }

In [2]:

@labeling_function()
def llava_7b(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter/'
    llava_7b_results = 'eurosat_llava7b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)

@labeling_function()
def llava_13b(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter'
    llava_7b_results = 'eurosat_llava13b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)

@labeling_function()
def bakllava(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter'
    llava_7b_results = 'eurosat_bakllava.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)

@labeling_function()
def llava_llama3(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter'
    llava_7b_results = 'eurosat_llava_llama3.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)

@labeling_function()
def llava_phi3(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter'
    llava_7b_results = 'eurosat_llava_phi3.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)


@labeling_function()
def moondream(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter'
    llava_7b_results = 'eurosat_moondream.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)

@labeling_function()
def llava_34b(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter'
    llava_7b_results = 'eurosat_llava34b.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)

@labeling_function()
def minicpm(image_name):
    root_path = '../prompting_framework/prompting_results/eurosat/interpreter'
    llava_7b_results = 'eurosat_minicpm.json'
    path_to_llava_7b_results = os.path.join(root_path,llava_7b_results)
    with open(path_to_llava_7b_results, 'r') as file:
        data = json.load(file)

    return data.get(image_name, -1)

In [3]:
train_data_json_path = '../prompting_framework/prompting_results/eurosat/interpreter/train_gt.json'
dev_data_json_path = '../prompting_framework/prompting_results/eurosat/interpreter/test_gt.json'

with open(train_data_json_path, 'r') as file:
    train_data = json.load(file)

# Extract and pad image names, ensuring they are 5 digits long before the '.png'
train_image_names = []
for item in train_data:
    train_image_names.append(item)

with open(dev_data_json_path, 'r') as file:
    dev_data = json.load(file)
    
dev_image_names = []
Y_dev = []
for item in dev_data:
    Y_dev.append(dev_data[item])
    dev_image_names.append(item)

print(f"There are {len(train_image_names)} images in the Train set.")
print(f"There are {len(dev_image_names)} images in the dev set.")
print(f"There are {len(Y_dev)} labels in the dev set.")


There are 13500 images in the Train set.
There are 8100 images in the dev set.
There are 8100 labels in the dev set.


In [4]:
from snorkel.labeling import LFApplier

list_of_all_the_models = ['llava_34b',
       'llava_13b',
       'llava_phi3',
       'llava_7b',
       'llava_llama3',
       'minicpm',
       'bakllava'
       ]

lfs = [llava_34b,
       llava_13b,
       llava_phi3,
       llava_7b,
       llava_llama3,
       minicpm,
       bakllava
       ]

applier = LFApplier(lfs)

In [5]:
from snorkel.labeling import LFAnalysis

L_dev = applier.apply(dev_image_names)
L_train = applier.apply(train_image_names)

8100it [01:59, 67.71it/s]
13500it [03:28, 64.87it/s]


In [10]:
Y_dev = np.array(Y_dev)
LFAnalysis(L_dev, lfs).lf_summary(Y_dev)

Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.
llava_34b,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]",0.901481,0.901358,0.765679,4462,2840,0.611065
llava_13b,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]",0.915062,0.914938,0.765309,4360,3052,0.588235
llava_phi3,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]",0.340988,0.337654,0.280247,1085,1677,0.392831
llava_7b,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]",0.973704,0.965679,0.814444,4367,3520,0.553696
llava_llama3,4,"[1, 3, 4, 5, 6, 7, 8, 9]",0.908025,0.905185,0.756667,3016,4339,0.410061
minicpm,5,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]",0.772222,0.768272,0.645802,2414,3841,0.385931
bakllava,6,"[1, 3, 4, 5, 6, 7, 8, 9]",0.187654,0.186914,0.174568,935,585,0.615132


In [11]:
from snorkel.labeling.model import LabelModel
from snorkel.analysis import metric_score
from snorkel.utils import probs_to_preds

label_model = LabelModel(cardinality=10, verbose=False)
label_model.fit(L_train, Y_dev, n_epochs=5000, log_freq=500, seed=12345)



probs_dev = label_model.predict_proba(L_dev)
preds_dev = probs_to_preds(probs_dev)

metrics = calculate_metrics(Y_dev, preds_dev)
for metric, value in metrics.items():
    print(f"{metric}: {value}")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:01<00:00, 2819.32epoch/s]


Precision: 0.6121511260118815
Recall: 0.6347555555555555
F1 Score: 0.6080876216368022
Accuracy: 0.6227160493827161


In [14]:
import torch.nn.functional as F

def expected_cross_entropy_loss(logits, target_distributions):
    """
    Computes the expected cross-entropy loss for a batch of predictions and target distributions.

    Parameters:
    logits (torch.Tensor): The raw output from the model of shape (batch_size, num_classes).
    target_distributions (torch.Tensor): The target class distributions of shape (batch_size, num_classes),
                                         where each row is a probability distribution over classes.

    Returns:
    torch.Tensor: The expected cross-entropy loss.
    """
    # Convert logits to log probabilities
    log_probs = F.log_softmax(logits, dim=1)
    
    # Compute the element-wise product between target distributions and log probabilities
    # Then, sum across classes to get the cross-entropy for each instance
    cross_entropy = -torch.sum(target_distributions * log_probs, dim=1)
    
    # Take the mean over the batch
    loss = cross_entropy.mean()
    
    return loss
    
class EuroSAT_Dataset(Dataset):
    def __init__(self, image_names, root_dir, labels, target_dists, processor):
        """
        Args:
            data_frame (DataFrame): DataFrame containing image names and labels.
            image_dir (str): Directory where the images are stored.
            processor (CLIPProcessor): CLIP processor for preprocessing images.
        """
        self.image_names = image_names
        self.root_dir = root_dir
        self.labels = labels
        self.target_dists = target_dists
        self.processor = processor

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Get image name and label from the dataframe
        img_name = os.path.join(self.root_dir, self.image_names[idx])
        label = self.labels[idx]
        target_dist = self.target_dists[idx]

        # Load and process image
        image = Image.open(img_name).convert('RGB')
        inputs = self.processor(images=image, return_tensors="pt")

        # Return image and label
        return inputs['pixel_values'].squeeze(0), torch.tensor(label, dtype=torch.long), torch.tensor(target_dist)

# MLP head to be added after the CLIP model
class MLPHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLPHead, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        return self.fc(x)

# CLIP model with MLP head for binary classification
class CLIPWithMLP(nn.Module):
    def __init__(self, clip_model, mlp_head):
        super(CLIPWithMLP, self).__init__()
        self.clip_model = clip_model
        self.mlp_head = mlp_head

        # Freeze CLIP's parameters
        for param in self.clip_model.parameters():
            param.requires_grad = False

    def forward(self, image):
        # Extract image features from CLIP
        image_features = self.clip_model.get_image_features(pixel_values=image)
        # Pass through the MLP head
        outputs = self.mlp_head(image_features)
        return outputs

# Training function
def train_model(model, train_loader, dev_loader, criterion, optimizer, device, epochs=5):
    model.train()  # Set model to training mode
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels, target_dist in tqdm(train_loader):
            images, labels, target_dist = images.to(device), labels.to(device), target_dist.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # loss = expected_cross_entropy_loss(outputs, target_dist)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")
        evaluate_model(model, dev_loader, device)

# Evaluation function to compute precision, recall, and F1-score
def evaluate_model(model, dev_loader, device):
    model.eval()  # Set model to evaluation mode
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels, target_dist in tqdm(dev_loader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    # Compute metrics
    metrics = calculate_metrics(np.array(all_labels), np.array(all_preds))
    for metric, value in metrics.items():
        print(f"{metric}: {value}")

    return metrics

In [16]:
root_dir = "/home1/pupil/goowfd/CVPR_2025/Eurosat_all_images/"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

probs_train = label_model.predict_proba(L_train)
preds_train = probs_to_preds(probs_train)

# Create datasets and dataloaders
train_dataset = EuroSAT_Dataset(image_names=train_image_names, 
                                    root_dir=root_dir, 
                                    labels=preds_train,
                                    target_dists=probs_train,
                                    processor=processor)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=16)

dev_dataset = EuroSAT_Dataset(image_names=dev_image_names, 
                                  root_dir=root_dir, 
                                  labels=Y_dev, 
                                  target_dists=probs_dev,
                                  processor=processor)
dev_loader = DataLoader(dev_dataset, batch_size=8, shuffle=False)

# Define MLP head (the dimension is based on CLIP output size)
mlp_head = MLPHead(input_dim=512, output_dim=2)  # Binary classification, so output_dim = 2

# Create the full model with CLIP + MLP
model = CLIPWithMLP(clip_model=clip_model, mlp_head=mlp_head)
model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.mlp_head.parameters(), lr=0.0001)

# Train the model
epochs = 10
train_model(model, train_loader, dev_loader, criterion, optimizer, device, epochs=epochs)

# Evaluate the model
evaluate_model(model, dev_loader, device)


  0%|                                                                                                                    | 0/106 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [4,0,0] Assertion `t >= 0 && t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
