<a href="https://colab.research.google.com/github/Olalekan-Ojo/Generative-Ai/blob/main/ITM%20-%20Image%20Text%20Matching-%20with%20argumentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
################################################################################
# Image-Text Matching Classifier: baseline system for visual question answering
#
# This program has been adapted and rewriten from the CMP9137 materials of 2024.
#
# It treats the task of multi-choice visual question answering as a binary
# classification task. This is possible by rewriting the questions from this format:
# v7w_2358727.jpg	When was this?  Nighttime. | Daytime. | Dawn. Sunset.
#
# to the following format:
# v7w_2358727.jpg	When was this? Nighttime. 	match
# v7w_2358727.jpg	When was this?  Daytime. 	no-match
# v7w_2358727.jpg	When was this?  Dawn. 	no-match
# v7w_2358727.jpg	When was this?  Sunset.	no-match
#
# The list above contains the image file name, the question-answer pairs, and the labels.
# Only question types "when", "where" and "who" were used due to compute requirements. In
# this folder, files v7w.*Images.itm.txt are used and v7w.*Images.txt are ignored. The
# two formats are provided for your information and convenience.
#
# To enable the above this implementation provides the following classes and functions:
# - Class ITM_Dataset() to load the multimodal data (image & text (question and answer)).
# - Class Transformer_VisionEncoder() to create a pre-trained Vision Transformer, which
#   can be finetuned or trained from scratch -- update USE_PRETRAINED_MODEL accordingly.
# - Function load_sentence_embeddings() to load pre-generated sentence embeddings of questions
#   and answers, which were generated using SentenceTransformer('sentence-transformers/gtr-t5-large').
# - Class ITM_Model() to create a model combining the vision and text encoders above.
# - Function train_model trains/finetunes one of two possible models: CNN or ViT. The CNN
#   model is based on resnet18, and the Vision Transformer (ViT) is based on vit_b_32.
# - Function evaluate_model() calculates the accuracy of the selected model using test data.
# - The last block of code brings everything together calling all classes & functions above.
#
# info of resnet18: https://pytorch.org/vision/main/models/resnet.html
# info of vit_b_32: https://pytorch.org/vision/main/models/vision_transformer.html
# info of SentenceTransformer: https://huggingface.co/sentence-transformers/gtr-t5-large
#
# This program was tested on Windows 11 using WSL and does not generate any plots.
# Feel free to use and extend this program as part of your our assignment work.
#
# Version 1.0, main functionality in tensorflow tested with COCO data
# Version 1.2, extended functionality for Flickr data
# Version 1.3, ported to pytorch and tested with visual7w data
# Contact: {hcuayahuitl}@lincoln.ac.uk
################################################################################


In [18]:
import os
import time
import pickle
import torch
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.models import vit_b_32


In [3]:
# prompt: Simple code to import google drive folder

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [7]:
# # prompt: i imported my google drive folder, i need to navidate to the folder ohis to my zipped dataset call ITM_classifier

# # Assuming your zipped dataset is in 'My Drive/ohis'
# zip_file_path = '/content/drive/MyDrive/ohis/ITM_Classifier-baselines.zip'

# # Check if the zip file exists
# if os.path.exists(zip_file_path):
#   print(f"Zip file found at: {zip_file_path}")

# else:
#   print(f"Error: Zip file not found at {zip_file_path}. Please check the path.")


Zip file found at: /content/drive/MyDrive/ohis/ITM_Classifier-baselines.zip


In [8]:
# # prompt: code to unzip my folder - zip_file_path

# import zipfile
# import os

# # Assuming your zipped dataset is in 'My Drive/ohis'
# zip_file_path = '/content/drive/MyDrive/ohis/ITM_Classifier-baselines.zip'

# # Check if the zip file exists
# if os.path.exists(zip_file_path):
#   print(f"Zip file found at: {zip_file_path}")
#   try:
#     with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
#       zip_ref.extractall('/content/drive/MyDrive/ohis/') # Extract to the same directory
#     print("Files extracted successfully!")

#   except zipfile.BadZipFile:
#     print(f"Error: The file at {zip_file_path} is not a valid zip file.")

#   except Exception as e:
#     print(f"An error occurred during extraction: {e}")
# else:
#   print(f"Error: Zip file not found at {zip_file_path}. Please check the path.")


Zip file found at: /content/drive/MyDrive/ohis/ITM_Classifier-baselines.zip
Files extracted successfully!


