## About the notebook

In [0]:
!pip install -U datasets
!pip install geopandas

In [0]:
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer, DefaultDataCollator
from torchvision.transforms import RandomResizedCrop, Compose, Resize, Normalize, ToTensor, RandomRotation
import matplotlib.pyplot as plt
import evaluate
import geopandas as gpd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay

In [0]:
# target = 'material'
target = 'pitch'

In [0]:
labels = gpd.read_file("/Volumes/prd_datascience_depcaribbeansids/volumes/depcaribbeansids/tiles_VCT.geojson")
labels = labels[labels["roof_material_dataset"].isin(["train", "test"])][["filename", "roof_material", "roof_pitch", "roof_material_dataset"]]

label2id_material, id2label_material = dict(), dict()
for i, label in enumerate(labels.roof_material.unique()):
    label2id_material[label] = str(i)
    id2label_material[str(i)] = label

label2id_pitch, id2label_pitch = dict(), dict()
for i, label in enumerate(labels.roof_pitch.unique()):
    label2id_pitch[label] = str(i)
    id2label_pitch[str(i)] = label

In [0]:
if target == 'material':
  dataset = load_dataset("/Volumes/prd_datascience_depcaribbeansids/volumes/depcaribbeansids/dataset_material/")
elif target == 'pitch':
  dataset = load_dataset("/Volumes/prd_datascience_depcaribbeansids/volumes/depcaribbeansids/dataset_pitch/")

In [0]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [0]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([
    # RandomResizedCrop(size),
    Resize(size),
    RandomRotation(360),
    ToTensor(),
    # normalize
    ])

In [0]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [0]:
dataset_transformed = dataset.with_transform(transforms)

In [0]:
plt.imshow(dataset_transformed['train'][7]['pixel_values'].permute(1,2,0))

In [0]:
data_collator = DefaultDataCollator()

In [0]:
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
precision = evaluate.load("precision")
recall = evaluate.load("recall")

In [0]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    # return accuracy.compute(predictions=predictions, references=labels)
    return f1.compute(predictions=predictions, references=labels, average="weighted")

In [0]:
if target == 'material':
    model = AutoModelForImageClassification.from_pretrained(
        '/Volumes/prd_datascience_depcaribbeansids/volumes/depcaribbeansids/model_vit/model/',
        num_labels=len(labels.roof_material.unique()),
        id2label=id2label_material,
        label2id=label2id_material,
    )
elif target == 'pitch':
    model = AutoModelForImageClassification.from_pretrained(
        '/Volumes/prd_datascience_depcaribbeansids/volumes/depcaribbeansids/model_vit/model/',
        num_labels=len(labels.roof_pitch.unique()),
        id2label=id2label_pitch,
        label2id=label2id_pitch,
    )

In [0]:
output_model_dir="/Volumes/prd_datascience_depcaribbeansids/volumes/depcaribbeansids/model_vit/model_ft_" + target

training_args = TrainingArguments(
    output_dir=output_model_dir,
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="best",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    # metric_for_best_model="accuracy",
    metric_for_best_model="eval_f1",
    push_to_hub=False,
    report_to='none',
    # disable_tqdm=True
)

In [0]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset_transformed["train"],
    eval_dataset=dataset_transformed["test"],
    processing_class=image_processor,
    compute_metrics=compute_metrics,
)

In [0]:
trainer.train()

In [0]:
def compute_metrics(pred):
  labels = pred.label_ids
  preds = pred.predictions.argmax(-1)
  precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
  acc = accuracy_score(labels, preds)
  return {
      'accuracy': acc,
      'f1': f1,
      'precision': precision,
      'recall': recall
      }

In [0]:
preds = trainer.predict(dataset_transformed["test"])
metrics = compute_metrics(preds)

In [0]:
cm = confusion_matrix(preds.label_ids, preds.predictions.argmax(-1))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['flat', 'gable', 'hip', 'no_roof'])
disp.plot()
plt.title(f"Pitch Classification F1 score: {metrics['f1']:.3f}")
plt.xticks(rotation=90)
plt.show()

In [0]:
labels.roof_pitch.value_counts()