In [1]:
import os
import clip
import torch
from torchvision import datasets, transforms
from PIL import Image
import csv
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-L/14@336px', device)

# Set up the dataset
dataset_path = 'E:/iith-dl-contest-2024/train/train'
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    # transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
class_names = dataset.classes

100%|███████████████████████████████████████| 891M/891M [09:58<00:00, 1.56MiB/s]


In [49]:

class SimpleImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                                             on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.file_paths = [os.path.join(dp, f) for dp, dn, filenames in os.walk(root_dir) for f in filenames if f.endswith(('.png', '.jpg', '.jpeg', '.JPEG'))]
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        img_path = self.file_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Convert to RGB
        
        if self.transform:
            image = self.transform(image)
        
        # Extract filename from img_path
        filename = os.path.basename(img_path)
        
        return image, filename  # Returns the image and its filename


In [13]:
class SimpleImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, self.image_files[idx]

In [14]:
transform_val = transforms.Compose([
    transforms.Resize(64),  # Resize to 64x64 to match training
    # transforms.CenterCrop(64),  # Crop to 64x64 to match training
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    # transforms.ToTensor(),
])


csv_dir = 'iith-dl-contest-2024\\test\\test'
csv_dataset = SimpleImageDataset( csv_dir, transform=transform_val)

In [2]:
def load_class_ids(file_path):
    class_id_to_name = {}
    with open(file_path, 'r') as file:
        for line in file:
            class_id, name = line.strip().split(': ')
            class_id_to_name[class_id] = name
    return class_id_to_name

# Assuming 'classnName_iith_dataset.txt' is in the current directory
class_ids = load_class_ids('classnName_iith_dataset.txt')
print(class_ids)

