# Requirements

In [None]:
!pip install astroNN transformers datasets accelerate evaluate albumentations 

# Imports

In [None]:
import os
import random

random.seed(1337)

import ssl

import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
from sklearn.model_selection import train_test_split

ssl._create_default_https_context = ssl._create_stdlib_context

import torch
import albumentations as A
from astroNN.datasets import galaxy10
from datasets import load_dataset, load_metric
from transformers import (
    Trainer,
    TrainingArguments,
    AutoFeatureExtractor,
    AutoModelForImageClassification,
)

# Data Prep

In [None]:
import shutil
try:
  shutil.rmtree("./data/")
except:
  pass
try:
  shutil.rmtree("./vit-base-beans/")
except:
  pass

In [None]:
images, labels = galaxy10.load_data()
x_train, x_test, y_train, y_test = train_test_split(
    images, labels, test_size=0.2, random_state=1337, stratify=labels, shuffle=True
)

features = [
    "Disk, Face-on, No Spiral",
    "Smooth, Completely round",
    "Smooth, in-between round",
    "Smooth, Cigar shaped",
    "Disk, Edge-on, Rounded Bulge",
    "Disk, Edge-on, Boxy Bulge",
    "Disk, Edge-on, No Bulge",
    "Disk, Face-on, Tight Spiral",
    "Disk, Face-on, Medium Spiral",
    "Disk, Face-on, Loose Spiral",
]
for i in range(len(features)):
    try:
        os.makedirs(f"./data/train/{i}")
    except:
        pass
    try:
        os.makedirs(f"./data/test/{i}")
    except:
        pass

for i in range(len(x_train)):
    im = Image.fromarray(np.uint8(x_train[i])).convert("RGB")
    im.save(f"./data/train/{y_train[i]}/{i}.png")

for i in range(len(x_test)):
    im = Image.fromarray(np.uint8(x_test[i])).convert("RGB")
    im.save(f"./data/test/{y_test[i]}/{i}.png")

# Dataset load

In [None]:
ds = load_dataset("imagefolder", data_dir="./data")
print(ds)

# Data exploring

In [None]:
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds["train"].features["label"].names
    grid = Image.new("RGB", size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype(
        "C:\\Users\\USUARIO\Downloads\\liberation_mono\\LiberationMono-Bold.ttf", 24
    )

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = (
            ds["train"]
            .filter(lambda ex: ex["label"] == label_id)
            .shuffle(seed)
            .select(range(examples_per_class))
        )

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example["image"]
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid


show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

# Data preprocessing

In [None]:
model_name_or_path = "google/vit-base-patch16-224-in21k"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)

In [None]:
train_transforms = A.Compose([
    A.RandomRotate90(),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(),
])

val_transforms = A.Compose([
    A.Normalize(),
]) 


def transform(example_batch):
    inputs = feature_extractor([x for x in example_batch["pixel_values"]], return_tensors="pt")
    inputs["label"] = example_batch["label"]
    return inputs

def preprocess_train(examples):
    examples["pixel_values"] = [
        train_transforms(image=np.array(image))["image"] for image in examples["image"]
    ]

    return transform(examples)

def preprocess_val(examples):
    examples["pixel_values"] = [
        val_transforms(image=np.array(image))["image"] for image in examples["image"]
    ]

    return transform(examples)

In [None]:
train_ds = ds["train"].with_transform(preprocess_train)
test_ds = ds["test"].with_transform(preprocess_val)

In [None]:
def collate_fn(batch):
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "labels": torch.tensor([x["label"] for x in batch]),
    }

In [None]:
metric = load_metric("f1")

def compute_metrics(p):
    return metric.compute(
        predictions=np.argmax(p.predictions, axis=1), references=p.label_ids, average="weighted"
    )

In [None]:
labels = ds["train"].features["label"].names

model = AutoModelForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)},
    ignore_mismatched_sizes=True
)

# Training

In [None]:
training_args = TrainingArguments(
    output_dir="./vit-base-beans",
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    num_train_epochs=100,
    # fp16=True,
    tf32=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="tensorboard",
    load_best_model_at_end=True
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=feature_extractor,
)

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Evaluation

In [None]:
metrics = trainer.evaluate(test_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
import torch
import torch.nn.functional as F
from sklearn import metrics
 
y_preds = []
y_trues = []
for idx, data in enumerate(test_ds):
  x = torch.unsqueeze(data["pixel_values"], dim=0).cuda()
  with torch.no_grad():
    logits = model(x).logits
  probability = torch.nn.functional.softmax(logits, dim=-1).cpu()
  probs = probability[0].detach().numpy()
  confidences = {label: float(probs[i]) for i, label in enumerate(labels)} 
  y_pred = max(confidences, key=confidences.get)
  y_preds.append(y_pred)
  y_trues.append(data["label"])

In [None]:
y_trues = [str(y) for y in y_trues]
y_preds = [str(y) for y in y_preds]

In [None]:
confusion_matrix = metrics.confusion_matrix(y_trues, y_preds, labels=labels)
print(confusion_matrix)

In [None]:
import seaborn as sns
sns.heatmap(confusion_matrix, annot=True, fmt=".0f", linewidth=.1, cmap="crest")

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y_trues, y_preds, target_names=labels))