In [None]:
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
import pandas as pd
import torch
import numpy as np
from torch import nn

import sys
sys.path.append("..")
from fewshot_histo.fewshot.feature_extractors import load_feature_extractor, FEATURE_EXTRACTORS


DATA_DIR = Path("/app/data/CRC-100K-VIT/")

# Set feature extractor here!
FEATURE_EXTRACTOR = "phikon"
# FEATURE_EXTRACTOR = "vit"
# FEATURE_EXTRACTOR = "resnet50"

FEATURE_DIM = 768 if FEATURE_EXTRACTOR not in {"resnet50", "retccl", "bt", "swav", "mocov2"} else 2048

NUM_EPOCHS = 10

print(f"Available feature extractors: {', '.join(FEATURE_EXTRACTORS.keys())}")
print(f"Using feature extractor: {FEATURE_EXTRACTOR}")

In [None]:
feature_extractor = load_feature_extractor(FEATURE_EXTRACTOR)

images = list(DATA_DIR.glob("*.png"))

In [None]:
to_tensor = transforms.ToTensor()

# Extract features
with torch.no_grad():
    features = {
        image.name: feature_extractor(to_tensor(Image.open(image)).unsqueeze(0)).squeeze(0)
        for image in tqdm(images, desc="Extracting features")
    }

In [None]:
def run_train_experiment(train_features, test_feature):
    classes = sorted(set(label for (_, label) in train_features))
    label_enc = {label: i for i, label in enumerate(classes)}
    label_dec = {i: label for i, label in enumerate(classes)}

    train_feats = torch.stack(list(feat for (feat, _) in train_features))
    train_labels_enc = torch.tensor([label_enc[label] for (_, label) in train_features])

    model = nn.Linear(FEATURE_DIM, len(classes))

    # Train model
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(NUM_EPOCHS):
        optimizer.zero_grad()
        outputs = model(train_feats)
        loss = criterion(outputs, train_labels_enc)
        loss.backward()
        optimizer.step()
        # print(f"Epoch {epoch+1}/{10}, loss={loss.item():.4f}")

    # Test model
    model.eval()
    with torch.no_grad():
        y_pred = model(test_feature.unsqueeze(0)).squeeze(0)
        y_pred_label = label_dec[y_pred.argmax().item()]
    return y_pred_label, loss.item()

In [None]:
binary_dir = DATA_DIR / "binary"
predictions_dir = DATA_DIR / "predictions" / FEATURE_EXTRACTOR
predictions_dir.mkdir(exist_ok=True, parents=True)

gpt_results_files = list(binary_dir.glob("**/*.csv"))

print(f"Accuracy on GPT results files using linear classification head on features extracted by {feature_extractor.name}:")

for gpt_results_file in gpt_results_files:
    if "zero_shot" in gpt_results_file.name:
        continue
    gpt_results = pd.read_csv(gpt_results_file, index_col=0, header=0)
    gpt_results["train_data"] = gpt_results["train_data"].map(eval)

    for i, experiment in gpt_results.iterrows():
        test_label = experiment["label"]
        test_image = experiment["fname"] + ".png"
        train_images = experiment["train_data"]

        # Get features of training and test images
        train_features = [
            (features[image], label)
            for image, label in train_images.items()
        ]
        test_feature = features[test_image]

        # Run experiment
        pred_label, final_loss = run_train_experiment(train_features, test_feature)

        gpt_results.loc[i, "pred_label"] = pred_label
        gpt_results.loc[i, "final_loss"] = final_loss
    
    output_file = predictions_dir / gpt_results_file.relative_to(binary_dir)
    output_file.parent.mkdir(exist_ok=True, parents=True)
    gpt_results.to_csv(output_file)

    accuracy = (gpt_results["label"] == gpt_results["pred_label"]).mean()
    print(f"{gpt_results_file.name}: {accuracy:.2f}")