In [22]:
# Custom Dataset
class ITM_Dataset(Dataset):
    def __init__(self, images_path, data_file, sentence_embeddings, data_split, train_ratio=1.0):
        self.images_path = images_path
        self.data_file = data_file
        self.sentence_embeddings = sentence_embeddings
        self.data_split = data_split.lower()
        self.train_ratio = train_ratio if self.data_split == "train" else 1.0

        self.image_data = []
        self.question_data = []
        self.answer_data = []
        self.question_embeddings_data = []
        self.answer_embeddings_data = []
        self.label_data = []
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.RandomCrop(224),     # Random crop for training - changed
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard for pretrained models on ImageNet
        ]) if data_split == "train" else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.load_data()

In [23]:
# Custom Dataset
class ITM_Dataset(Dataset):
    def __init__(self, images_path, data_file, sentence_embeddings, data_split, train_ratio=1.0):
        self.images_path = images_path
        self.data_file = data_file
        self.sentence_embeddings = sentence_embeddings
        self.data_split = data_split.lower()
        self.train_ratio = train_ratio if self.data_split == "train" else 1.0

        self.image_data = []
        self.question_data = []
        self.answer_data = []
        self.question_embeddings_data = []
        self.answer_embeddings_data = []
        self.label_data = []
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard for pretrained models on ImageNet
        ])

        self.load_data()

    def load_data(self):
        print("LOADING data from "+str(self.data_file))
        print("=========================================")

        random.seed(42)

        with open(self.data_file) as f:
            lines = f.readlines()

            # Apply train_ratio only for training data
            if self.data_split == "train":
                random.shuffle(lines)  # Shuffle before selecting
                num_samples = int(len(lines) * self.train_ratio)
                lines = lines[:num_samples]

            for line in lines:
                line = line.rstrip("\n")
                img_name, text, raw_label = line.split("\t")
                img_path = os.path.join(self.images_path, img_name.strip())

                question_answer_text = text.split("?")
                question_text = question_answer_text[0].strip() + '?'
                answer_text = question_answer_text[1].strip()

                # Get binary labels from match/no-match answers
                label = 1 if raw_label == "match" else 0
                self.image_data.append(img_path)
                self.question_data.append(question_text)
                self.answer_data.append(answer_text)
                self.question_embeddings_data.append(self.sentence_embeddings[question_text])
                self.answer_embeddings_data.append(self.sentence_embeddings[answer_text])
                self.label_data.append(label)

        print("|image_data|="+str(len(self.image_data)))
        print("|question_data|="+str(len(self.question_data)))
        print("|answer_data|="+str(len(self.answer_data)))
        print("done loading data...")

    def __len__(self):
        return len(self.image_data)

    def __getitem__(self, idx):
        img_path = self.image_data[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        question_embedding = torch.tensor(self.question_embeddings_data[idx], dtype=torch.float32)
        answer_embedding = torch.tensor(self.answer_embeddings_data[idx], dtype=torch.float32)
        label = torch.tensor(self.label_data[idx], dtype=torch.long)
        return img, question_embedding, answer_embedding, label


In [24]:
# Load sentence embeddings from an existing file -- generated a priori
def load_sentence_embeddings(file_path):
    print("READING sentence embeddings...")
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

# Pre-trained ViT model
class Transformer_VisionEncoder(nn.Module):
    def __init__(self, pretrained=None):
        super(Transformer_VisionEncoder, self).__init__()

        if pretrained:
            self.vision_model = vit_b_32(weights="IMAGENET1K_V1")
            # Freeze all layers initially
            for param in self.vision_model.parameters():
                param.requires_grad = False

            # Unfreeze the last two layers
            for param in list(self.vision_model.heads.parameters())[-2:]:
                param.requires_grad = True
        else:
            self.vision_model = vit_b_32(weights=None)  # Initialize without pretrained weights

        # Get feature size after initialising the model
        self.num_features = self.vision_model.heads[0].in_features

        # Remove original classification head
        self.vision_model.heads = nn.Identity()

    def forward(self, x):
        features = self.vision_model(x)  # Shape should be (batch_size, num_features)
        return features


In [25]:
# Image-Text Matching Model
class ITM_Model(nn.Module):
    def __init__(self, num_classes=2, ARCHITECTURE=None, PRETRAINED=None, dropout_rate=0.3): #changed
        print(f'BUILDING %s model, pretrained=%s' % (ARCHITECTURE, PRETRAINED))
        super(ITM_Model, self).__init__()
        self.ARCHITECTURE = ARCHITECTURE

        if self.ARCHITECTURE == "CNN":
            self.vision_model = models.resnet18(pretrained=PRETRAINED)
            if PRETRAINED:
			    # Freeze all layers
                for param in self.vision_model.parameters():
                    param.requires_grad = False
                # Unfreeze the last two layers
                for param in list(self.vision_model.children())[-2:]:
                    for p in param.parameters():
                        p.requires_grad = True
            else:
                for param in self.vision_model.parameters():
                    param.requires_grad = True
            self.vision_model.fc = nn.Linear(self.vision_model.fc.in_features, 128) # Change output

        elif self.ARCHITECTURE == "ViT":
            self.vision_model = Transformer_VisionEncoder(pretrained=PRETRAINED)
            self.fc_vit = nn.Linear(self.vision_model.num_features, 128)  # Reduce features

        else:
            print("UNKNOWN neural architecture", ARCHITECTURE)
            exit(0)

         # Enhanced feature processing
        self.question_embedding_layer = nn.Sequential(
            nn.Linear(768, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128)
        )

        self.answer_embedding_layer = nn.Sequential(
            nn.Linear(768, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128)
        )

        # Cross-attention mechanism
        self.cross_attention = nn.MultiheadAttention(128, num_heads=4, dropout=dropout_rate)

        # Final classification layers
        self.classifier = nn.Sequential(
            nn.Linear(128 * 3, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, num_classes)
        )

    def forward(self, img, question_embedding, answer_embedding):
        img_features = self.vision_model(img)
        if self.ARCHITECTURE == "ViT":
            img_features = self.fc_vit(img_features) # Use the custom linear layer for ViT
        question_features = self.question_embedding_layer(question_embedding)
        answer_features = self.answer_embedding_layer(answer_embedding)

        #---------------------------------------- changed
        # Apply cross-attention
        attn_output, _ = self.cross_attention(
            img_features.unsqueeze(0),
            torch.stack([question_features, answer_features], dim=0),
            torch.stack([question_features, answer_features], dim=0)
        )

        combined_features = torch.cat((
            img_features,
            question_features,
            answer_features
        ), dim=1)

        return self.classifier(combined_features)


In [26]:
def train_model(model, ARCHITECTURE, train_loader, criterion, optimizer, num_epochs=10):
    print(f'TRAINING {ARCHITECTURE} model')

    # Add learning rate scheduler
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=3e-4,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader)
    )

    # Add gradient scaler for mixed precision training
    scaler = torch.cuda.amp.GradScaler()

