In [1]:
import sys
# !{sys.executable} -m pip install timm
# !{sys.executable} -m pip install torchvision
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
import pandas as pd
import torch
import numpy as np
from torch import nn
import timm
import ast
import tqdm

from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

from torch.utils.data import TensorDataset, DataLoader

# CHANGE HERE #######################################################################################################
BASE_DIR = Path("CRC100K")     # SET BASE_DIR HERE
IMAGE_DIR = Path("./data/CRC-VAL-HE-7K-png")  # SET IMAGE_DIR HERE
# DATA_DIRECTORIES = ["binary", "complete", "knn"]
DATA_DIRECTORIES = ["knn"]                                                     # SET DATA DIRECTORIES HERE
# CHANGE HERE #######################################################################################################

# MAKE NO CHANGES HERE
VISION_MODELS = ["resnet18", "resnet50", "tiny_vit_21m_224", "vit_small_patch8_224"] 
BACKEND = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
NUM_EPOCHS = 10
LR = 1e-3

print(f"Data directories: {DATA_DIRECTORIES}")

In [2]:
########################################## DONT CHANGE ##########################################

def run_train_test(experiment, VISION_MODEL, BACKEND, test_only=False):

        model = timm.create_model(VISION_MODEL, pretrained=True, num_classes=num_classes)
        print("Model # Parameters: ", sum([p.numel() for p in model.parameters() if p.requires_grad]))
        model = model.to(BACKEND)
        img_config = resolve_data_config({}, model=model)
        img_transform = create_transform(**img_config)

        test_label = torch.tensor(label_enc[experiment["label"]])
        test_image = img_transform(Image.open(IMAGE_DIR / experiment["img_path"]))

        if test_only:
            loss = None

        if not test_only:
            train_images, train_labels = zip(*experiment["train_data"].items())
            train_images = torch.stack([img_transform(Image.open(IMAGE_DIR / img_path)) for img_path in train_images])
            train_labels = torch.tensor([label_enc[label] for label in train_labels])

            train_dataset = TensorDataset(train_images, train_labels)
            train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)

            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(model.parameters(), lr=LR)
            
            model = model.train()
            for epoch in range(NUM_EPOCHS):
                for images, labels in train_dataloader:
                    images, labels = images.to(BACKEND), labels.to(BACKEND)
                    optimizer.zero_grad()
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    sys.stdout.write(f"\rEpoch {epoch} Loss: {loss.item()}")
                    sys.stdout.flush()
                    optimizer.step()
            sys.stdout.flush()

        model = model.eval()
        with torch.inference_mode():
            test_image = test_image.to(BACKEND)
            y_pred = model(test_image.unsqueeze(0)).squeeze(0)
            y_pred_label = label_dec[y_pred.argmax().item()]
        
        sys.stdout.write(f"Predicted: {y_pred_label}, Ground Truth: {label_dec[test_label.item()]}")
        sys.stdout.flush()
        del model
        return y_pred_label, loss.item() if isinstance(loss, torch.Tensor) else loss


for VISION_MODEL in VISION_MODELS:
    print(f"Running {VISION_MODEL}...")

    for DATA_DIR in DATA_DIRECTORIES:
        print(f"Running {DATA_DIR}...")

        predictions_dir = BASE_DIR / "predictions" / DATA_DIR / VISION_MODEL
        predictions_dir.mkdir(exist_ok=True, parents=True)

        training_csvs = list(Path(BASE_DIR / DATA_DIR).glob("**/*.csv"))
        # training_csvs = [csv for csv in training_csvs if not "zero_shot" in csv.name]  # TODO: REMOVE THIS LINE ONCE DEBUGGED !!!!!!!

        for train_csv in tqdm.tqdm(training_csvs, desc="Training Datasets", colour="blue"):

            print("Training run on ", train_csv.name)
            gpt_results = pd.read_csv(train_csv, index_col=0, header=0)
            if not "zero_shot" in train_csv.name:
                gpt_results["train_data"] = gpt_results["train_data"].map(ast.literal_eval)
            gpt_results["img_path"] = gpt_results["fname"] + ".png"

            classes = sorted(set(gpt_results["label"].to_list()))
            num_classes = len(classes)
            label_enc = {label: i for i, label in enumerate(classes)}
            label_dec = {i: label for label, i in label_enc.items()}


            for i, experiment in tqdm.tqdm(gpt_results.iterrows(), desc="Training Combinations",
                                           leave=False, colour="red"):
                            
                if "zero_shot" in train_csv.name:
                    y_pred_label, final_loss = run_train_test(experiment, VISION_MODEL, BACKEND, test_only=True)
                
                else:
                    y_pred_label, final_loss = run_train_test(experiment, VISION_MODEL, BACKEND)

                gpt_results.loc[i, "answer"] = y_pred_label
                gpt_results.loc[i, "final_loss"] = final_loss

            output_file = predictions_dir / train_csv.relative_to(BASE_DIR/DATA_DIR)

            output_file.parent.mkdir(exist_ok=True, parents=True)
            gpt_results.to_csv(output_file)

            accuracy = (gpt_results["label"] == gpt_results["answer"]).mean()
            print(f"{train_csv.name}: {accuracy:.2f}")

        print(f"####################### Finished {DATA_DIR}. #######################")
    print(f"####################### Finished {VISION_MODEL}. #######################")
print(f"####################### Finished Run. #######################")