In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
from PIL import Image
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Criterion(nn.Module):
    """
    Batch-based classifcation loss
    """
    def __init__(self):
        super(Criterion, self).__init__()
    
    def forward(self, scores):
        return F.cross_entropy(
            scores, 
            torch.arange(scores.shape[0]).long().to(scores.device)
        )


class Combiner(nn.Module):
    """ TODO: Combiner module, which fuses textual and visual information.
    Given an image feature and a text feature, you should fuse them to get a fused feature. The dimension of the fused feature should be embed_dim.
    Hint: You can concatenate image and text features and feed them to a FC layer, or you can devise your own fusion module, e.g., add, multiply, or attention, to achieve a higher retrieval score.
    """
    def __init__(self, vision_feature_dim, text_feature_dim, embed_dim):
        super(Combiner, self).__init__()
        # Fully connected layer to project concatenated features to embed_dim
        self.fc = nn.Linear(vision_feature_dim + text_feature_dim, embed_dim)

    def forward(self, image_features, text_features):
        # Concatenate image and text features
        combined_features = torch.cat((image_features, text_features), dim=-1)
        # Pass through fully connected layer
        fused_features = self.fc(combined_features)
        return fused_features


class Model(nn.Module):
    """
    CLIP-based Composed Image Retrieval Model.
    """
    def __init__(self, vision_feature_dim, text_feature_dim, embed_dim):
        super(Model, self).__init__()
        self.vision_feature_dim = vision_feature_dim
        self.text_feature_dim = text_feature_dim
        self.embed_dim = embed_dim

        # Load clip model and freeze its parameters
        self.clip_model, self.preprocess = clip.load("ViT-B/32", device=device)
        for p in self.clip_model.parameters():
            p.requires_grad = False
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

        self.combiner = Combiner(vision_feature_dim, text_feature_dim, embed_dim)
    
    def train(self):
        self.combiner.train()

    def eval(self):
        self.combiner.eval()
    
    def encode_image(self, image_paths):
        """ TODO: Encode images to get image features by the vision encoder of clip model. See https://github.com/openai/CLIP
        Note: The clip model has loaded in the __init__() function. You do not need to create and load it on your own.

        Args:
            Image_paths (list[str]): a list of image paths.
        
        Returns:
            vision_features (torch.Tensor): image features.
        """
        images = [self.preprocess(Image.open(path)).unsqueeze(0) for path in image_paths]
        images = torch.cat(images).to(device)

        with torch.no_grad():
            vision_features = self.clip_model.encode_image(images)
        
        return vision_features.float() # Convert to float32 data type

    def encode_text(self, texts):
        """ TODO: Encode texts to get text features by the text encoder of clip model. See https://github.com/openai/CLIP
        Note: The clip model has loaded in the __init__() function. You do not need to create and load it on your own.

        Args:
            texts (list[str]): a list of captions.
        
        Returns:
            text_features (torch.Tensor): text features.
        """
        tokens = clip.tokenize(texts).to(device)
        
        with torch.no_grad():
            text_features = self.clip_model.encode_text(tokens)

        return text_features.float()# Convert to float32 data type

    def inference(self, ref_image_paths, texts):
        with torch.no_grad():
            ref_vision_features = self.encode_image(ref_image_paths)
            text_features = self.encode_text(texts)
            fused_features = self.combiner(ref_vision_features, text_features)
        return fused_features
    
    def forward(self, ref_image_paths, texts, tgt_image_paths):
        """
        Args:
            ref_image_paths (list[str]): image paths of reference images.
            texts (list[str]): captions.
            tgt_image_paths (list[str]): image paths of reference images.
        
        Returns:
            scores (torch.Tensor): score matrix with shape batch_size * batch_size.
        """
        batch_size = len(ref_image_paths)

        # Extract vision and text features
        with torch.no_grad():
            ref_vision_features = self.encode_image(ref_image_paths)
            tgt_vision_features = self.encode_image(tgt_image_paths)
            text_features = self.encode_text(texts)
        assert ref_vision_features.shape == torch.Size([batch_size, self.vision_feature_dim])
        assert tgt_vision_features.shape == torch.Size([batch_size, self.vision_feature_dim])
        assert text_features.shape == torch.Size([batch_size, self.text_feature_dim])

        # Fuse vision and text features 
        fused_features = self.combiner(ref_vision_features, text_features)
        assert fused_features.shape == torch.Size([batch_size, self.embed_dim])

        # L2 norm
        fused_features = F.normalize(fused_features)
        tgt_vision_features = F.normalize(tgt_vision_features)

        # Calculate scores
        scores = self.temperature.exp() * fused_features @ tgt_vision_features.t()
        assert scores.shape == torch.Size([batch_size, batch_size])

        return scores

