In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch

from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
import pandas as pd
from scripts.clip_rn50 import ClipFinetuner, train, CustomDataset, evaluate
from scripts.datasets import office_home, convert_bytes_to_images
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import numpy as np

from vendi_score import image_utils
import weightwatcher as ww
import warnings

In [2]:

PATH = "dataset"
ds = load_dataset("wltjr1007/DomainNet", cache_dir=PATH)

In [None]:
train_clipart_images = ds['train'].filter(lambda example: example['domain'] == 0)
train_infograph_images = ds['train'].filter(lambda example: example['domain'] == 1)
train_painting_images = ds['train'].filter(lambda example: example['domain'] == 2)
train_sketch_images = ds['train'].filter(lambda example: example['domain'] == 5)

In [4]:
clipart_images_filtered = [
    item for item in train_clipart_images if 0 <= item['label'] <= 33
]

infograph_images_filtered = [
    item for item in train_infograph_images if 0 <= item['label'] <= 33
]

painting_images_filtered = [
    item for item in train_painting_images if 0 <= item['label'] <= 33
]

sketch_images_filtered = [
    item for item in train_sketch_images if 0 <= item['label'] <= 33
]

In [9]:
clipart_images =  [item['image'] for item in clipart_images_filtered]
infograph_images =  [item['image'] for item in infograph_images_filtered]
painting_images =  [item['image'] for item in painting_images_filtered]
sketch_images =  [item['image'] for item in sketch_images_filtered]


clipart_labels =  [item['label'] for item in clipart_images_filtered]
infograph_labels =  [item['label'] for item in infograph_images_filtered]
painting_labels =  [item['label'] for item in painting_images_filtered]
sketch_labels =  [item['label'] for item in sketch_images_filtered]


In [None]:
classes = train_clipart_images.features['label'].names[:34]
len(classes)

In [12]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224 for CLIP
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))  # CLIP's mean and std
])

In [13]:
clipart_dataset = CustomDataset(images=clipart_images, labels=clipart_labels, classes=classes, transform=transform)
infograph_dataset = CustomDataset(images=infograph_images, labels=infograph_labels, classes=classes, transform=transform)
painting_dataset = CustomDataset(images=painting_images, labels=painting_labels, classes=classes, transform=transform)
sketch_dataset = CustomDataset(images=sketch_images, labels=sketch_labels, classes=classes, transform=transform)


In [14]:
clipart_dataloader = DataLoader(clipart_dataset, batch_size=32, shuffle=False)
infograph_dataloader = DataLoader(infograph_dataset, batch_size=32, shuffle=False)
painting_dataloader = DataLoader(painting_dataset, batch_size=32, shuffle=False)
sketch_dataloader = DataLoader(sketch_dataset, batch_size=32, shuffle=False)

In [37]:

model = ClipFinetuner(num_classes=34)              

In [38]:
ood_acc_clipart = {}
ood_loss_clipart = {}
ood_ece_clipart = {}

ood_acc_infograph = {}
ood_loss_infograph = {}
ood_ece_infograph = {}

ood_acc_painting = {}
ood_loss_painting = {}
ood_ece_painting = {}

ood_acc_sketch = {}
ood_loss_sketch = {}
ood_ece_sketch = {}

In [None]:
"ft_model_dp0.1.pth".split('_')[2].replace('.pth', '')

In [None]:
import os

dataloaders = {
    'clipart': clipart_dataloader,
    'infograph': infograph_dataloader,
    'painting': painting_dataloader,
    'sketch': sketch_dataloader
}


ft_path = "domainnet/model"

for model_file in os.listdir(ft_path):
    if model_file.endswith('.pth'):  # Assuming the model files have .pth extension
        model_path = os.path.join(ft_path, model_file)
        model_name = model_file.split('_')[2].replace('.pth', '')
        model.load_state_dict(torch.load(model_path))
        model = model.to('cuda')
        model.eval()
        
        print(f"Evaluating model: {model_name}'")

        for dataset_name, dataloader in dataloaders.items():
            test_loss, test_acc, test_ece = evaluate(model, dataloader)
            print(f"{dataset_name.capitalize()}:")
            print(f"  Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}, ECE: {test_ece:.4f}")
            if dataset_name == 'clipart': 
                ood_acc_clipart[model_name] = test_acc
                ood_loss_clipart[model_name] = test_loss
                ood_ece_clipart[model_name] = test_ece
            
            if dataset_name == 'infograph': 
                ood_acc_infograph[model_name] = test_acc
                ood_loss_infograph[model_name] = test_loss
                ood_ece_infograph[model_name] = test_ece
            if dataset_name == 'painting': 
                ood_acc_painting[model_name] = test_acc
                ood_loss_painting[model_name] = test_loss
                ood_ece_painting[model_name] = test_ece
            if dataset_name =='sketch':
                ood_acc_sketch[model_name] = test_acc
                ood_loss_sketch[model_name] = test_loss
                ood_ece_sketch[model_name] = test_ece
                
        print("\n")  # Add a blank line between models for readability


In [64]:
ood_acc_clipart_to = pd.DataFrame([ood_acc_clipart])

ood_loss_clipart_to = pd.DataFrame([ood_loss_clipart])

ood_ece_clipart_to = pd.DataFrame([ood_ece_clipart])

ood_acc_infograph_to = pd.DataFrame([ood_acc_infograph])

ood_loss_infograph_to = pd.DataFrame([ood_loss_infograph])

ood_ece_infograph_to = pd.DataFrame([ood_ece_infograph])

ood_acc_art_to = pd.DataFrame([ood_acc_painting])

ood_loss_art_to = pd.DataFrame([ood_loss_painting])

ood_ece_art_to = pd.DataFrame([ood_ece_painting])

ood_acc_sketch_to = pd.DataFrame([ood_acc_sketch])

ood_loss_sketch_to = pd.DataFrame([ood_loss_sketch])

ood_ece_sketch_to = pd.DataFrame([ood_ece_sketch])