# C3 Project 1: Image Classification

In [None]:
from bovw import BOVW
import wandb

SPLIT_PATH = "../data/MIT_split/"

## Charge the data

In [2]:
data_train = Dataset(ImageFolder=SPLIT_PATH+"train")
data_test = Dataset(ImageFolder=SPLIT_PATH+"test") 

## Functions

In [None]:
def extract_bovw_histograms(bovw: Type[BOVW], descriptors: Literal["N", "T", "d"]):
    return np.array([bovw._compute_codebook_descriptor(descriptors=descriptor, kmeans=bovw.codebook_algo) for descriptor in descriptors])


def test(dataset: List[Tuple[Type[Image.Image], int]]
         , bovw: Type[BOVW], 
         classifier:Type[object]):
    
    test_descriptors = []
    descriptors_labels = []
    
    for idx in tqdm.tqdm(range(len(dataset)), desc="Phase [Eval]: Extracting the descriptors"):
        image, label = dataset[idx]
        _, descriptors = bovw._extract_features(image=np.array(image))
        
        if descriptors is not None:
            test_descriptors.append(descriptors)
            descriptors_labels.append(label)
            
    
    print("Computing the bovw histograms")
    bovw_histograms = extract_bovw_histograms(descriptors=test_descriptors, bovw=bovw)
    
    print("predicting the values")
    y_pred = classifier.predict(bovw_histograms)
    
    acc = accuracy_score(y_true=descriptors_labels, y_pred=y_pred)
    print("Accuracy on Phase[Test]:", acc)
    
    return acc

def train(dataset: List[Tuple[Type[Image.Image], int]],
           bovw:Type[BOVW]):
    all_descriptors = []
    all_labels = []
    
    for idx in tqdm.tqdm(range(len(dataset)), desc="Phase [Training]: Extracting the descriptors"):
        
        image, label = dataset[idx]
        _, descriptors = bovw._extract_features(image=np.array(image))
        
        if descriptors  is not None:
            all_descriptors.append(descriptors)
            all_labels.append(label)
            
    print("Fitting the codebook")
    kmeans, cluster_centers = bovw._update_fit_codebook(descriptors=all_descriptors)

    print("Computing the bovw histograms")
    bovw_histograms = extract_bovw_histograms(descriptors=all_descriptors, bovw=bovw) 
    
    print("Fitting the classifier")
    classifier = LogisticRegression(class_weight="balanced").fit(bovw_histograms, all_labels)

    acc = accuracy_score(y_true=all_labels, y_pred=classifier.predict(bovw_histograms))
    print("Accuracy on Phase[Train]:", acc)
    
    return bovw, classifier, acc


def Dataset(ImageFolder:str = SPLIT_PATH + "train") -> List[Tuple[Type[Image.Image], int]]:

    """
    Expected Structure:

        ImageFolder/<cls label>/xxx1.jpg
        ImageFolder/<cls label>/xxx2.jpg
        ImageFolder/<cls label>/xxx3.jpg
        ...

        Example:
            ImageFolder/cat/123.jpg
            ImageFolder/cat/nsdf3.jpg
            ImageFolder/cat/[...]/asd932_.jpg
    
    """

    map_classes = {clsi: idx for idx, clsi  in enumerate(os.listdir(ImageFolder))}
    
    dataset :List[Tuple] = []

    for idx, cls_folder in enumerate(os.listdir(ImageFolder)):

        image_path = os.path.join(ImageFolder, cls_folder)
        images: List[str] = glob.glob(image_path+"/*.jpg")
        for img in images:
            img_pil = Image.open(img).convert("RGB")

            dataset.append((img_pil, map_classes[cls_folder]))


    return dataset

### Experiment 1

In [None]:
# HYPERPARAMETERS
detector_type = "SIFT"
codebook_size = 100
classifier = "Logistic Regression"

## Run experiment

In [None]:
# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="project-5",
    # Set the wandb project where this run will be logged.
    project="project-1",
    # Track hyperparameters and run metadata.
    config={
        "detector_type": detector_type,
        "codebook_size": codebook_size,
        "classifier": classifier,
    },
)

#codebook_kwargs and detector_kwargs for passing hyperparameters
bovw = BOVW(detector_type=detector_type, codebook_size=codebook_size)
bovw, classifier, train_acc = train(dataset=data_train, bovw=bovw)
test_acc = test(dataset=data_test, bovw=bovw, classifier=classifier)

run.log({"train_acc": train_acc, "test_acc": test_acc})
run.finish()