# Training function
def train(data_loader, model, criterion, optimizer, log_step=15):
    model.train()
    for i, (_, ref_img_paths, tgt_img_paths, raw_captions) in enumerate(data_loader):
        # Clear the previous gradients
        optimizer.zero_grad()

        # Forward pass through the model
        scores = model(ref_img_paths, raw_captions, tgt_img_paths)

        # Compute loss
        loss = criterion(scores)

        # Backpropagation
        loss.backward()

        # Optimizer step to update parameters
        optimizer.step()

        # Log the training loss every log_step
        if i % log_step == 0:
            print("training loss: {:.3f}".format(loss.item()))
            
# Validation function
def eval_batch(data_loader, model, ranker):
    model.eval()
    ranker.update_emb(model)
    rankings = []
    for meta_info, ref_img_paths, _, raw_captions in data_loader:
        with torch.no_grad():
            fused_features = model.inference(ref_img_paths, raw_captions)
            target_asins = [ meta_info[m]['target'] for m in range(len(meta_info)) ]
            rankings.append(ranker.compute_rank(fused_features, target_asins))
    metrics = {}
    rankings = torch.cat(rankings, dim=0)
    metrics['score'] = 1 - rankings.mean().item() / ranker.data_emb.size(0)
    model.train()
    return metrics

def val(data_loader, model, ranker, best_score):
    model.eval()
    metrics = eval_batch(data_loader, model, ranker)
    dev_score = metrics['score']
    best_score = max(best_score, dev_score)
    print('-' * 77)
    print('| score {:8.5f} / {:8.5f} '.format(dev_score, best_score))
    print('-' * 77)
    print('best_dev_score: {}'.format(best_score))
    return best_score

In [17]:
def predict_custom_input(model, ref_image_paths, texts, tgt_image_paths=None):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        # Encode reference images and text
        ref_vision_features = model.encode_image(ref_image_paths)
        text_features = model.encode_text(texts)
        
        if tgt_image_paths is not None:
            tgt_vision_features = model.encode_image(tgt_image_paths)
            fused_features = model.combiner(ref_vision_features, text_features)
            fused_features = F.normalize(fused_features)
            tgt_vision_features = F.normalize(tgt_vision_features)
            # Calculate similarity score
            scores = model.temperature.exp() * fused_features @ tgt_vision_features.t()
            return scores
        else:
            # Only return fused features if no target images are provided
            fused_features = model.combiner(ref_vision_features, text_features)
            return fused_features

In [36]:
def load_model(vision_feature_dim, text_feature_dim, embed_dim, model_path):
    model = Model(vision_feature_dim, text_feature_dim, embed_dim)
    model.load_state_dict(torch.load("/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/CIR/trained_model.pth", map_location=device))
    model.to(device)
    model.eval()
    return model


In [37]:
from resize_images import resize_images_parallel

# Define input directories and output directories for resized images
input_ref_dir = "/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/inputref"  # Directory containing your reference images
input_tgt_dir = "/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/outputref"  # Directory containing your target images
output_ref_dir = ""  # Directory to save resized reference images
output_tgt_dir = ""  # Directory to save resized target images
image_size = 256  # Size to resize images to (256x256 in this case)

# Resize reference images
resize_images_parallel(input_ref_dir, output_ref_dir, (image_size, image_size))

# Resize target images (if available)
resize_images_parallel(input_tgt_dir, output_tgt_dir, (image_size, image_size))

ModuleNotFoundError: No module named 'joblib'

In [39]:
# Full process of prediction using custom input
def main_predict():
    # Define your parameters
    vision_feature_dim = 512
    text_feature_dim = 512
    embed_dim = 512
    model_path = "path_to_your_trained_model.pth"
    
    # Load your model
    model = load_model(vision_feature_dim, text_feature_dim, embed_dim, 'CIR/trained_model.pth')

    # Define custom input
    ref_image_paths = ["/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/inputref/i1.png", "/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/inputref/i2.png"]
    texts = ["A man standing in front of a red car.", "A dog running on the beach during sunset."]
    
    # Optionally define target image paths if available
    tgt_image_paths = ["/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/outputref/q1.png", "/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/outputref/q2.png"]

    # Predict scores between reference and target images
    scores = predict_custom_input(model, ref_image_paths, texts, tgt_image_paths)
    
    print("Predicted scores:", scores)

if __name__ == "__main__":
    main_predict()


  model.load_state_dict(torch.load("/Users/ved/Desktop/Sem 1/Vision and Language/Fall24_CSE597_Homework1/CIR/trained_model.pth", map_location=device))


Predicted scores: tensor([[3.3151, 1.3381],
        [0.9653, 1.4226]])
