In [None]:
!pip install transformers datasets torch torchvision

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTConfig, Trainer, TrainingArguments
from datasets import load_dataset
from torch.utils.data import DataLoader

In [None]:
MODEL_NAME = "google/vit-base-patch16-224"
NUM_CLASSES = 3 
NUM_INPUT_CHANNELS = 1
IMAGE_SIZE = 224

In [None]:
dataset = load_dataset("")

### **MODIFY INPUT LAYER**

In [None]:
def modify_vit(num_input_channels, num_classes):
    config = ViTConfig.from_pretrained(MODEL_NAME)
    config.num_labels = num_classes  # Modify output layer

    model = ViTForImageClassification(config)

    # Modify the first layer to accept different input channels
    old_embedding_layer = model.vit.embeddings.patch_embeddings
    new_embedding_layer = nn.Conv2d(
        in_channels=num_input_channels,  # Change input channels
        out_channels=old_embedding_layer.out_channels,
        kernel_size=old_embedding_layer.kernel_size,
        stride=old_embedding_layer.stride,
        padding=old_embedding_layer.padding
    )

    model.vit.embeddings.patch_embeddings = new_embedding_layer

    return model


In [None]:


transform = transforms.Compose([
    transforms.ToTensor(),
])

# Function to apply transformations to dataset
def transform_dataset(example):
    example["pixel_values"] = transform(example["image"])
    return example

dataset = dataset.map(transform_dataset, remove_columns=["image"])

# ------------------ COLLATE FUNCTION ------------------
# Hugging Face’s Trainer expects batch collation
def collate_fn(batch):
    pixel_values = torch.stack([x["pixel_values"] for x in batch])
    labels = torch.tensor([x["label"] for x in batch])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
model = modify_vit(NUM_INPUT_CHANNELS, NUM_CLASSES)

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./vit-ndvi",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    push_to_hub=False,  # Set True if you want to upload to Hugging Face Hub
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=collate_fn,
)

In [None]:
trainer.train()