In [5]:
!pip install transformers datasets torch torchvision


In [6]:
!pip install tf-keras

In [None]:
import torch
from transformers import ViTForImageClassification, TrainingArguments, Trainer, ViTImageProcessor
from datasets import load_dataset
from PIL import Image
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Clear GPU Cache
torch.cuda.empty_cache()

# Load dataset
dataset = load_dataset("sagar27kumar/ECG-XRAY-dataset")

# Load processor and model
processor = ViTImageProcessor.from_pretrained("google/vit-large-patch32-384")
model = ViTForImageClassification.from_pretrained(
    "google/vit-large-patch32-384",
    num_labels=4,
    id2label={0: "Abnormal Heartbeat", 1: "History of MI", 2: "Myocardial Infarction", 3: "Normal Person"},
    label2id={"Abnormal Heartbeat": 0, "History of MI": 1, "Myocardial Infarction": 2, "Normal Person": 3},
    ignore_mismatched_sizes=True
).to("cuda" if torch.cuda.is_available() else "cpu")  # Ensure the model is on the right device


# Preprocessing function
def preprocess_function(examples):
    images = [Image.open(img_path).convert("RGB") for img_path in examples["image"]]
    inputs = processor(images=images, return_tensors="pt")
    inputs["labels"] = torch.tensor(examples["label"], dtype=torch.long)
    return inputs

# Apply preprocessing correctly
prepared_dataset = dataset.map(preprocess_function, batched=True)

# Define data collator function
def data_collator(features):
    pixel_values = torch.stack([f["pixel_values"] for f in features])
    labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
    return {"pixel_values": pixel_values, "labels": labels}

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=True,
    hub_model_id="sagar27kumar/sagarsahu_ECG-XRAY-ViT",
    no_cuda=not torch.cuda.is_available(),  # Automatically set CPU/GPU usage
)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=prepared_dataset["train"],
    eval_dataset=prepared_dataset["validation"],
    processing_class=processor,  # Fixed deprecated 'tokenizer'
    data_collator=data_collator,
)

# Train model
trainer.train()

# Push to Hugging Face Hub
trainer.push_to_hub()


# Compute evaluation metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = torch.argmax(torch.tensor(logits), dim=-1)
    accuracy = accuracy_score(labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="weighted")
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

# Update trainer with metrics function
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=prepared_dataset["train"],
    eval_dataset=prepared_dataset["validation"],
    processing_class=processor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,  # Fixed metrics function
)

# Evaluate model
eval_results = trainer.evaluate()
print(f"✅ Evaluation Results: {eval_results}")
