In [None]:
# -*- coding: utf-8 -*-
"""
Created on Sat Nov 23 12:22:20 2024

@author: Mahwash Shakoor

Autism Detection
"""

import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
import torch.nn as nn
import pytorch_lightning as pl
import torch
from torchmetrics.functional import accuracy
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from PIL import Image



# Define image transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT input size
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Load datasets
data_dir = "/content/drive/MyDrive/Data_9B"
train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Vit Transformer Model Defination

# Load pre-trained ViT and modify for binary classification
class AutismClassifier(nn.Module):
  def __init__(self):
    super(AutismClassifier, self).__init__()
    self.vit = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224-in21k", num_labels=2)  # Binary classification

    # Freeze pre-trained layers (except the final classifier)
    for param in self.vit.parameters():
      param.requires_grad = False  # Freeze all parameters except the last layer

    # Modify the final classifier
    self.vit.classifier = nn.Linear(self.vit.classifier.in_features, 2)

  def forward(self, x):
    return self.vit(x).logits


# Pytorch Modeule

from torchmetrics.functional import accuracy

class AutismClassifierLit(pl.LightningModule):
    def __init__(self, learning_rate=2e-5):
        super(AutismClassifierLit, self).__init__()
        self.model = AutismClassifier()
        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
      images, labels = batch
      logits = self(images)
      loss = self.criterion(logits, labels)

      # Convert logits to predicted classes
      preds = torch.argmax(logits, dim=1)
      acc = accuracy(preds, labels, task="binary")

      self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
      self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True)
      return loss

    def validation_step(self, batch, batch_idx):
      images, labels = batch
      logits = self(images)
      loss = self.criterion(logits, labels)

      # Convert logits to predicted classes
      preds = torch.argmax(logits, dim=1)
      acc = accuracy(preds, labels, task="binary")

      self.log("val_loss", loss, on_epoch=True, prog_bar=True)
      self.log("val_acc", acc, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
      images, labels = batch
      logits = self(images)
      loss = self.criterion(logits, labels)

      # Convert logits to predicted classes
      preds = torch.argmax(logits, dim=1)
      acc = accuracy(preds, labels, task="binary")

      self.log("test_loss", loss, prog_bar=True)
      self.log("test_acc", acc, prog_bar=True)


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)


# Choose a logger (TensorBoard or WandB)
logger = TensorBoardLogger("logs", name="Autism_Classifier")
# logger = WandbLogger(project="Autism_Classifier")

# Model checkpointing
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints/",
    filename="best-checkpoint",
    save_top_k=1,
    mode="min",
)

# Instantiate the model
model = AutismClassifierLit()

# Trainer
trainer = pl.Trainer(
    max_epochs=5,
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback],
    log_every_n_steps=10
)

# Train the model
trainer.fit(model, train_loader, val_loader)


# Evalute the Model on the test set

# Load the best checkpoint
best_model_path = checkpoint_callback.best_model_path
model = AutismClassifierLit.load_from_checkpoint(best_model_path)

# Test the model
trainer.test(model, test_loader)

#inference



def predict_image(image_path, model, transform):
    model.eval()
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    logits = model(image)
    pred = torch.argmax(logits, dim=1).item()
    class_map = {0: "Autism", 1: "No_Autism"}
    return class_map[pred]

# Example inference
image_path = "/content/drive/MyDrive/Data_9B/test/autistic/001.jpg"
prediction = predict_image(image_path, model.model, transform)
print(f"Prediction: {prediction}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params | Mode 
-------------------------------------------------------
0 | model     | AutismClassifier | 85.8 M | train
1 | criterion | CrossEntropyLoss | 0      | train
-------------------------------------------------------
1.5 K     Trainable params
85.8 M    Non-trainable params
85.8 M    Total params
343.201   Total estimated model params size (MB)
3         Modul

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Testing: |          | 0/? [00:00<?, ?it/s]

Prediction: No_Autism


In [None]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Downloading pytorch_lightning-2.4.0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Downloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m926.4/926.4 kB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.11.9 pytorch_lightning-2.4.0 torchmetrics-1.6.0
