<a href="https://colab.research.google.com/github/MuhammadIrzam447/MultiModel/blob/master/Train_12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# google/vit-base-patch16-224

In [None]:
# !gdown

In [None]:
# !unzip

In [None]:
!pip install transformers

In [None]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor, AdamW
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_auc_score
from PIL import Image

In [None]:
# Define the path to your training and validation data
train_data_root = "/content/Dataset(s)/fused-food-101-train"
val_data_root = "/content/Dataset(s)/fused-food-101-test"

In [None]:
from transformers import ViTImageProcessor

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]

# Define transformations for the input images
transform = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std)
])


In [None]:
# Load the dataset using ImageFolder and apply transformations
train_dataset = ImageFolder(train_data_root, transform=transform)
val_dataset = ImageFolder(val_data_root, transform=transform)

In [None]:
# Create label2id and id2label dictionaries based on the class names in the dataset
label2id = {class_name: idx for class_name, idx in train_dataset.class_to_idx.items()}
id2label = {idx: class_name for class_name, idx in train_dataset.class_to_idx.items()}

In [None]:
# Initialize the feature extractor
# feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

# Define batch size and number of workers (adjust based on your system's resources)
batch_size = 32
num_workers = 1

# Create DataLoader for the dataset
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
num_classes = len(train_dataset.classes)
print(num_classes)

101


In [None]:
save_dir = '/content/Model/Models-Train-12/'
load_path = os.path.join(save_dir, '8_model.pth')

vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True, num_labels=num_classes)
vit.load_state_dict(torch.load(load_path))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit.to(device)
print(vit)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([101, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([101]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [None]:
# vit = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)
# vit.classifier = nn.Linear(vit.config.hidden_size, num_classes)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# vit.to(device)
# print(vit)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(vit.parameters(), lr=1e-5)
# optimizer = optim.SGD(vit.parameters(), lr=0.001, momentum=0.9)
num_epochs = 12



In [None]:
for epoch in range(num_epochs):
    vit.train()
    train_loss = 0.0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = vit(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    # Calculate average loss for this epoch
    train_loss /= len(train_loader.dataset)

    save_dir = "/content/Model/Models-Train-12/"
    os.makedirs(save_dir, exist_ok=True)  # Create the directory if it doesn't exist

    model_name = str((epoch+9)) + "_model.pth"
    save_path = os.path.join(save_dir, model_name)  # Specify the complete path to the model file
    torch.save(vit.state_dict(), save_path)

    # Validation
    vit.eval()
    val_loss = 0.0
    correct = 0
    predicted_classes = []
    actual_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = vit(images).logits
            val_loss += criterion(outputs, labels).item() * images.size(0)

            probabilities = torch.softmax(outputs, dim=1)
            predicted = torch.argmax(probabilities, dim=1)

            correct += (predicted == labels).sum().item()

            predicted_classes.extend(predicted.cpu().numpy())
            actual_labels.extend(labels.cpu().numpy())

    # Calculate average loss and accuracy for validation set
    val_loss /= len(val_loader.dataset)
    accuracy = correct / len(val_loader.dataset)

    print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {train_loss:.4f} - Validation Loss: {val_loss:.4f} - Accuracy: {accuracy:.4f}")

    # Compute evaluation metrics using the predicted_classes and actual_labels lists
    accuracy = accuracy_score(actual_labels, predicted_classes)
    precision = precision_score(actual_labels, predicted_classes, average='weighted')
    recall = recall_score(actual_labels, predicted_classes, average='weighted')
    f1 = f1_score(actual_labels, predicted_classes, average='weighted')

    print("Accuracy:", accuracy)
    print("Precision:", precision)
    print("Recall:", recall)
    print("F1-score:", f1)
    print(classification_report(actual_labels, predicted_classes))
    print("Confusion Matrix:")
    print(confusion_matrix(actual_labels, predicted_classes))
    # print("AUROC:", roc_auc_score(actual_labels, predicted_classes))

Epoch 1/12 - Training Loss: 0.0150 - Validation Loss: 0.9687 - Accuracy: 0.8216
Accuracy: 0.8215795034337031
Precision: 0.8251463022019719
Recall: 0.8215795034337031
F1-score: 0.8220048196654802
              precision    recall  f1-score   support

           0       0.80      0.84      0.82       234
           1       0.92      0.93      0.92       221
           2       0.92      0.92      0.92       226
           3       0.80      0.77      0.78       222
           4       0.77      0.55      0.64       225
           5       0.89      0.79      0.84       224
           6       0.80      0.77      0.79       224
           7       0.90      0.76      0.83       225
           8       0.80      0.76      0.78       226
           9       0.77      0.82      0.79       214
          10       0.82      0.84      0.83       231
          11       0.85      0.89      0.87       227
          12       0.85      0.86      0.85       230
          13       0.88      0.91      0.90     

KeyboardInterrupt: ignored