#----------------------------------------- Changed above

    # Track the overall loss for each epoch
    for epoch in range(num_epochs):
        running_loss = 0.0
        total_batches = len(train_loader)
        start_time = time.time()

        for batch_idx, (images, question_embeddings, answer_embeddings, labels) in enumerate(train_loader):
            # Move images/text/labels to the GPU (if available)
            images = images.to(device)
            question_embeddings = question_embeddings.to(device)
            answer_embeddings = answer_embeddings.to(device)
            labels = labels.to(device)

            # # Forward pass -- given input data to the model
            # outputs = model(images, question_embeddings, answer_embeddings)

            # # Calculate loss (error)
            # loss = criterion(outputs, labels)  # output should be raw logits

            # # Backward pass -- given loss above
            # optimiser.zero_grad() # clear the gradients
            # loss.backward() # computes gradient of the loss/error
            # optimiser.step() # updates parameters using gradients

            # Mixed precision training
            with torch.cuda.amp.autocast():
                outputs = model(images, question_embeddings, answer_embeddings)
                loss = criterion(outputs, labels)

            optimizer.zero_grad()
            scaler.scale(loss).backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            running_loss += loss.item()

            #-------------------------changed above

            if batch_idx % 100 == 0:
              print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx}/{total_batches}], '
                      f'Loss: {loss.item():.4f}, LR: {scheduler.get_last_lr()[0]:.6f}')



        # Print average loss for the epoch
        print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx}/{total_batches}], '
                      f'Loss: {loss.item():.4f}, LR: {scheduler.get_last_lr()[0]:.6f}')

        avg_loss = running_loss / total_batches
        print(f'Epoch [{epoch + 1}/{num_epochs}] Average Loss: {avg_loss:.4f}')