{'n01443537': 'goldfish', 'n01774750': 'tarantula', 'n01784675': 'centipede', 'n01882714': 'koala', 'n01910747': 'jellyfish', 'n01944390': 'snail', 'n01983481': 'American_lobster', 'n02056570': 'king_penguin', 'n02085620': 'Chihuahua', 'n02094433': 'Yorkshire_terrier', 'n02099601': 'golden_retriever', 'n02099712': 'Labrador_retriever', 'n02106662': 'German_shepherd', 'n02190166': 'fly', 'n02206856': 'bee', 'n02226429': 'grasshopper', 'n02233338': 'cockroach', 'n02236044': 'mantis', 'n02268443': 'dragonfly', 'n02279972': 'monarch', 'n02364673': 'guinea_pig', 'n02395406': 'hog', 'n02410509': 'bison', 'n02423022': 'gazelle', 'n02480495': 'orangutan', 'n02481823': 'chimpanzee', 'n02486410': 'baboon', 'n02769748': 'backpack', 'n02793495': 'barn', 'n02802426': 'basketball', 'n02808440': 'bathtub', 'n02814860': 'beacon', 'n02843684': 'birdhouse', 'n02906734': 'broom', 'n02948072': 'candle', 'n02950826': 'cannon', 'n03424325': 'gasmask', 'n03649909': 'lawn_mower', 'n04133789': 'sandal', 'n0414

In [None]:
def get_label(class_index):
    class_id = dataset.classes[class_index]  # Get the class ID using index
    return class_ids.get(class_id, "Unknown class ID")

# Example usage
for idx, (image, label_idx) in enumerate(dataset):
    label = get_label(label_idx)
    print(f"Image {idx} is labeled as: {label}")


In [None]:
import matplotlib.pyplot as plt

def show_image(image, label):
    # Convert tensor image to numpy
    image = image.numpy().transpose((1, 2, 0))
    plt.imshow(image)
    plt.title(label)
    plt.show()

# Show the first few images
for image, label_idx in dataset:
    label = get_label(label_idx)
    show_image(image, label)
    break  # Remove this to show more images


In [5]:
text_prompts = [f"a photo of a {name}" for name in class_ids.values()]
text_tokens = clip.tokenize(text_prompts).to(device)
print(text_prompts)

['a photo of a goldfish', 'a photo of a tarantula', 'a photo of a centipede', 'a photo of a koala', 'a photo of a jellyfish', 'a photo of a snail', 'a photo of a American_lobster', 'a photo of a king_penguin', 'a photo of a Chihuahua', 'a photo of a Yorkshire_terrier', 'a photo of a golden_retriever', 'a photo of a Labrador_retriever', 'a photo of a German_shepherd', 'a photo of a fly', 'a photo of a bee', 'a photo of a grasshopper', 'a photo of a cockroach', 'a photo of a mantis', 'a photo of a dragonfly', 'a photo of a monarch', 'a photo of a guinea_pig', 'a photo of a hog', 'a photo of a bison', 'a photo of a gazelle', 'a photo of a orangutan', 'a photo of a chimpanzee', 'a photo of a baboon', 'a photo of a backpack', 'a photo of a barn', 'a photo of a basketball', 'a photo of a bathtub', 'a photo of a beacon', 'a photo of a birdhouse', 'a photo of a broom', 'a photo of a candle', 'a photo of a cannon', 'a photo of a gasmask', 'a photo of a lawn_mower', 'a photo of a sandal', 'a pho

In [3]:
def classify_image(image_tensor):
    image_tensor = preprocess(image_tensor).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        text_features = model.encode_text(text_tokens)
        logits_per_image = (image_features @ text_features.T).softmax(dim=-1)        
        predicted_class_index = logits_per_image.argmax().item()
    predicted_class_id = class_names[predicted_class_index]
    return class_ids[predicted_class_id], predicted_class_id



In [None]:
correct_predictions = 0
total_images = 0

for idx, (image, label_idx) in enumerate(dataset):
    actual_class_id = class_names[label_idx]
    actual_class_name = class_ids[actual_class_id]
    predicted_class_name, predicted_class_id = classify_image(image)

    # Check if the predicted class ID matches the actual class ID
    if predicted_class_id == actual_class_id:
        correct_predictions += 1

    print(f"Image {idx}: Actual Class - {actual_class_name} ({actual_class_id}), Predicted Class - {predicted_class_name} ({predicted_class_id})")
    
    total_images += 1
    # if idx == 100:  # Limiting the loop to the first 10 images for brevity
    #     break



In [37]:
# Calculate and print the accuracy
if total_images > 0:
    accuracy = (correct_predictions / total_images) * 100
    print(f"Accuracy: {accuracy:.2f}%")
else:
    print("No images processed.")

Accuracy: 87.43%


In [6]:
correct_predictions = 0
total_images = 0

# Open a file to save the logs
with open('OpenAICLILP Logs\predicted_output_vit_L_14@336px_64_OpenAI_CLIP.txt', 'w') as log_file:
    for idx, (image, label_idx) in enumerate(dataset):
        actual_class_id = class_names[label_idx]
        actual_class_name = class_ids[actual_class_id]
        predicted_class_name, predicted_class_id = classify_image(image)

        # Check if the predicted class ID matches the actual class ID
        if predicted_class_id == actual_class_id:
            correct_predictions += 1

        # Log the result for each image
        print(f"Image {idx}: Actual Class - {actual_class_name} ({actual_class_id}), Predicted Class - {predicted_class_name} ({predicted_class_id})")

        log_entry = f"Image {idx}: Actual Class - {actual_class_name} ({actual_class_id}), Predicted Class - {predicted_class_name} ({predicted_class_id})\n"
        log_file.write(log_entry)
        
        total_images += 1
        # if idx == 10:  # Limiting the loop to the first 10 images for brevity
        #     break

    # Calculate and log the accuracy
    if total_images > 0:
        accuracy = (correct_predictions / total_images) * 100
        accuracy_entry = f"Accuracy: {accuracy:.2f}%\n"
        print(f"Accuracy: {accuracy:.2f}%")
        log_file.write(accuracy_entry)
    else:
        log_file.write("No images processed.\n")

print("Classification and logging complete. See 'classification_log.txt' for details.")


Image 0: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 1: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 2: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 3: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 4: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 5: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 6: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 7: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 8: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 9: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 10: Actual Class - goldfish (n01443537), Predicted Class - goldfish (n01443537)
Image 11: Actual Class - goldfish (n01443537), Predicted Class -

# Output Writer On TEST Set

In [None]:
# Process images and save results
results = []

for idx, (image, label_idx) in enumerate(csv_dataset):
    predicted_class_name, predicted_class_id = classify_image(image)
    results.append((f"{idx}.JPEG", predicted_class_id))
    if idx % 100 == 0:
        print(f"Processed {idx} images...")

# Write results to a CSV file
with open('predicted_output.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['ID', 'Category'])  # Write header
    writer.writerows(results)

print("All images processed and results saved to 'predicted_output.csv'.")

In [15]:
results = []
for image, filename in csv_dataset:
    predicted_class_name, predicted_class_id = classify_image(image)
    results.append((filename, predicted_class_id))
    print(f"Processed {filename}...")

# Write results to a CSV file
with open('predicted_output_vit_L_14_64_openAI_CLIP.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['ID', 'Category'])  # Write header
    writer.writerows(results)


Processed 0.JPEG...
Processed 1.JPEG...
Processed 10.JPEG...
Processed 100.JPEG...
Processed 1000.JPEG...
Processed 10000.JPEG...
Processed 10001.JPEG...
Processed 10002.JPEG...
Processed 10003.JPEG...
Processed 10004.JPEG...
Processed 10005.JPEG...
Processed 10006.JPEG...
Processed 10007.JPEG...
Processed 10008.JPEG...
Processed 10009.JPEG...
Processed 1001.JPEG...
Processed 10010.JPEG...
Processed 10011.JPEG...
Processed 10012.JPEG...
Processed 10013.JPEG...
Processed 10014.JPEG...
Processed 10015.JPEG...
Processed 10016.JPEG...
Processed 10017.JPEG...
Processed 10018.JPEG...
Processed 10019.JPEG...
Processed 1002.JPEG...
Processed 10020.JPEG...
Processed 10021.JPEG...
Processed 10022.JPEG...
Processed 10023.JPEG...
Processed 10024.JPEG...
Processed 10025.JPEG...
Processed 10026.JPEG...
Processed 10027.JPEG...
Processed 10028.JPEG...
Processed 10029.JPEG...
Processed 1003.JPEG...
Processed 10030.JPEG...
Processed 10031.JPEG...
Processed 10032.JPEG...
Processed 10033.JPEG...
Processed