In [27]:
def evaluate_model(model, ARCHITECTURE, test_loader, device):
    print(f'EVALUATING %s model' % (ARCHITECTURE))
    model.eval()
    total_test_loss = 0
    all_labels = []
    all_predictions = []
    start_time = time.time()

    with torch.no_grad():
        for images, question_embeddings, answer_embeddings, labels in test_loader:
            # Move images/text/labels to the GPU (if available)
            images = images.to(device)
            question_embeddings = question_embeddings.to(device)
            answer_embeddings = answer_embeddings.to(device)
            labels = labels.to(device)  # Labels are single integers (0 or 1)

            # Perform forward pass on our data
            outputs = model(images, question_embeddings, answer_embeddings)

            # Accumulate loss on test data
            total_test_loss += criterion(outputs, labels)

            # Since outputs are logits, apply softmax to get probabilities
            predicted_probabilities = torch.softmax(outputs, dim=1)  # Use softmax for multi-class output
            predicted_class = predicted_probabilities.argmax(dim=1)  # Get the predicted class index (0 or 1)

            # Store labels and predictions for later analysis
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted_class.cpu().numpy())

    # Convert to numpy arrays for easier calculations
    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)

    # Calculate true positives, true negatives, false positives, false negatives
    tp = np.sum((all_predictions == 1) & (all_labels == 1))  # True positives
    tn = np.sum((all_predictions == 0) & (all_labels == 0))  # True negatives
    fp = np.sum((all_predictions == 1) & (all_labels == 0))  # False positives
    fn = np.sum((all_predictions == 0) & (all_labels == 1))  # False negatives

    # Calculate sensitivity, specificity, and balanced accuracy
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    balanced_accuracy = (sensitivity + specificity) / 2.0

    elapsed_time = time.time() - start_time
    print(f'Balanced Accuracy: {balanced_accuracy:.4f}, {elapsed_time:.2f} seconds')
    print(f'Total Test Loss: {total_test_loss:.4f}')



In [28]:
# Main Execution
if __name__ == '__main__':
    # Check GPU availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')

    # Paths and files
    # Paths and files
    IMAGES_PATH = "/content/drive/MyDrive/ohis/visual7w-images"
    train_data_file = "/content/drive/MyDrive/ohis/visual7w-text/v7w.TrainImages.itm.txt"
    dev_data_file = "/content/drive/MyDrive/ohis/visual7w-text/v7w.DevImages.itm.txt"
    test_data_file = "/content/drive/MyDrive/ohis/visual7w-text/v7w.TestImages.itm.txt"
    sentence_embeddings_file = "/content/drive/MyDrive/ohis/v7w.sentence_embeddings-gtr-t5-large.pkl"
    sentence_embeddings = load_sentence_embeddings(sentence_embeddings_file)

    # Create datasets and loaders
    train_dataset = ITM_Dataset(IMAGES_PATH, train_data_file, sentence_embeddings, data_split="train", train_ratio=0.2)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    test_dataset = ITM_Dataset(IMAGES_PATH, test_data_file, sentence_embeddings, data_split="test")  # whole test data
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    # The dev set is not used in this program and you should/could use it for example to optimise your hyperparameters
    #dev_dataset = ITM_Dataset(images_path, "dev_data.txt", sentence_embeddings, data_split="dev")  # whole dev data

    # Create the model using one of the two supported architectures
    MODEL_ARCHITECTURE = "ViT"  # ViT often performs better for this task
    USE_PRETRAINED_MODEL = True

    # Create model with dropout
    model = ITM_Model(
        num_classes=2,
        ARCHITECTURE=MODEL_ARCHITECTURE,
        PRETRAINED=USE_PRETRAINED_MODEL,
        dropout_rate=0.3
    ).to(device)

    # Use a weighted loss function to handle class imbalance
    class_weights = torch.tensor([1.0, 2.0]).to(device)  # Adjust based on your class distribution
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Use AdamW with weight decay
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=3e-5,
        weight_decay=0.01,
        betas=(0.9, 0.999)
    )


    # Increase epochs and use larger batch size
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

    train_model(model, MODEL_ARCHITECTURE, train_loader, criterion, optimizer, num_epochs=20)
    evaluate_model(model, MODEL_ARCHITECTURE, test_loader, device)


Using device: cuda
READING sentence embeddings...
LOADING data from /content/drive/MyDrive/ohis/visual7w-text/v7w.TrainImages.itm.txt
|image_data|=9780
|question_data|=9780
|answer_data|=9780
done loading data...
LOADING data from /content/drive/MyDrive/ohis/visual7w-text/v7w.TestImages.itm.txt




|image_data|=5980
|question_data|=5980
|answer_data|=5980
done loading data...
BUILDING ViT model, pretrained=True
TRAINING ViT model


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Epoch [1/20], Batch [0/306], Loss: 0.8167, LR: 0.000012
Epoch [1/20], Batch [100/306], Loss: 0.7500, LR: 0.000014
Epoch [1/20], Batch [200/306], Loss: 0.6718, LR: 0.000020
Epoch [1/20], Batch [300/306], Loss: 0.5642, LR: 0.000031
Epoch [1/20], Batch [305/306], Loss: 0.8502, LR: 0.000031
Epoch [1/20] Average Loss: 0.6839
Epoch [2/20], Batch [0/306], Loss: 0.7288, LR: 0.000031
Epoch [2/20], Batch [100/306], Loss: 0.7569, LR: 0.000046
Epoch [2/20], Batch [200/306], Loss: 0.6028, LR: 0.000063
Epoch [2/20], Batch [300/306], Loss: 0.4362, LR: 0.000083
Epoch [2/20], Batch [305/306], Loss: 0.7542, LR: 0.000084
Epoch [2/20] Average Loss: 0.6100
Epoch [3/20], Batch [0/306], Loss: 0.4442, LR: 0.000084
Epoch [3/20], Batch [100/306], Loss: 0.6364, LR: 0.000107
Epoch [3/20], Batch [200/306], Loss: 0.6359, LR: 0.000130
Epoch [3/20], Batch [300/306], Loss: 0.5446, LR: 0.000155
Epoch [3/20], Batch [305/306], Loss: 0.6230, LR: 0.000156
Epoch [3/20] Average Loss: 0.5580
Epoch [4/20], Batch [0/306], Loss:

In [31]:
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk

In [None]:
def predict_time_of_day(model, image_path, sentence_embeddings, device):
    """
    Predicts the time of day for a given image
    """
    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load and preprocess image
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)  # Add batch dimension
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

    # Fixed question and possible answers
    question = "When was this?"
    possible_answers = ["Nighttime", "Daytime", "Dawn", "Sunset"]

    # Set model to evaluation mode
    model.eval()

    # Get question embedding (same for all predictions)
    question_embedding = torch.tensor(sentence_embeddings[question], dtype=torch.float32)
    question_embedding = question_embedding.unsqueeze(0).to(device)

    # Test each possible answer
    results = []
    for answer in possible_answers:
        # Get answer embedding
        answer_embedding = torch.tensor(sentence_embeddings[answer], dtype=torch.float32)
        answer_embedding = answer_embedding.unsqueeze(0).to(device)

        # Make prediction
        with torch.no_grad():
            outputs = model(image, question_embedding, answer_embedding)
            probabilities = torch.softmax(outputs, dim=1)
            prediction = torch.argmax(probabilities, dim=1)
            confidence = probabilities[0][prediction[0]].item()

        results.append((answer, confidence, prediction.item()))

    # Find the best matching answer
    best_match = max(results, key=lambda x: x[1] if x[2] == 1 else 0)
    return best_match[0], best_match[1]  # Return time of day and confidence

def interactive_image_testing(model, sentence_embeddings, device):
    """
    Interactive loop for testing new images
    """
    print("\n=== Time of Day Image Predictor ===")
    print("Enter 'quit' to exit")

    while True:
        # Get image path
        image_path = input("\nEnter the path to your image: ").strip()
        if image_path.lower() == 'quit':
            break

        # Make prediction
        try:
            time_of_day, confidence = predict_time_of_day(model, image_path, sentence_embeddings, device)

            # Display results
            print("\nResults:")
            print("-" * 50)
            print(f"Image: {image_path}")
            print(f"Predicted Time of Day: {time_of_day}")
            print(f"Confidence: {confidence:.2%}")
            print("-" * 50)
        except Exception as e:
            print(f"Error processing image: {e}")

# Add this at the end of your main block:
if __name__ == '__main__':
    # ... (your existing training code) ...

    # After training and evaluation, add:
    print("\nWould you like to test new images? (yes/no)")
    response = input().strip().lower()
    if response == 'yes':
        interactive_image_testing(model, sentence_embeddings, device)


Would you like to test new images? (yes/no)
yes 

=== Time of Day Image Predictor ===
Enter 'quit' to exit

Enter the path to your image: /content/img__01327_.png
Error processing image: 'Nighttime'

Enter the path to your image: /content/img__01327_.png
Error processing image: 'Nighttime'

Enter the path to your image: /content/Firefly 20250127010512.png
Error processing image: 'Nighttime'
